# Install dependencies (uncomment if needed)
# !pip install torch torchvision onnx onnxruntime numpy matplotlibModel Profiling & Quantization — Notebook
Week 13 · CS 203 · Software Tools and Techniques for AI
Prof. Nipun Batra · IIT Gandhinagar
This notebook covers: 1. Model architecture and parameter counting 2. Profiling (timing, torch.profiler, memory) 3. Dynamic quantization 4. ONNX export and inference 5. Comparing all optimizations
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
import os
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
plt.style.use('seaborn-v0_8-whitegrid')1. Define Model
class SimpleCNN(nn.Module):
"""Small CNN for CIFAR-10-like images (3x32x32)."""
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
)
self.classifier = nn.Sequential(
nn.Linear(128, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, num_classes),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
model = SimpleCNN()
model.eval()
print(model)# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:>10,}")
print(f"Trainable parameters: {trainable_params:>10,}")
print(f"Model size (FP32): {total_params * 4 / 1024:>10.1f} KB")
print(f"Model size (FP16): {total_params * 2 / 1024:>10.1f} KB")
print(f"Model size (INT8): {total_params * 1 / 1024:>10.1f} KB")
# Per-layer breakdown
print("\nPer-layer breakdown:")
print(f"{'Layer':<35} {'Shape':<25} {'Params':>10} {'Size (KB)':>10}")
print("-" * 85)
for name, param in model.named_parameters():
n = param.numel()
print(f"{name:<35} {str(list(param.shape)):<25} {n:>10,} {n * 4 / 1024:>10.1f}")2. Profiling: Timing
def benchmark(model, input_data, n_runs=200, label="Model"):
"""Benchmark inference time with proper warmup."""
# Warmup
for _ in range(20):
with torch.no_grad():
model(input_data)
# Timed runs
start = time.perf_counter()
for _ in range(n_runs):
with torch.no_grad():
model(input_data)
elapsed = (time.perf_counter() - start) / n_runs * 1000
print(f"{label:<30} {elapsed:>8.2f} ms/batch")
return elapsed
# Benchmark at different batch sizes
print("Batch size effects on per-sample latency:\n")
batch_times = {}
for bs in [1, 4, 16, 32, 64, 128]:
x = torch.randn(bs, 3, 32, 32)
ms = benchmark(model, x, label=f"Batch size {bs}")
batch_times[bs] = ms / bs # per-sample time# Visualize batch size effect
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(list(batch_times.keys()), list(batch_times.values()), 'o-', markersize=8)
ax.set_xlabel('Batch Size')
ax.set_ylabel('Time per Sample (ms)')
ax.set_title('Batching Reduces Per-Sample Latency')
ax.set_xscale('log', base=2)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()3. Profiling: torch.profiler
from torch.profiler import profile, ProfilerActivity
input_data = torch.randn(32, 3, 32, 32)
with profile(
activities=[ProfilerActivity.CPU],
record_shapes=True,
profile_memory=True,
) as prof:
with torch.no_grad():
for _ in range(10):
output = model(input_data)
print("Top operations by CPU time:")
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=15))print("Top operations by memory usage:")
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))4. Memory Profiling
# Track activation sizes through the network
activation_sizes = []
def hook_fn(name):
def hook(module, input, output):
if isinstance(output, torch.Tensor):
size_kb = output.numel() * output.element_size() / 1024
activation_sizes.append((name, list(output.shape), size_kb))
return hook
hooks = []
for name, module in model.named_modules():
if len(list(module.children())) == 0: # leaf modules
hooks.append(module.register_forward_hook(hook_fn(name)))
with torch.no_grad():
_ = model(torch.randn(1, 3, 32, 32))
for h in hooks:
h.remove()
print(f"{'Layer':<35} {'Output Shape':<25} {'Size (KB)':>10}")
print("-" * 75)
total_act = 0
for name, shape, size in activation_sizes:
print(f"{name:<35} {str(shape):<25} {size:>10.1f}")
total_act += size
print("-" * 75)
print(f"{'TOTAL':<35} {'':25} {total_act:>10.1f}")# Visualize: weights vs activations
weight_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024
fig, ax = plt.subplots(figsize=(6, 4))
bars = ax.bar(['Weights', 'Activations (bs=1)'], [weight_size, total_act], color=['steelblue', 'coral'])
ax.set_ylabel('Memory (KB)')
ax.set_title('Memory Breakdown: Weights vs Activations')
for bar, val in zip(bars, [weight_size, total_act]):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 5, f'{val:.0f} KB',
ha='center', fontsize=11)
plt.tight_layout()
plt.show()5. Dynamic Quantization
# Save original model
torch.save(model.state_dict(), "model_fp32.pth")
fp32_size = os.path.getsize("model_fp32.pth")
# Dynamic quantization (quantize Linear layers)
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear},
dtype=torch.qint8,
)
torch.save(quantized_model.state_dict(), "model_int8.pth")
int8_size = os.path.getsize("model_int8.pth")
print(f"FP32 model size: {fp32_size / 1024:.1f} KB")
print(f"INT8 model size: {int8_size / 1024:.1f} KB")
print(f"Compression: {fp32_size / int8_size:.2f}x smaller")# Benchmark: FP32 vs INT8
input_data = torch.randn(32, 3, 32, 32)
print("Inference speed comparison:")
fp32_time = benchmark(model, input_data, label="FP32 (original)")
int8_time = benchmark(quantized_model, input_data, label="INT8 (quantized)")
print(f"\nSpeedup: {fp32_time / int8_time:.2f}x")# Compare outputs
with torch.no_grad():
fp32_out = model(input_data)
int8_out = quantized_model(input_data)
fp32_preds = fp32_out.argmax(dim=1)
int8_preds = int8_out.argmax(dim=1)
agreement = (fp32_preds == int8_preds).float().mean()
max_diff = (fp32_out - int8_out).abs().max().item()
mean_diff = (fp32_out - int8_out).abs().mean().item()
print(f"Prediction agreement: {agreement:.1%}")
print(f"Max output difference: {max_diff:.6f}")
print(f"Mean output difference: {mean_diff:.6f}")# Visualize weight distributions before/after quantization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# FP32 weights
fp32_weights = []
for p in model.parameters():
fp32_weights.extend(p.detach().numpy().flatten())
axes[0].hist(fp32_weights, bins=100, alpha=0.7, color='steelblue', edgecolor='black', linewidth=0.5)
axes[0].set_title(f'FP32 Weight Distribution\n(model size: {fp32_size/1024:.0f} KB)')
axes[0].set_xlabel('Weight Value')
axes[0].set_ylabel('Count')
# Output difference distribution
diffs = (fp32_out - int8_out).detach().numpy().flatten()
axes[1].hist(diffs, bins=50, alpha=0.7, color='coral', edgecolor='black', linewidth=0.5)
axes[1].set_title(f'Quantization Error Distribution\n(max: {max_diff:.4f}, mean: {mean_diff:.4f})')
axes[1].set_xlabel('FP32 output - INT8 output')
axes[1].set_ylabel('Count')
plt.tight_layout()
plt.show()6. ONNX Export
import onnxruntime as ort
# Export to ONNX
dummy_input = torch.randn(1, 3, 32, 32)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=["image"],
output_names=["logits"],
dynamic_axes={"image": {0: "batch_size"}},
opset_version=17,
)
onnx_size = os.path.getsize("model.onnx")
print(f"ONNX model size: {onnx_size / 1024:.1f} KB")
# Load with ONNX Runtime
opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession("model.onnx", opts)
# Test inference
input_np = np.random.randn(1, 3, 32, 32).astype(np.float32)
result = session.run(None, {"image": input_np})
print(f"ONNX output shape: {result[0].shape}")# Benchmark: PyTorch vs ONNX Runtime
batch_np = np.random.randn(32, 3, 32, 32).astype(np.float32)
batch_torch = torch.from_numpy(batch_np)
# PyTorch
for _ in range(20):
with torch.no_grad():
model(batch_torch)
start = time.perf_counter()
for _ in range(200):
with torch.no_grad():
model(batch_torch)
pytorch_time = (time.perf_counter() - start) / 200 * 1000
# ONNX Runtime
for _ in range(20):
session.run(None, {"image": batch_np})
start = time.perf_counter()
for _ in range(200):
session.run(None, {"image": batch_np})
onnx_time = (time.perf_counter() - start) / 200 * 1000
print(f"PyTorch: {pytorch_time:.2f} ms/batch")
print(f"ONNX Runtime: {onnx_time:.2f} ms/batch")
print(f"Speedup: {pytorch_time / onnx_time:.2f}x")
# Verify outputs match
with torch.no_grad():
pytorch_out = model(batch_torch).numpy()
onnx_out = session.run(None, {"image": batch_np})[0]
print(f"Max difference: {np.abs(pytorch_out - onnx_out).max():.8f}")7. Final Comparison
# Collect all results
results = {
"FP32 (PyTorch)": {"size_kb": fp32_size / 1024, "time_ms": fp32_time},
"INT8 (Dynamic)": {"size_kb": int8_size / 1024, "time_ms": int8_time},
"ONNX Runtime": {"size_kb": onnx_size / 1024, "time_ms": onnx_time},
}
print("\n" + "=" * 60)
print(f"{'Method':<25} {'Size (KB)':>12} {'Latency (ms)':>15} {'Speedup':>10}")
print("=" * 60)
baseline_time = results["FP32 (PyTorch)"]["time_ms"]
for name, data in results.items():
speedup = baseline_time / data["time_ms"]
print(f"{name:<25} {data['size_kb']:>12.1f} {data['time_ms']:>15.2f} {speedup:>9.2f}x")
print("=" * 60)# Visualization: size vs speed
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
names = list(results.keys())
sizes = [results[n]["size_kb"] for n in names]
times = [results[n]["time_ms"] for n in names]
colors = ['steelblue', 'coral', 'mediumseagreen']
# Model size
bars = axes[0].bar(names, sizes, color=colors, edgecolor='black', linewidth=0.5)
axes[0].set_ylabel('Model Size (KB)')
axes[0].set_title('Model Size Comparison')
for bar, val in zip(bars, sizes):
axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2,
f'{val:.0f}', ha='center', fontsize=10)
# Inference time
bars = axes[1].bar(names, times, color=colors, edgecolor='black', linewidth=0.5)
axes[1].set_ylabel('Inference Time (ms/batch)')
axes[1].set_title('Inference Speed Comparison (batch=32)')
for bar, val in zip(bars, times):
axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
f'{val:.1f}', ha='center', fontsize=10)
plt.tight_layout()
plt.show()Summary
| Technique | Code Complexity | Size Reduction | Speed Improvement |
|---|---|---|---|
| Dynamic Quantization | 1 line | ~2-4x smaller | ~1.5-3x faster |
| ONNX Export | ~5 lines | Similar | ~1.5-5x faster |
| Static Quantization | ~15 lines + data | ~3-4x smaller | ~2-4x faster |
| Pruning | ~5 lines | ~2-10x smaller | Varies |
| Distillation | Full retraining | Custom | Custom |
Key takeaways: - Always profile before optimizing — find the actual bottleneck - Dynamic quantization is the easiest win (one line of code) - ONNX Runtime gives free speedups via operator fusion - Batching is a free optimization — larger batches = better throughput - Combine techniques for maximum effect: quantize + ONNX + batch