1. Imports and Device
!pip install timm==1.0.20 torchvision matplotlib scikit-learn --upgrade
Requirement already satisfied: timm==1.0.20 in /usr/local/lib/python3.12/dist-packages (1.0.20)
Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (0.23.0+cu126)
Collecting torchvision
Downloading torchvision-0.24.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (3.10.0)
Collecting matplotlib
Downloading matplotlib-3.10.7-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.12/dist-packages (1.6.1)
Collecting scikit-learn
Downloading scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (from timm==1.0.20) (2.8.0+cu126)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.12/dist-packages (from timm==1.0.20) (6.0.3)
Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.12/dist-packages (from timm==1.0.20) (0.35.3)
Requirement already satisfied: safetensors in /usr/local/lib/python3.12/dist-packages (from timm==1.0.20) (0.6.2)
Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from torchvision) (2.0.2)
Collecting torch (from timm==1.0.20)
Downloading torch-2.9.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (30 kB)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.12/dist-packages (from torchvision) (11.3.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch->timm==1.0.20) (3.20.0)
Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch->timm==1.0.20) (4.15.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch->timm==1.0.20) (75.2.0)
Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch->timm==1.0.20) (1.13.3)
Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch->timm==1.0.20) (3.5)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch->timm==1.0.20) (3.1.6)
Requirement already satisfied: fsspec>=0.8.5 in /usr/local/lib/python3.12/dist-packages (from torch->timm==1.0.20) (2025.3.0)
Collecting nvidia-cuda-nvrtc-cu12==12.8.93 (from torch->timm==1.0.20)
Downloading nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-runtime-cu12==12.8.90 (from torch->timm==1.0.20)
Downloading nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-cupti-cu12==12.8.90 (from torch->timm==1.0.20)
Downloading nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch->timm==1.0.20) (9.10.2.21)
Collecting nvidia-cublas-cu12==12.8.4.1 (from torch->timm==1.0.20)
Downloading nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cufft-cu12==11.3.3.83 (from torch->timm==1.0.20)
Downloading nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-curand-cu12==10.3.9.90 (from torch->timm==1.0.20)
Downloading nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cusolver-cu12==11.7.3.90 (from torch->timm==1.0.20)
Downloading nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl.metadata (1.8 kB)
Collecting nvidia-cusparse-cu12==12.5.8.93 (from torch->timm==1.0.20)
Downloading nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.8 kB)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch->timm==1.0.20) (0.7.1)
Collecting nvidia-nccl-cu12==2.27.5 (from torch->timm==1.0.20)
Downloading nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.0 kB)
Collecting nvidia-nvshmem-cu12==3.3.20 (from torch->timm==1.0.20)
Downloading nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.1 kB)
Collecting nvidia-nvtx-cu12==12.8.90 (from torch->timm==1.0.20)
Downloading nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.8 kB)
Collecting nvidia-nvjitlink-cu12==12.8.93 (from torch->timm==1.0.20)
Downloading nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cufile-cu12==1.13.1.3 (from torch->timm==1.0.20)
Downloading nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Collecting triton==3.5.0 (from torch->timm==1.0.20)
Downloading triton-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.7 kB)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (4.60.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (1.4.9)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (25.0)
Requirement already satisfied: pyparsing>=3 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (3.2.5)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (2.9.0.post0)
Requirement already satisfied: scipy>=1.8.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (1.16.2)
Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (1.5.2)
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (3.6.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.7->matplotlib) (1.17.0)
Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from huggingface_hub->timm==1.0.20) (2.32.4)
Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub->timm==1.0.20) (4.67.1)
Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub->timm==1.0.20) (1.1.10)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch->timm==1.0.20) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch->timm==1.0.20) (3.0.3)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub->timm==1.0.20) (3.4.4)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub->timm==1.0.20) (3.11)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub->timm==1.0.20) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub->timm==1.0.20) (2025.10.5)
Downloading torchvision-0.24.0-cp312-cp312-manylinux_2_28_x86_64.whl (8.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.1/8.1 MB 96.3 MB/s eta 0:00:00
Downloading torch-2.9.0-cp312-cp312-manylinux_2_28_x86_64.whl (899.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 899.7/899.7 MB 1.4 MB/s eta 0:00:00
Downloading nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl (594.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 594.3/594.3 MB 2.6 MB/s eta 0:00:00
Downloading nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (10.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 10.2/10.2 MB 164.7 MB/s eta 0:00:00
Downloading nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (88.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 88.0/88.0 MB 9.6 MB/s eta 0:00:00
Downloading nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (954 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 954.8/954.8 kB 72.9 MB/s eta 0:00:00
Downloading nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (193.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 193.1/193.1 MB 6.3 MB/s eta 0:00:00
Downloading nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (1.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 89.3 MB/s eta 0:00:00
Downloading nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl (63.6 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.6/63.6 MB 13.2 MB/s eta 0:00:00
Downloading nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl (267.5 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 267.5/267.5 MB 4.8 MB/s eta 0:00:00
Downloading nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (288.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 288.2/288.2 MB 3.1 MB/s eta 0:00:00
Downloading nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (322.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 322.3/322.3 MB 2.9 MB/s eta 0:00:00
Downloading nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 39.3/39.3 MB 59.7 MB/s eta 0:00:00
Downloading nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (124.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.7/124.7 MB 7.9 MB/s eta 0:00:00
Downloading nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (89 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 90.0/90.0 kB 9.3 MB/s eta 0:00:00
Downloading triton-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (170.5 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 170.5/170.5 MB 6.5 MB/s eta 0:00:00
Downloading matplotlib-3.10.7-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.7/8.7 MB 164.8 MB/s eta 0:00:00
Downloading scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.5 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.5/9.5 MB 163.7 MB/s eta 0:00:00
Installing collected packages: triton, nvidia-nvtx-cu12, nvidia-nvshmem-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufile-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, scikit-learn, nvidia-cusparse-cu12, nvidia-cufft-cu12, matplotlib, nvidia-cusolver-cu12, torch, torchvision
Attempting uninstall: triton
Found existing installation: triton 3.4.0
Uninstalling triton-3.4.0:
Successfully uninstalled triton-3.4.0
Attempting uninstall: nvidia-nvtx-cu12
Found existing installation: nvidia-nvtx-cu12 12.6.77
Uninstalling nvidia-nvtx-cu12-12.6.77:
Successfully uninstalled nvidia-nvtx-cu12-12.6.77
Attempting uninstall: nvidia-nvshmem-cu12
Found existing installation: nvidia-nvshmem-cu12 3.4.5
Uninstalling nvidia-nvshmem-cu12-3.4.5:
Successfully uninstalled nvidia-nvshmem-cu12-3.4.5
Attempting uninstall: nvidia-nvjitlink-cu12
Found existing installation: nvidia-nvjitlink-cu12 12.6.85
Uninstalling nvidia-nvjitlink-cu12-12.6.85:
Successfully uninstalled nvidia-nvjitlink-cu12-12.6.85
Attempting uninstall: nvidia-nccl-cu12
Found existing installation: nvidia-nccl-cu12 2.27.3
Uninstalling nvidia-nccl-cu12-2.27.3:
Successfully uninstalled nvidia-nccl-cu12-2.27.3
Attempting uninstall: nvidia-curand-cu12
Found existing installation: nvidia-curand-cu12 10.3.7.77
Uninstalling nvidia-curand-cu12-10.3.7.77:
Successfully uninstalled nvidia-curand-cu12-10.3.7.77
Attempting uninstall: nvidia-cufile-cu12
Found existing installation: nvidia-cufile-cu12 1.11.1.6
Uninstalling nvidia-cufile-cu12-1.11.1.6:
Successfully uninstalled nvidia-cufile-cu12-1.11.1.6
Attempting uninstall: nvidia-cuda-runtime-cu12
Found existing installation: nvidia-cuda-runtime-cu12 12.6.77
Uninstalling nvidia-cuda-runtime-cu12-12.6.77:
Successfully uninstalled nvidia-cuda-runtime-cu12-12.6.77
Attempting uninstall: nvidia-cuda-nvrtc-cu12
Found existing installation: nvidia-cuda-nvrtc-cu12 12.6.77
Uninstalling nvidia-cuda-nvrtc-cu12-12.6.77:
Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.6.77
Attempting uninstall: nvidia-cuda-cupti-cu12
Found existing installation: nvidia-cuda-cupti-cu12 12.6.80
Uninstalling nvidia-cuda-cupti-cu12-12.6.80:
Successfully uninstalled nvidia-cuda-cupti-cu12-12.6.80
Attempting uninstall: nvidia-cublas-cu12
Found existing installation: nvidia-cublas-cu12 12.6.4.1
Uninstalling nvidia-cublas-cu12-12.6.4.1:
Successfully uninstalled nvidia-cublas-cu12-12.6.4.1
Attempting uninstall: scikit-learn
Found existing installation: scikit-learn 1.6.1
Uninstalling scikit-learn-1.6.1:
Successfully uninstalled scikit-learn-1.6.1
Attempting uninstall: nvidia-cusparse-cu12
Found existing installation: nvidia-cusparse-cu12 12.5.4.2
Uninstalling nvidia-cusparse-cu12-12.5.4.2:
Successfully uninstalled nvidia-cusparse-cu12-12.5.4.2
Attempting uninstall: nvidia-cufft-cu12
Found existing installation: nvidia-cufft-cu12 11.3.0.4
Uninstalling nvidia-cufft-cu12-11.3.0.4:
Successfully uninstalled nvidia-cufft-cu12-11.3.0.4
Attempting uninstall: matplotlib
Found existing installation: matplotlib 3.10.0
Uninstalling matplotlib-3.10.0:
Successfully uninstalled matplotlib-3.10.0
Attempting uninstall: nvidia-cusolver-cu12
Found existing installation: nvidia-cusolver-cu12 11.7.1.2
Uninstalling nvidia-cusolver-cu12-11.7.1.2:
Successfully uninstalled nvidia-cusolver-cu12-11.7.1.2
Attempting uninstall: torch
Found existing installation: torch 2.8.0+cu126
Uninstalling torch-2.8.0+cu126:
Successfully uninstalled torch-2.8.0+cu126
Attempting uninstall: torchvision
Found existing installation: torchvision 0.23.0+cu126
Uninstalling torchvision-0.23.0+cu126:
Successfully uninstalled torchvision-0.23.0+cu126
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchaudio 2.8.0+cu126 requires torch==2.8.0, but you have torch 2.9.0 which is incompatible.
fastai 2.8.4 requires torch<2.9,>=1.10, but you have torch 2.9.0 which is incompatible.
Successfully installed matplotlib-3.10.7 nvidia-cublas-cu12-12.8.4.1 nvidia-cuda-cupti-cu12-12.8.90 nvidia-cuda-nvrtc-cu12-12.8.93 nvidia-cuda-runtime-cu12-12.8.90 nvidia-cufft-cu12-11.3.3.83 nvidia-cufile-cu12-1.13.1.3 nvidia-curand-cu12-10.3.9.90 nvidia-cusolver-cu12-11.7.3.90 nvidia-cusparse-cu12-12.5.8.93 nvidia-nccl-cu12-2.27.5 nvidia-nvjitlink-cu12-12.8.93 nvidia-nvshmem-cu12-3.3.20 nvidia-nvtx-cu12-12.8.90 scikit-learn-1.7.2 torch-2.9.0 torchvision-0.24.0 triton-3.5.0
Unable to display output for mime type(s): application/vnd.colab-display-data+json
import torch
import torch.nn as nn
import timm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import torch.nn.functional as F
2 Data: CIFAR-10 setup
transform_train = transforms.Compose([
transforms.Resize(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_ds = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
val_ds = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
subset_size = 500
subset_indices = np.random.choice(len(val_ds), subset_size, replace=False)
val_subset = Subset(val_ds, subset_indices)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=128, shuffle=False, num_workers=4)
val_subset_loader = DataLoader(val_subset, batch_size=128, shuffle=False, num_workers=2)
classes = train_ds.classes
fig, axes = plt.subplots(1, 10, figsize=(12, 3))
for i in range(10):
img, label = train_ds[i]
img_show = img.permute(1, 2, 0).numpy()
img_show = np.clip((img_show * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406]), 0, 1)
axes[i].imshow(img_show)
axes[i].set_title(classes[label])
axes[i].axis('off')
plt.suptitle('Sample CIFAR-10 Images')
plt.show()
3 Model: timm DINOv3 backbone
backbone_name = 'vit_base_patch16_dinov3'
backbone = timm.create_model(backbone_name, pretrained=True, num_classes=0, global_pool='avg')
class DinoClassifier(nn.Module):
def __init__(self, backbone, num_classes=10):
super().__init__()
self.backbone = backbone
feat_dim = backbone.num_features
self.head = nn.Linear(feat_dim, num_classes)
def forward(self, x):
x = self.backbone(x)
x = self.head(x)
return x
model = DinoClassifier(backbone, num_classes=10)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
DinoClassifier(
(backbone): Eva(
(patch_embed): PatchEmbed(
(proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
(norm): Identity()
)
(pos_drop): Dropout(p=0.0, inplace=False)
(rope): RotaryEmbeddingDinoV3()
(norm_pre): Identity()
(blocks): ModuleList(
(0-11): 12 x EvaBlock(
(norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): EvaAttention(
(qkv): Linear(in_features=768, out_features=2304, bias=False)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path1): Identity()
(norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(drop_path2): Identity()
)
)
(norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(fc_norm): Identity()
(head_drop): Dropout(p=0.0, inplace=False)
(head): Identity()
)
(head): Linear(in_features=768, out_features=10, bias=True)
)
Image → Split into (16x16) patches → Patch Embedding (Conv2d, 768d) → Positional Encoding (RoPE) → 12 Transformer Layers (Feature Extraction) → Normalization → Classifier (Linear → 10 Classes)
# Freeze backbone for faster linear probing
for param in model.backbone.parameters():
param.requires_grad = False
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
def evaluate(model, loader, title):
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
preds = outputs.argmax(dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
cm = confusion_matrix(all_labels, all_preds, labels=list(range(10)))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
disp.plot(xticks_rotation='vertical', cmap='Blues')
plt.title(title)
plt.show()
acc = (np.array(all_preds) == np.array(all_labels)).mean()
print(f"Accuracy: {acc*100:.2f}%\n")
# Evaluate before fine-tuning
evaluate(model, val_subset_loader, title='Confusion Matrix Before Fine-Tuning')
4 Training
from tqdm import tqdm
EPOCHS = 2
for epoch in range(EPOCHS):
model.train()
running_loss = 0.0
progress_bar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{EPOCHS}]", leave=False, mininterval=3)
for images, labels in progress_bar:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
print(f"Epoch [{epoch+1}/{EPOCHS}] - Avg Loss: {running_loss/len(train_loader):.4f}")
Epoch [2/2] - Avg Loss: 0.2984
evaluate(model, val_subset_loader, title='Confusion Matrix After Fine-Tuning')