
In the previous post, we explored the basic concepts of PyTorch profiler and the newest capabilities comes with its recent updates. One of the coolest things I tried is the TensorBoard plugin comes with PyTorch Profiler. Yes.. you heard to right.. The well-known deep learning visualisation platform TensorBoard is having a Profiler plugin which makes network analysis much more easy.
I just tried the PyTorch Profiler official tutorials and seems the visualisations are pretty descriptive with analysis. I’ll do a complete deep dive with the tool in the next article.
One of the cool things I’ve noticed is the performance recommendations. Most of the recommendations make by the tool makes sense and am pretty sure they going to increase the model training performance.
In the meantime you can play around with the tool and see how convenient it is to use in your deep learning experiments. Here’s the script I used for starting the initial steps with the tool.
import torch
import torch.nn
import torch.optim
import torch.profiler
import torch.utils.data
import torchvision.datasets
import torchvision.models
import torchvision.transforms as T
#load data
transform = T.Compose(
[T.Resize(224),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)
#create model
device = torch.device("cuda:0")
model = torchvision.models.resnet18(pretrained=True).cuda(device)
criterion = torch.nn.CrossEntropyLoss().cuda(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model.train()
#train function
def train(data):
inputs, labels = data[0].to(device=device), data[1].to(device=device)
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
#use profiler to record execution events
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/resnet18'),
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
for step, batch_data in enumerate(train_loader):
if step >= (1 + 1 + 3) * 2:
break
train(batch_data)
prof.step()