try:
from astra.torch.models import ResNetClassifier
except:
%pip install git+https://github.com/sustainability-lab/ASTRA
Imports
import os
"CUDA_VISIBLE_DEVICES"] = "0"
os.environ[
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os
import numpy as np
import pandas as pd
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'
# Confusion matrix
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import torchsummary
from tqdm import tqdm
import umap
# ASTRA
from astra.torch.data import load_cifar_10
from astra.torch.utils import train_fn
from astra.torch.models import ResNetClassifier
# Netron, ONNX for model visualization
import netron
import onnx
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device device
device(type='cuda')
Dataset
= load_cifar_10()
dataset dataset
Files already downloaded and verified
Files already downloaded and verified
CIFAR-10 Dataset
length of dataset: 60000
shape of images: torch.Size([3, 32, 32])
len of classes: 10
classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
dtype of images: torch.float32
dtype of labels: torch.int64
# Plot some images
=(6, 6))
plt.figure(figsizefor i in range(25):
5, 5, i+1)
plt.subplot("chw->hwc", dataset.data[i].cpu()))
plt.imshow(torch.einsum('off')
plt.axis(
plt.title(dataset.classes[dataset.targets[i]]) plt.tight_layout()
Train val test split
= 1000
n_train = 20000
n_test
= dataset.data
X = dataset.targets
y
print(X.shape)
print(X.shape, X.dtype)
print(X.min(), X.max())
print(y.shape, y.dtype)
torch.Size([60000, 3, 32, 32])
torch.Size([60000, 3, 32, 32]) torch.float32
tensor(0.) tensor(1.)
torch.Size([60000]) torch.int64
0)
torch.manual_seed(= torch.randperm(len(X))
idx = idx[:n_train]
train_idx = idx[n_train:-n_test]
pool_idx = idx[-n_test:]
test_idx print(len(train_idx), len(pool_idx), len(test_idx))
1000 39000 20000
= ResNetClassifier(models.resnet18, models.ResNet18_Weights.DEFAULT, n_classes=10).to(device) resnet
resnet
ResNetClassifier(
(featurizer): ResNet(
(resnet): ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Identity()
)
(flatten): Flatten(start_dim=1, end_dim=-1)
)
(classifier): MLPClassifier(
(featurizer): MLP(
(activation): ReLU()
(dropout): Dropout(p=0.0, inplace=True)
(input_layer): Linear(in_features=512, out_features=512, bias=True)
)
(classifier): Linear(in_features=512, out_features=10, bias=True)
)
)
3, 32, 32)) torchsummary.summary(resnet, (
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 16, 16] 9,408
BatchNorm2d-2 [-1, 64, 16, 16] 128
ReLU-3 [-1, 64, 16, 16] 0
MaxPool2d-4 [-1, 64, 8, 8] 0
Conv2d-5 [-1, 64, 8, 8] 36,864
BatchNorm2d-6 [-1, 64, 8, 8] 128
ReLU-7 [-1, 64, 8, 8] 0
Conv2d-8 [-1, 64, 8, 8] 36,864
BatchNorm2d-9 [-1, 64, 8, 8] 128
ReLU-10 [-1, 64, 8, 8] 0
BasicBlock-11 [-1, 64, 8, 8] 0
Conv2d-12 [-1, 64, 8, 8] 36,864
BatchNorm2d-13 [-1, 64, 8, 8] 128
ReLU-14 [-1, 64, 8, 8] 0
Conv2d-15 [-1, 64, 8, 8] 36,864
BatchNorm2d-16 [-1, 64, 8, 8] 128
ReLU-17 [-1, 64, 8, 8] 0
BasicBlock-18 [-1, 64, 8, 8] 0
Conv2d-19 [-1, 128, 4, 4] 73,728
BatchNorm2d-20 [-1, 128, 4, 4] 256
ReLU-21 [-1, 128, 4, 4] 0
Conv2d-22 [-1, 128, 4, 4] 147,456
BatchNorm2d-23 [-1, 128, 4, 4] 256
Conv2d-24 [-1, 128, 4, 4] 8,192
BatchNorm2d-25 [-1, 128, 4, 4] 256
ReLU-26 [-1, 128, 4, 4] 0
BasicBlock-27 [-1, 128, 4, 4] 0
Conv2d-28 [-1, 128, 4, 4] 147,456
BatchNorm2d-29 [-1, 128, 4, 4] 256
ReLU-30 [-1, 128, 4, 4] 0
Conv2d-31 [-1, 128, 4, 4] 147,456
BatchNorm2d-32 [-1, 128, 4, 4] 256
ReLU-33 [-1, 128, 4, 4] 0
BasicBlock-34 [-1, 128, 4, 4] 0
Conv2d-35 [-1, 256, 2, 2] 294,912
BatchNorm2d-36 [-1, 256, 2, 2] 512
ReLU-37 [-1, 256, 2, 2] 0
Conv2d-38 [-1, 256, 2, 2] 589,824
BatchNorm2d-39 [-1, 256, 2, 2] 512
Conv2d-40 [-1, 256, 2, 2] 32,768
BatchNorm2d-41 [-1, 256, 2, 2] 512
ReLU-42 [-1, 256, 2, 2] 0
BasicBlock-43 [-1, 256, 2, 2] 0
Conv2d-44 [-1, 256, 2, 2] 589,824
BatchNorm2d-45 [-1, 256, 2, 2] 512
ReLU-46 [-1, 256, 2, 2] 0
Conv2d-47 [-1, 256, 2, 2] 589,824
BatchNorm2d-48 [-1, 256, 2, 2] 512
ReLU-49 [-1, 256, 2, 2] 0
BasicBlock-50 [-1, 256, 2, 2] 0
Conv2d-51 [-1, 512, 1, 1] 1,179,648
BatchNorm2d-52 [-1, 512, 1, 1] 1,024
ReLU-53 [-1, 512, 1, 1] 0
Conv2d-54 [-1, 512, 1, 1] 2,359,296
BatchNorm2d-55 [-1, 512, 1, 1] 1,024
Conv2d-56 [-1, 512, 1, 1] 131,072
BatchNorm2d-57 [-1, 512, 1, 1] 1,024
ReLU-58 [-1, 512, 1, 1] 0
BasicBlock-59 [-1, 512, 1, 1] 0
Conv2d-60 [-1, 512, 1, 1] 2,359,296
BatchNorm2d-61 [-1, 512, 1, 1] 1,024
ReLU-62 [-1, 512, 1, 1] 0
Conv2d-63 [-1, 512, 1, 1] 2,359,296
BatchNorm2d-64 [-1, 512, 1, 1] 1,024
ReLU-65 [-1, 512, 1, 1] 0
BasicBlock-66 [-1, 512, 1, 1] 0
AdaptiveAvgPool2d-67 [-1, 512, 1, 1] 0
Identity-68 [-1, 512] 0
ResNet-69 [-1, 512] 0
Flatten-70 [-1, 512] 0
ResNet-71 [-1, 512] 0
Linear-72 [-1, 512] 262,656
ReLU-73 [-1, 512] 0
Dropout-74 [-1, 512] 0
MLP-75 [-1, 512] 0
Linear-76 [-1, 10] 5,130
MLPClassifier-77 [-1, 10] 0
================================================================
Total params: 11,444,298
Trainable params: 11,444,298
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 1.32
Params size (MB): 43.66
Estimated Total Size (MB): 44.98
----------------------------------------------------------------
# Export to ONNX and visualize with Netron
= torch.randn(1, 3, 32, 32).to(device)
dummy_input "resnet.onnx", verbose=True)
torch.onnx.export(resnet, dummy_input, "resnet.onnx") netron.start(
Exported graph: graph(%input.1 : Float(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cuda:0),
%classifier.featurizer.input_layer.weight : Float(512, 512, strides=[512, 1], requires_grad=1, device=cuda:0),
%classifier.featurizer.input_layer.bias : Float(512, strides=[1], requires_grad=1, device=cuda:0),
%classifier.classifier.weight : Float(10, 512, strides=[512, 1], requires_grad=1, device=cuda:0),
%classifier.classifier.bias : Float(10, strides=[1], requires_grad=1, device=cuda:0),
%onnx::Conv_198 : Float(64, 3, 7, 7, strides=[147, 49, 7, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_199 : Float(64, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_201 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_202 : Float(64, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_204 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_205 : Float(64, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_207 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_208 : Float(64, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_210 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_211 : Float(64, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_213 : Float(128, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_214 : Float(128, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_216 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_217 : Float(128, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_219 : Float(128, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_220 : Float(128, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_222 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_223 : Float(128, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_225 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_226 : Float(128, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_228 : Float(256, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_229 : Float(256, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_231 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_232 : Float(256, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_234 : Float(256, 128, 1, 1, strides=[128, 1, 1, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_235 : Float(256, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_237 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_238 : Float(256, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_240 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_241 : Float(256, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_243 : Float(512, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_244 : Float(512, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_246 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_247 : Float(512, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_249 : Float(512, 256, 1, 1, strides=[256, 1, 1, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_250 : Float(512, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_252 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_253 : Float(512, strides=[1], requires_grad=0, device=cuda:0),
%onnx::Conv_255 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=0, device=cuda:0),
%onnx::Conv_256 : Float(512, strides=[1], requires_grad=0, device=cuda:0)):
%/featurizer/resnet/conv1/Conv_output_0 : Float(1, 64, 16, 16, strides=[16384, 256, 16, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[7, 7], pads=[3, 3, 3, 3], strides=[2, 2], onnx_name="/featurizer/resnet/conv1/Conv"](%input.1, %onnx::Conv_198, %onnx::Conv_199), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.conv.Conv2d::conv1 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/relu/Relu_output_0 : Float(1, 64, 16, 16, strides=[16384, 256, 16, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/relu/Relu"](%/featurizer/resnet/conv1/Conv_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/maxpool/MaxPool_output_0 : Float(1, 64, 8, 8, strides=[4096, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::MaxPool[ceil_mode=0, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[2, 2], onnx_name="/featurizer/resnet/maxpool/MaxPool"](%/featurizer/resnet/relu/Relu_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.pooling.MaxPool2d::maxpool # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:782:0
%/featurizer/resnet/layer1/layer1.0/conv1/Conv_output_0 : Float(1, 64, 8, 8, strides=[4096, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name="/featurizer/resnet/layer1/layer1.0/conv1/Conv"](%/featurizer/resnet/maxpool/MaxPool_output_0, %onnx::Conv_201, %onnx::Conv_202), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer1/torchvision.models.resnet.BasicBlock::layer1.0/torch.nn.modules.conv.Conv2d::conv1 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer1/layer1.0/relu/Relu_output_0 : Float(1, 64, 8, 8, strides=[4096, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer1/layer1.0/relu/Relu"](%/featurizer/resnet/layer1/layer1.0/conv1/Conv_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer1/torchvision.models.resnet.BasicBlock::layer1.0/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/layer1/layer1.0/conv2/Conv_output_0 : Float(1, 64, 8, 8, strides=[4096, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name="/featurizer/resnet/layer1/layer1.0/conv2/Conv"](%/featurizer/resnet/layer1/layer1.0/relu/Relu_output_0, %onnx::Conv_204, %onnx::Conv_205), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer1/torchvision.models.resnet.BasicBlock::layer1.0/torch.nn.modules.conv.Conv2d::conv2 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer1/layer1.0/Add_output_0 : Float(1, 64, 8, 8, strides=[4096, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::Add[onnx_name="/featurizer/resnet/layer1/layer1.0/Add"](%/featurizer/resnet/layer1/layer1.0/conv2/Conv_output_0, %/featurizer/resnet/maxpool/MaxPool_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer1/torchvision.models.resnet.BasicBlock::layer1.0 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torchvision/models/resnet.py:102:0
%/featurizer/resnet/layer1/layer1.0/relu_1/Relu_output_0 : Float(1, 64, 8, 8, strides=[4096, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer1/layer1.0/relu_1/Relu"](%/featurizer/resnet/layer1/layer1.0/Add_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer1/torchvision.models.resnet.BasicBlock::layer1.0/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/layer1/layer1.1/conv1/Conv_output_0 : Float(1, 64, 8, 8, strides=[4096, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name="/featurizer/resnet/layer1/layer1.1/conv1/Conv"](%/featurizer/resnet/layer1/layer1.0/relu_1/Relu_output_0, %onnx::Conv_207, %onnx::Conv_208), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer1/torchvision.models.resnet.BasicBlock::layer1.1/torch.nn.modules.conv.Conv2d::conv1 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer1/layer1.1/relu/Relu_output_0 : Float(1, 64, 8, 8, strides=[4096, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer1/layer1.1/relu/Relu"](%/featurizer/resnet/layer1/layer1.1/conv1/Conv_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer1/torchvision.models.resnet.BasicBlock::layer1.1/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/layer1/layer1.1/conv2/Conv_output_0 : Float(1, 64, 8, 8, strides=[4096, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name="/featurizer/resnet/layer1/layer1.1/conv2/Conv"](%/featurizer/resnet/layer1/layer1.1/relu/Relu_output_0, %onnx::Conv_210, %onnx::Conv_211), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer1/torchvision.models.resnet.BasicBlock::layer1.1/torch.nn.modules.conv.Conv2d::conv2 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer1/layer1.1/Add_output_0 : Float(1, 64, 8, 8, strides=[4096, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::Add[onnx_name="/featurizer/resnet/layer1/layer1.1/Add"](%/featurizer/resnet/layer1/layer1.1/conv2/Conv_output_0, %/featurizer/resnet/layer1/layer1.0/relu_1/Relu_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer1/torchvision.models.resnet.BasicBlock::layer1.1 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torchvision/models/resnet.py:102:0
%/featurizer/resnet/layer1/layer1.1/relu_1/Relu_output_0 : Float(1, 64, 8, 8, strides=[4096, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer1/layer1.1/relu_1/Relu"](%/featurizer/resnet/layer1/layer1.1/Add_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer1/torchvision.models.resnet.BasicBlock::layer1.1/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/layer2/layer2.0/conv1/Conv_output_0 : Float(1, 128, 4, 4, strides=[2048, 16, 4, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[2, 2], onnx_name="/featurizer/resnet/layer2/layer2.0/conv1/Conv"](%/featurizer/resnet/layer1/layer1.1/relu_1/Relu_output_0, %onnx::Conv_213, %onnx::Conv_214), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer2/torchvision.models.resnet.BasicBlock::layer2.0/torch.nn.modules.conv.Conv2d::conv1 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer2/layer2.0/relu/Relu_output_0 : Float(1, 128, 4, 4, strides=[2048, 16, 4, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer2/layer2.0/relu/Relu"](%/featurizer/resnet/layer2/layer2.0/conv1/Conv_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer2/torchvision.models.resnet.BasicBlock::layer2.0/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/layer2/layer2.0/conv2/Conv_output_0 : Float(1, 128, 4, 4, strides=[2048, 16, 4, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name="/featurizer/resnet/layer2/layer2.0/conv2/Conv"](%/featurizer/resnet/layer2/layer2.0/relu/Relu_output_0, %onnx::Conv_216, %onnx::Conv_217), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer2/torchvision.models.resnet.BasicBlock::layer2.0/torch.nn.modules.conv.Conv2d::conv2 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer2/layer2.0/downsample/downsample.0/Conv_output_0 : Float(1, 128, 4, 4, strides=[2048, 16, 4, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[2, 2], onnx_name="/featurizer/resnet/layer2/layer2.0/downsample/downsample.0/Conv"](%/featurizer/resnet/layer1/layer1.1/relu_1/Relu_output_0, %onnx::Conv_219, %onnx::Conv_220), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer2/torchvision.models.resnet.BasicBlock::layer2.0/torch.nn.modules.container.Sequential::downsample/torch.nn.modules.conv.Conv2d::downsample.0 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer2/layer2.0/Add_output_0 : Float(1, 128, 4, 4, strides=[2048, 16, 4, 1], requires_grad=1, device=cuda:0) = onnx::Add[onnx_name="/featurizer/resnet/layer2/layer2.0/Add"](%/featurizer/resnet/layer2/layer2.0/conv2/Conv_output_0, %/featurizer/resnet/layer2/layer2.0/downsample/downsample.0/Conv_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer2/torchvision.models.resnet.BasicBlock::layer2.0 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torchvision/models/resnet.py:102:0
%/featurizer/resnet/layer2/layer2.0/relu_1/Relu_output_0 : Float(1, 128, 4, 4, strides=[2048, 16, 4, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer2/layer2.0/relu_1/Relu"](%/featurizer/resnet/layer2/layer2.0/Add_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer2/torchvision.models.resnet.BasicBlock::layer2.0/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/layer2/layer2.1/conv1/Conv_output_0 : Float(1, 128, 4, 4, strides=[2048, 16, 4, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name="/featurizer/resnet/layer2/layer2.1/conv1/Conv"](%/featurizer/resnet/layer2/layer2.0/relu_1/Relu_output_0, %onnx::Conv_222, %onnx::Conv_223), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer2/torchvision.models.resnet.BasicBlock::layer2.1/torch.nn.modules.conv.Conv2d::conv1 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer2/layer2.1/relu/Relu_output_0 : Float(1, 128, 4, 4, strides=[2048, 16, 4, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer2/layer2.1/relu/Relu"](%/featurizer/resnet/layer2/layer2.1/conv1/Conv_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer2/torchvision.models.resnet.BasicBlock::layer2.1/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/layer2/layer2.1/conv2/Conv_output_0 : Float(1, 128, 4, 4, strides=[2048, 16, 4, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name="/featurizer/resnet/layer2/layer2.1/conv2/Conv"](%/featurizer/resnet/layer2/layer2.1/relu/Relu_output_0, %onnx::Conv_225, %onnx::Conv_226), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer2/torchvision.models.resnet.BasicBlock::layer2.1/torch.nn.modules.conv.Conv2d::conv2 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer2/layer2.1/Add_output_0 : Float(1, 128, 4, 4, strides=[2048, 16, 4, 1], requires_grad=1, device=cuda:0) = onnx::Add[onnx_name="/featurizer/resnet/layer2/layer2.1/Add"](%/featurizer/resnet/layer2/layer2.1/conv2/Conv_output_0, %/featurizer/resnet/layer2/layer2.0/relu_1/Relu_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer2/torchvision.models.resnet.BasicBlock::layer2.1 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torchvision/models/resnet.py:102:0
%/featurizer/resnet/layer2/layer2.1/relu_1/Relu_output_0 : Float(1, 128, 4, 4, strides=[2048, 16, 4, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer2/layer2.1/relu_1/Relu"](%/featurizer/resnet/layer2/layer2.1/Add_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer2/torchvision.models.resnet.BasicBlock::layer2.1/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/layer3/layer3.0/conv1/Conv_output_0 : Float(1, 256, 2, 2, strides=[1024, 4, 2, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[2, 2], onnx_name="/featurizer/resnet/layer3/layer3.0/conv1/Conv"](%/featurizer/resnet/layer2/layer2.1/relu_1/Relu_output_0, %onnx::Conv_228, %onnx::Conv_229), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer3/torchvision.models.resnet.BasicBlock::layer3.0/torch.nn.modules.conv.Conv2d::conv1 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer3/layer3.0/relu/Relu_output_0 : Float(1, 256, 2, 2, strides=[1024, 4, 2, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer3/layer3.0/relu/Relu"](%/featurizer/resnet/layer3/layer3.0/conv1/Conv_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer3/torchvision.models.resnet.BasicBlock::layer3.0/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/layer3/layer3.0/conv2/Conv_output_0 : Float(1, 256, 2, 2, strides=[1024, 4, 2, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name="/featurizer/resnet/layer3/layer3.0/conv2/Conv"](%/featurizer/resnet/layer3/layer3.0/relu/Relu_output_0, %onnx::Conv_231, %onnx::Conv_232), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer3/torchvision.models.resnet.BasicBlock::layer3.0/torch.nn.modules.conv.Conv2d::conv2 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer3/layer3.0/downsample/downsample.0/Conv_output_0 : Float(1, 256, 2, 2, strides=[1024, 4, 2, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[2, 2], onnx_name="/featurizer/resnet/layer3/layer3.0/downsample/downsample.0/Conv"](%/featurizer/resnet/layer2/layer2.1/relu_1/Relu_output_0, %onnx::Conv_234, %onnx::Conv_235), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer3/torchvision.models.resnet.BasicBlock::layer3.0/torch.nn.modules.container.Sequential::downsample/torch.nn.modules.conv.Conv2d::downsample.0 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer3/layer3.0/Add_output_0 : Float(1, 256, 2, 2, strides=[1024, 4, 2, 1], requires_grad=1, device=cuda:0) = onnx::Add[onnx_name="/featurizer/resnet/layer3/layer3.0/Add"](%/featurizer/resnet/layer3/layer3.0/conv2/Conv_output_0, %/featurizer/resnet/layer3/layer3.0/downsample/downsample.0/Conv_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer3/torchvision.models.resnet.BasicBlock::layer3.0 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torchvision/models/resnet.py:102:0
%/featurizer/resnet/layer3/layer3.0/relu_1/Relu_output_0 : Float(1, 256, 2, 2, strides=[1024, 4, 2, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer3/layer3.0/relu_1/Relu"](%/featurizer/resnet/layer3/layer3.0/Add_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer3/torchvision.models.resnet.BasicBlock::layer3.0/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/layer3/layer3.1/conv1/Conv_output_0 : Float(1, 256, 2, 2, strides=[1024, 4, 2, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name="/featurizer/resnet/layer3/layer3.1/conv1/Conv"](%/featurizer/resnet/layer3/layer3.0/relu_1/Relu_output_0, %onnx::Conv_237, %onnx::Conv_238), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer3/torchvision.models.resnet.BasicBlock::layer3.1/torch.nn.modules.conv.Conv2d::conv1 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer3/layer3.1/relu/Relu_output_0 : Float(1, 256, 2, 2, strides=[1024, 4, 2, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer3/layer3.1/relu/Relu"](%/featurizer/resnet/layer3/layer3.1/conv1/Conv_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer3/torchvision.models.resnet.BasicBlock::layer3.1/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/layer3/layer3.1/conv2/Conv_output_0 : Float(1, 256, 2, 2, strides=[1024, 4, 2, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name="/featurizer/resnet/layer3/layer3.1/conv2/Conv"](%/featurizer/resnet/layer3/layer3.1/relu/Relu_output_0, %onnx::Conv_240, %onnx::Conv_241), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer3/torchvision.models.resnet.BasicBlock::layer3.1/torch.nn.modules.conv.Conv2d::conv2 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer3/layer3.1/Add_output_0 : Float(1, 256, 2, 2, strides=[1024, 4, 2, 1], requires_grad=1, device=cuda:0) = onnx::Add[onnx_name="/featurizer/resnet/layer3/layer3.1/Add"](%/featurizer/resnet/layer3/layer3.1/conv2/Conv_output_0, %/featurizer/resnet/layer3/layer3.0/relu_1/Relu_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer3/torchvision.models.resnet.BasicBlock::layer3.1 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torchvision/models/resnet.py:102:0
%/featurizer/resnet/layer3/layer3.1/relu_1/Relu_output_0 : Float(1, 256, 2, 2, strides=[1024, 4, 2, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer3/layer3.1/relu_1/Relu"](%/featurizer/resnet/layer3/layer3.1/Add_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer3/torchvision.models.resnet.BasicBlock::layer3.1/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/layer4/layer4.0/conv1/Conv_output_0 : Float(1, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[2, 2], onnx_name="/featurizer/resnet/layer4/layer4.0/conv1/Conv"](%/featurizer/resnet/layer3/layer3.1/relu_1/Relu_output_0, %onnx::Conv_243, %onnx::Conv_244), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer4/torchvision.models.resnet.BasicBlock::layer4.0/torch.nn.modules.conv.Conv2d::conv1 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer4/layer4.0/relu/Relu_output_0 : Float(1, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer4/layer4.0/relu/Relu"](%/featurizer/resnet/layer4/layer4.0/conv1/Conv_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer4/torchvision.models.resnet.BasicBlock::layer4.0/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/layer4/layer4.0/conv2/Conv_output_0 : Float(1, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name="/featurizer/resnet/layer4/layer4.0/conv2/Conv"](%/featurizer/resnet/layer4/layer4.0/relu/Relu_output_0, %onnx::Conv_246, %onnx::Conv_247), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer4/torchvision.models.resnet.BasicBlock::layer4.0/torch.nn.modules.conv.Conv2d::conv2 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer4/layer4.0/downsample/downsample.0/Conv_output_0 : Float(1, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[2, 2], onnx_name="/featurizer/resnet/layer4/layer4.0/downsample/downsample.0/Conv"](%/featurizer/resnet/layer3/layer3.1/relu_1/Relu_output_0, %onnx::Conv_249, %onnx::Conv_250), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer4/torchvision.models.resnet.BasicBlock::layer4.0/torch.nn.modules.container.Sequential::downsample/torch.nn.modules.conv.Conv2d::downsample.0 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer4/layer4.0/Add_output_0 : Float(1, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cuda:0) = onnx::Add[onnx_name="/featurizer/resnet/layer4/layer4.0/Add"](%/featurizer/resnet/layer4/layer4.0/conv2/Conv_output_0, %/featurizer/resnet/layer4/layer4.0/downsample/downsample.0/Conv_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer4/torchvision.models.resnet.BasicBlock::layer4.0 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torchvision/models/resnet.py:102:0
%/featurizer/resnet/layer4/layer4.0/relu_1/Relu_output_0 : Float(1, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer4/layer4.0/relu_1/Relu"](%/featurizer/resnet/layer4/layer4.0/Add_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer4/torchvision.models.resnet.BasicBlock::layer4.0/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/layer4/layer4.1/conv1/Conv_output_0 : Float(1, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name="/featurizer/resnet/layer4/layer4.1/conv1/Conv"](%/featurizer/resnet/layer4/layer4.0/relu_1/Relu_output_0, %onnx::Conv_252, %onnx::Conv_253), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer4/torchvision.models.resnet.BasicBlock::layer4.1/torch.nn.modules.conv.Conv2d::conv1 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer4/layer4.1/relu/Relu_output_0 : Float(1, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer4/layer4.1/relu/Relu"](%/featurizer/resnet/layer4/layer4.1/conv1/Conv_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer4/torchvision.models.resnet.BasicBlock::layer4.1/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/layer4/layer4.1/conv2/Conv_output_0 : Float(1, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], onnx_name="/featurizer/resnet/layer4/layer4.1/conv2/Conv"](%/featurizer/resnet/layer4/layer4.1/relu/Relu_output_0, %onnx::Conv_255, %onnx::Conv_256), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer4/torchvision.models.resnet.BasicBlock::layer4.1/torch.nn.modules.conv.Conv2d::conv2 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/conv.py:459:0
%/featurizer/resnet/layer4/layer4.1/Add_output_0 : Float(1, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cuda:0) = onnx::Add[onnx_name="/featurizer/resnet/layer4/layer4.1/Add"](%/featurizer/resnet/layer4/layer4.1/conv2/Conv_output_0, %/featurizer/resnet/layer4/layer4.0/relu_1/Relu_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer4/torchvision.models.resnet.BasicBlock::layer4.1 # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torchvision/models/resnet.py:102:0
%/featurizer/resnet/layer4/layer4.1/relu_1/Relu_output_0 : Float(1, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/featurizer/resnet/layer4/layer4.1/relu_1/Relu"](%/featurizer/resnet/layer4/layer4.1/Add_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.container.Sequential::layer4/torchvision.models.resnet.BasicBlock::layer4.1/torch.nn.modules.activation.ReLU::relu # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1455:0
%/featurizer/resnet/avgpool/GlobalAveragePool_output_0 : Float(1, 512, 1, 1, strides=[512, 1, 1, 1], requires_grad=1, device=cuda:0) = onnx::GlobalAveragePool[onnx_name="/featurizer/resnet/avgpool/GlobalAveragePool"](%/featurizer/resnet/layer4/layer4.1/relu_1/Relu_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet/torch.nn.modules.pooling.AdaptiveAvgPool2d::avgpool # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1214:0
%/featurizer/resnet/Flatten_output_0 : Float(1, 512, strides=[512, 1], requires_grad=1, device=cuda:0) = onnx::Flatten[axis=1, onnx_name="/featurizer/resnet/Flatten"](%/featurizer/resnet/avgpool/GlobalAveragePool_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torchvision.models.resnet.ResNet::resnet # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torchvision/models/resnet.py:279:0
%/featurizer/flatten/Flatten_output_0 : Float(1, 512, strides=[512, 1], requires_grad=1, device=cuda:0) = onnx::Flatten[axis=1, onnx_name="/featurizer/flatten/Flatten"](%/featurizer/resnet/Flatten_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.ResNet::featurizer/torch.nn.modules.flatten.Flatten::flatten # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/flatten.py:46:0
%/classifier/featurizer/input_layer/Gemm_output_0 : Float(1, 512, strides=[512, 1], requires_grad=1, device=cuda:0) = onnx::Gemm[alpha=1., beta=1., transB=1, onnx_name="/classifier/featurizer/input_layer/Gemm"](%/featurizer/flatten/Flatten_output_0, %classifier.featurizer.input_layer.weight, %classifier.featurizer.input_layer.bias), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.MLPClassifier::classifier/astra.torch.models.MLP::featurizer/torch.nn.modules.linear.Linear::input_layer # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/linear.py:114:0
%/classifier/featurizer/activation/Relu_output_0 : Float(1, 512, strides=[512, 1], requires_grad=1, device=cuda:0) = onnx::Relu[onnx_name="/classifier/featurizer/activation/Relu"](%/classifier/featurizer/input_layer/Gemm_output_0), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.MLPClassifier::classifier/astra.torch.models.MLP::featurizer/torch.nn.modules.activation.ReLU::activation # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:1457:0
%196 : Float(1, 10, strides=[10, 1], requires_grad=1, device=cuda:0) = onnx::Gemm[alpha=1., beta=1., transB=1, onnx_name="/classifier/classifier/Gemm"](%/classifier/featurizer/activation/Relu_output_0, %classifier.classifier.weight, %classifier.classifier.bias), scope: astra.torch.models.ResNetClassifier::/astra.torch.models.MLPClassifier::classifier/torch.nn.modules.linear.Linear::classifier # /home/nipun.batra/miniforge3/lib/python3.9/site-packages/torch/nn/modules/linear.py:114:0
return (%196)
============= Diagnostic Run torch.onnx.export version 2.0.0+cu118 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
Serving 'resnet.onnx' at http://localhost:8081
('localhost', 8081)
def get_accuracy(net, X, y):
with torch.no_grad():
= net(X)
logits_pred = logits_pred.argmax(dim=1)
y_pred = (y_pred == y).float().mean()
acc return y_pred, acc
def predict(net, classes, plot_confusion_matrix=False):
for i, (name, idx) in enumerate(zip(("train", "pool", "test"), [train_idx, pool_idx, test_idx])):
= X[idx].to(device)
X_dataset = y[idx].to(device)
y_dataset = get_accuracy(net, X_dataset, y_dataset)
y_pred, acc print(f'{name} set accuracy: {acc*100:.2f}%')
if plot_confusion_matrix:
= confusion_matrix(y_dataset.cpu(), y_pred.cpu())
cm = ConfusionMatrixDisplay(cm, display_labels=classes).plot(values_format='d'
cm_display ='Blues')
, cmap# Rotate the labels on x-axis to make them readable
= plt.xticks(rotation=90)
_
plt.show()
=True) predict(resnet, dataset.classes, plot_confusion_matrix
train set accuracy: 7.70%
pool set accuracy: 8.37%
test set accuracy: 8.58%
def viz_embeddings(net, X, y, device):
= umap.UMAP()
reducer with torch.no_grad():
= net.featurizer(X.to(device))
emb = emb.cpu().numpy()
emb = reducer.fit_transform(emb)
emb =(4, 4))
plt.figure(figsize0], emb[:, 1], c=y.cpu().numpy(), cmap='tab10')
plt.scatter(emb[:, # Add a colorbar legend to mark color to class mapping
= plt.colorbar(boundaries=np.arange(11)-0.5)
cb 10))
cb.set_ticks(np.arange(
cb.set_ticklabels(dataset.classes)"UMAP embeddings")
plt.title(
plt.tight_layout()
viz_embeddings(resnet, X[train_idx], y[train_idx], device)
Train the model on train set
= ResNetClassifier(models.resnet18, None, n_classes=10, activation=nn.GELU(), dropout=0.1).to(device)
resnet = train_fn(resnet, X[train_idx], y[train_idx], nn.CrossEntropyLoss(), lr=3e-4,
iter_losses, epoch_losses =128, epochs=30, verbose=False) batch_size
plt.plot(iter_losses)"Iteration")
plt.xlabel("Training loss") plt.ylabel(
Text(0, 0.5, 'Training loss')
=True) predict(resnet, dataset.classes, plot_confusion_matrix
train set accuracy: 100.00%
pool set accuracy: 35.95%
test set accuracy: 36.27%
viz_embeddings(resnet, X[train_idx], y[train_idx], device)
/home/nipun.batra/miniforge3/lib/python3.9/site-packages/sklearn/manifold/_spectral_embedding.py:273: UserWarning: Graph is not fully connected, spectral embedding may not work as expected.
warnings.warn(
1000]], y[test_idx[:1000]], device) viz_embeddings(resnet, X[test_idx[:
### Train on train + pool
= torch.cat([train_idx, pool_idx])
train_plus_pool_idx
= ResNetClassifier(models.resnet18, None, n_classes=10, activation=nn.GELU(), dropout=0.1).to(device)
resnet
= train_fn(resnet, X[train_plus_pool_idx], y[train_plus_pool_idx], loss_fn=nn.CrossEntropyLoss(),
iter_losses, epoch_losses =3e-4,
lr=1024, epochs=30) batch_size
Loss: 1.311430: 100%|██████████| 40/40 [00:01<00:00, 34.22it/s]
Loss: 1.074545: 100%|██████████| 40/40 [00:01<00:00, 36.99it/s]
Loss: 1.028258: 100%|██████████| 40/40 [00:01<00:00, 36.88it/s]
Loss: 0.679134: 100%|██████████| 40/40 [00:01<00:00, 37.03it/s]
Loss: 0.573962: 100%|██████████| 40/40 [00:01<00:00, 36.61it/s]
Loss: 0.470612: 100%|██████████| 40/40 [00:01<00:00, 37.01it/s]
Loss: 0.579732: 100%|██████████| 40/40 [00:01<00:00, 36.57it/s]
Loss: 0.262016: 100%|██████████| 40/40 [00:01<00:00, 37.01it/s]
Loss: 0.147717: 100%|██████████| 40/40 [00:01<00:00, 36.54it/s]
Loss: 0.108952: 100%|██████████| 40/40 [00:01<00:00, 36.02it/s]
Loss: 0.271355: 100%|██████████| 40/40 [00:01<00:00, 36.75it/s]
Loss: 0.370208: 100%|██████████| 40/40 [00:01<00:00, 35.67it/s]
Loss: 0.304398: 100%|██████████| 40/40 [00:01<00:00, 36.01it/s]
Loss: 0.184816: 100%|██████████| 40/40 [00:01<00:00, 36.37it/s]
Loss: 0.196016: 100%|██████████| 40/40 [00:01<00:00, 35.36it/s]
Loss: 0.181180: 100%|██████████| 40/40 [00:01<00:00, 34.56it/s]
Loss: 0.196182: 100%|██████████| 40/40 [00:01<00:00, 35.60it/s]
Loss: 0.144379: 100%|██████████| 40/40 [00:01<00:00, 35.92it/s]
Loss: 0.047389: 100%|██████████| 40/40 [00:01<00:00, 35.78it/s]
Loss: 0.031721: 100%|██████████| 40/40 [00:01<00:00, 35.50it/s]
Loss: 0.071309: 100%|██████████| 40/40 [00:01<00:00, 35.90it/s]
Loss: 0.132249: 100%|██████████| 40/40 [00:01<00:00, 35.88it/s]
Loss: 0.064951: 100%|██████████| 40/40 [00:01<00:00, 35.89it/s]
Loss: 0.087784: 100%|██████████| 40/40 [00:01<00:00, 35.95it/s]
Loss: 0.022345: 100%|██████████| 40/40 [00:01<00:00, 35.38it/s]
Loss: 0.220102: 100%|██████████| 40/40 [00:01<00:00, 34.77it/s]
Loss: 0.058476: 100%|██████████| 40/40 [00:01<00:00, 35.88it/s]
Loss: 0.083048: 100%|██████████| 40/40 [00:01<00:00, 36.26it/s]
Loss: 0.047481: 100%|██████████| 40/40 [00:01<00:00, 35.44it/s]
Loss: 0.047544: 100%|██████████| 40/40 [00:01<00:00, 35.65it/s]
Epoch 1: 1.6589520935058595
Epoch 2: 1.2343804992675782
Epoch 3: 1.0032470321655274
Epoch 4: 0.8050672988891602
Epoch 5: 0.6229847793579102
Epoch 6: 0.48343274230957034
Epoch 7: 0.3824925224304199
Epoch 8: 0.3308661190032959
Epoch 9: 0.19066958694458008
Epoch 10: 0.1531820728302002
Epoch 11: 0.13519002685546874
Epoch 12: 0.1659882785797119
Epoch 13: 0.23326843185424806
Epoch 14: 0.16760483856201172
Epoch 15: 0.09269110498428344
Epoch 16: 0.08953567190170288
Epoch 17: 0.10240967988967896
Epoch 18: 0.1157124849319458
Epoch 19: 0.08735885505676269
Epoch 20: 0.032301900148391724
Epoch 21: 0.02235677945613861
Epoch 22: 0.05257979347705841
Epoch 23: 0.06249353976249695
Epoch 24: 0.05699077892303467
Epoch 25: 0.0789616925239563
Epoch 26: 0.03884266901016235
Epoch 27: 0.10908802990913391
Epoch 28: 0.06660935344696045
Epoch 29: 0.042903441429138184
Epoch 30: 0.03823600392341614
plt.plot(iter_losses) "Iteration")
plt.xlabel("Training loss") plt.ylabel(
Text(0, 0.5, 'Training loss')
viz_embeddings(resnet, X[train_idx], y[train_idx], device)
1000]], y[test_idx[:1000]], device) viz_embeddings(resnet, X[test_idx[:
=True) predict(resnet, dataset.classes, plot_confusion_matrix
train set accuracy: 99.10%
pool set accuracy: 99.58%
test set accuracy: 61.65%
SSL
Task 1: Predict angle of rotation (0, 90, 180, 270) as a classification task
Create a dataset with rotated images and corresponding labels. We can now use a much larger dataset
= torch.cat([X[train_idx], X[pool_idx]])
X_train_plus_pool = torch.cat([y[train_idx], y[pool_idx]])
y_train_plus_pool
X_train_plus_pool.shape, y_train_plus_pool.shape
(torch.Size([40000, 3, 32, 32]), torch.Size([40000]))
= []
X_ssl = []
y_ssl
= {0:0, 90:1, 180:2, 270:3}
angles_map for angle_rot in angles_map.keys():
print(f"Angle: {angle_rot}")
= transforms.functional.rotate(X_train_plus_pool, angle_rot)
X_rot
X_ssl.append(X_rot)*len(X_rot)))
y_ssl.append(torch.tensor([angles_map[angle_rot]]
= torch.cat(X_ssl)
X_ssl = torch.cat(y_ssl) y_ssl
Angle: 0
Angle: 90
Angle: 180
Angle: 270
X_ssl.shape, y_ssl.shape
(torch.Size([160000, 3, 32, 32]), torch.Size([160000]))
# Plot same image rotated at different angles
def plot_ssl(img_id):
=(3, 3))
plt.figure(figsize= len(X_train_plus_pool)
offset for i in range(4):
2, 2, i+1)
plt.subplot("chw->hwc", X_ssl[offset*i + img_id]))
plt.imshow(torch.einsum('off')
plt.axis(f"Class: {angles_map[i*90]}\n Angle: {i*90}")
plt.title(
plt.tight_layout()2) plot_ssl(
= ResNetClassifier(models.resnet18, None, n_classes=4, activation=nn.GELU(), dropout=0.1).to(device) ssl_angle
= train_fn(ssl_angle, X_ssl, y_ssl, lr=3e-4, loss_fn=nn.CrossEntropyLoss(),
iter_losses, epoch_losses =1024, epochs=20) batch_size
Loss: 0.945101: 100%|██████████| 157/157 [00:04<00:00, 35.35it/s]
Loss: 0.785275: 100%|██████████| 157/157 [00:04<00:00, 35.88it/s]
Loss: 0.644941: 100%|██████████| 157/157 [00:04<00:00, 35.64it/s]
Loss: 0.697314: 100%|██████████| 157/157 [00:04<00:00, 36.29it/s]
Loss: 0.578679: 100%|██████████| 157/157 [00:04<00:00, 35.66it/s]
Loss: 0.478466: 100%|██████████| 157/157 [00:04<00:00, 36.11it/s]
Loss: 0.466891: 100%|██████████| 157/157 [00:04<00:00, 35.93it/s]
Loss: 0.436392: 100%|██████████| 157/157 [00:04<00:00, 35.65it/s]
Loss: 0.302828: 100%|██████████| 157/157 [00:04<00:00, 36.11it/s]
Loss: 0.226936: 100%|██████████| 157/157 [00:04<00:00, 35.73it/s]
Loss: 0.260345: 100%|██████████| 157/157 [00:04<00:00, 35.70it/s]
Loss: 0.136557: 100%|██████████| 157/157 [00:04<00:00, 36.13it/s]
Loss: 0.214188: 100%|██████████| 157/157 [00:04<00:00, 35.74it/s]
Loss: 0.114707: 100%|██████████| 157/157 [00:04<00:00, 35.61it/s]
Loss: 0.100094: 100%|██████████| 157/157 [00:04<00:00, 35.93it/s]
Loss: 0.114667: 100%|██████████| 157/157 [00:04<00:00, 36.11it/s]
Loss: 0.088765: 100%|██████████| 157/157 [00:04<00:00, 36.15it/s]
Loss: 0.094059: 100%|██████████| 157/157 [00:04<00:00, 36.16it/s]
Loss: 0.135386: 100%|██████████| 157/157 [00:04<00:00, 34.21it/s]
Loss: 0.073386: 100%|██████████| 157/157 [00:04<00:00, 35.75it/s]
Epoch 1: 0.9923952384948731
Epoch 2: 0.8178635284423829
Epoch 3: 0.7248399837493896
Epoch 4: 0.644944263458252
Epoch 5: 0.5666926818847656
Epoch 6: 0.4871123765945435
Epoch 7: 0.4118640880584717
Epoch 8: 0.3425796222686768
Epoch 9: 0.2795945789337158
Epoch 10: 0.2271963834762573
Epoch 11: 0.18415976438522338
Epoch 12: 0.15323887557983398
Epoch 13: 0.13012397813796997
Epoch 14: 0.10969336915016174
Epoch 15: 0.09902867212295532
Epoch 16: 0.08491913187503815
Epoch 17: 0.07862492105960846
Epoch 18: 0.0721992644071579
Epoch 19: 0.068406194627285
Epoch 20: 0.061716662764549256
plt.plot(iter_losses)"Iteration")
plt.xlabel("Training loss") plt.ylabel(
Text(0, 0.5, 'Training loss')
# Visualise the embeddings of the SSL model trained on angles dataset
# (but wrt original 10 classes)
viz_embeddings(ssl_angle, X[train_idx], y[train_idx], device)
# Now, we can use the features from SSLAngle model to train the classifier on the original data
= ResNetClassifier(models.resnet18, None, n_classes=10, activation=nn.GELU(), dropout=0.1).to(device)
net_pretrained
net_pretrained.featurizer.load_state_dict(ssl_angle.featurizer.state_dict())
= train_fn(net_pretrained, X[train_idx], y[train_idx], nn.CrossEntropyLoss(), lr=3e-4, epochs=50, batch_size=128, verbose=False) iter_losses, epoch_losses
plt.plot(iter_losses)"Iteration")
plt.xlabel("Training loss") plt.ylabel(
Text(0, 0.5, 'Training loss')
viz_embeddings(net_pretrained, X[train_idx], y[train_idx], device)
/home/nipun.batra/miniforge3/lib/python3.9/site-packages/sklearn/manifold/_spectral_embedding.py:273: UserWarning: Graph is not fully connected, spectral embedding may not work as expected.
warnings.warn(
1000]], y[test_idx[:1000]], device) viz_embeddings(net_pretrained, X[test_idx[:
=True) predict(net_pretrained, dataset.classes, plot_confusion_matrix
train set accuracy: 100.00%
pool set accuracy: 47.94%
test set accuracy: 47.57%
Summary of the test set performance
- Untrained model: 9%
- Train on 1000 labeled samples (train set): 36%
- Train on 1000 labeled samples + 39000 label samples (pool set): 62%
- Train on SSL task + Finetune on 1000 labeled samples: 47%