Model Profiling & Quantization — Notebook

Week 13 · CS 203 · Software Tools and Techniques for AI

Prof. Nipun Batra · IIT Gandhinagar

This notebook covers: 1. Model architecture and parameter counting 2. Profiling (timing, torch.profiler, memory) 3. Dynamic quantization 4. ONNX export and inference 5. Comparing all optimizations

# Install dependencies (uncomment if needed)
# !pip install torch torchvision onnx onnxruntime numpy matplotlib
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
import os

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

plt.style.use('seaborn-v0_8-whitegrid')

1. Define Model

class SimpleCNN(nn.Module):
    """Small CNN for CIFAR-10-like images (3x32x32)."""
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
        )
        self.classifier = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

model = SimpleCNN()
model.eval()
print(model)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters:     {total_params:>10,}")
print(f"Trainable parameters: {trainable_params:>10,}")
print(f"Model size (FP32):    {total_params * 4 / 1024:>10.1f} KB")
print(f"Model size (FP16):    {total_params * 2 / 1024:>10.1f} KB")
print(f"Model size (INT8):    {total_params * 1 / 1024:>10.1f} KB")

# Per-layer breakdown
print("\nPer-layer breakdown:")
print(f"{'Layer':<35} {'Shape':<25} {'Params':>10} {'Size (KB)':>10}")
print("-" * 85)
for name, param in model.named_parameters():
    n = param.numel()
    print(f"{name:<35} {str(list(param.shape)):<25} {n:>10,} {n * 4 / 1024:>10.1f}")

2. Profiling: Timing

def benchmark(model, input_data, n_runs=200, label="Model"):
    """Benchmark inference time with proper warmup."""
    # Warmup
    for _ in range(20):
        with torch.no_grad():
            model(input_data)
    
    # Timed runs
    start = time.perf_counter()
    for _ in range(n_runs):
        with torch.no_grad():
            model(input_data)
    elapsed = (time.perf_counter() - start) / n_runs * 1000
    
    print(f"{label:<30} {elapsed:>8.2f} ms/batch")
    return elapsed

# Benchmark at different batch sizes
print("Batch size effects on per-sample latency:\n")
batch_times = {}
for bs in [1, 4, 16, 32, 64, 128]:
    x = torch.randn(bs, 3, 32, 32)
    ms = benchmark(model, x, label=f"Batch size {bs}")
    batch_times[bs] = ms / bs  # per-sample time
# Visualize batch size effect
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(list(batch_times.keys()), list(batch_times.values()), 'o-', markersize=8)
ax.set_xlabel('Batch Size')
ax.set_ylabel('Time per Sample (ms)')
ax.set_title('Batching Reduces Per-Sample Latency')
ax.set_xscale('log', base=2)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

3. Profiling: torch.profiler

from torch.profiler import profile, ProfilerActivity

input_data = torch.randn(32, 3, 32, 32)

with profile(
    activities=[ProfilerActivity.CPU],
    record_shapes=True,
    profile_memory=True,
) as prof:
    with torch.no_grad():
        for _ in range(10):
            output = model(input_data)

print("Top operations by CPU time:")
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=15))
print("Top operations by memory usage:")
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))

4. Memory Profiling

# Track activation sizes through the network
activation_sizes = []

def hook_fn(name):
    def hook(module, input, output):
        if isinstance(output, torch.Tensor):
            size_kb = output.numel() * output.element_size() / 1024
            activation_sizes.append((name, list(output.shape), size_kb))
    return hook

hooks = []
for name, module in model.named_modules():
    if len(list(module.children())) == 0:  # leaf modules
        hooks.append(module.register_forward_hook(hook_fn(name)))

with torch.no_grad():
    _ = model(torch.randn(1, 3, 32, 32))

for h in hooks:
    h.remove()

print(f"{'Layer':<35} {'Output Shape':<25} {'Size (KB)':>10}")
print("-" * 75)
total_act = 0
for name, shape, size in activation_sizes:
    print(f"{name:<35} {str(shape):<25} {size:>10.1f}")
    total_act += size
print("-" * 75)
print(f"{'TOTAL':<35} {'':25} {total_act:>10.1f}")
# Visualize: weights vs activations
weight_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024

fig, ax = plt.subplots(figsize=(6, 4))
bars = ax.bar(['Weights', 'Activations (bs=1)'], [weight_size, total_act], color=['steelblue', 'coral'])
ax.set_ylabel('Memory (KB)')
ax.set_title('Memory Breakdown: Weights vs Activations')
for bar, val in zip(bars, [weight_size, total_act]):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 5, f'{val:.0f} KB',
            ha='center', fontsize=11)
plt.tight_layout()
plt.show()

5. Dynamic Quantization

# Save original model
torch.save(model.state_dict(), "model_fp32.pth")
fp32_size = os.path.getsize("model_fp32.pth")

# Dynamic quantization (quantize Linear layers)
quantized_model = torch.quantization.quantize_dynamic(
    model,
    {nn.Linear},
    dtype=torch.qint8,
)

torch.save(quantized_model.state_dict(), "model_int8.pth")
int8_size = os.path.getsize("model_int8.pth")

print(f"FP32 model size: {fp32_size / 1024:.1f} KB")
print(f"INT8 model size: {int8_size / 1024:.1f} KB")
print(f"Compression:     {fp32_size / int8_size:.2f}x smaller")
# Benchmark: FP32 vs INT8
input_data = torch.randn(32, 3, 32, 32)

print("Inference speed comparison:")
fp32_time = benchmark(model, input_data, label="FP32 (original)")
int8_time = benchmark(quantized_model, input_data, label="INT8 (quantized)")
print(f"\nSpeedup: {fp32_time / int8_time:.2f}x")
# Compare outputs
with torch.no_grad():
    fp32_out = model(input_data)
    int8_out = quantized_model(input_data)

fp32_preds = fp32_out.argmax(dim=1)
int8_preds = int8_out.argmax(dim=1)
agreement = (fp32_preds == int8_preds).float().mean()
max_diff = (fp32_out - int8_out).abs().max().item()
mean_diff = (fp32_out - int8_out).abs().mean().item()

print(f"Prediction agreement: {agreement:.1%}")
print(f"Max output difference: {max_diff:.6f}")
print(f"Mean output difference: {mean_diff:.6f}")
# Visualize weight distributions before/after quantization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# FP32 weights
fp32_weights = []
for p in model.parameters():
    fp32_weights.extend(p.detach().numpy().flatten())
axes[0].hist(fp32_weights, bins=100, alpha=0.7, color='steelblue', edgecolor='black', linewidth=0.5)
axes[0].set_title(f'FP32 Weight Distribution\n(model size: {fp32_size/1024:.0f} KB)')
axes[0].set_xlabel('Weight Value')
axes[0].set_ylabel('Count')

# Output difference distribution
diffs = (fp32_out - int8_out).detach().numpy().flatten()
axes[1].hist(diffs, bins=50, alpha=0.7, color='coral', edgecolor='black', linewidth=0.5)
axes[1].set_title(f'Quantization Error Distribution\n(max: {max_diff:.4f}, mean: {mean_diff:.4f})')
axes[1].set_xlabel('FP32 output - INT8 output')
axes[1].set_ylabel('Count')

plt.tight_layout()
plt.show()

6. ONNX Export

import onnxruntime as ort

# Export to ONNX
dummy_input = torch.randn(1, 3, 32, 32)
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["image"],
    output_names=["logits"],
    dynamic_axes={"image": {0: "batch_size"}},
    opset_version=17,
)

onnx_size = os.path.getsize("model.onnx")
print(f"ONNX model size: {onnx_size / 1024:.1f} KB")

# Load with ONNX Runtime
opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession("model.onnx", opts)

# Test inference
input_np = np.random.randn(1, 3, 32, 32).astype(np.float32)
result = session.run(None, {"image": input_np})
print(f"ONNX output shape: {result[0].shape}")
# Benchmark: PyTorch vs ONNX Runtime
batch_np = np.random.randn(32, 3, 32, 32).astype(np.float32)
batch_torch = torch.from_numpy(batch_np)

# PyTorch
for _ in range(20):
    with torch.no_grad():
        model(batch_torch)
start = time.perf_counter()
for _ in range(200):
    with torch.no_grad():
        model(batch_torch)
pytorch_time = (time.perf_counter() - start) / 200 * 1000

# ONNX Runtime
for _ in range(20):
    session.run(None, {"image": batch_np})
start = time.perf_counter()
for _ in range(200):
    session.run(None, {"image": batch_np})
onnx_time = (time.perf_counter() - start) / 200 * 1000

print(f"PyTorch:      {pytorch_time:.2f} ms/batch")
print(f"ONNX Runtime: {onnx_time:.2f} ms/batch")
print(f"Speedup:      {pytorch_time / onnx_time:.2f}x")

# Verify outputs match
with torch.no_grad():
    pytorch_out = model(batch_torch).numpy()
onnx_out = session.run(None, {"image": batch_np})[0]
print(f"Max difference: {np.abs(pytorch_out - onnx_out).max():.8f}")

7. Final Comparison

# Collect all results
results = {
    "FP32 (PyTorch)": {"size_kb": fp32_size / 1024, "time_ms": fp32_time},
    "INT8 (Dynamic)": {"size_kb": int8_size / 1024, "time_ms": int8_time},
    "ONNX Runtime": {"size_kb": onnx_size / 1024, "time_ms": onnx_time},
}

print("\n" + "=" * 60)
print(f"{'Method':<25} {'Size (KB)':>12} {'Latency (ms)':>15} {'Speedup':>10}")
print("=" * 60)
baseline_time = results["FP32 (PyTorch)"]["time_ms"]
for name, data in results.items():
    speedup = baseline_time / data["time_ms"]
    print(f"{name:<25} {data['size_kb']:>12.1f} {data['time_ms']:>15.2f} {speedup:>9.2f}x")
print("=" * 60)
# Visualization: size vs speed
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

names = list(results.keys())
sizes = [results[n]["size_kb"] for n in names]
times = [results[n]["time_ms"] for n in names]
colors = ['steelblue', 'coral', 'mediumseagreen']

# Model size
bars = axes[0].bar(names, sizes, color=colors, edgecolor='black', linewidth=0.5)
axes[0].set_ylabel('Model Size (KB)')
axes[0].set_title('Model Size Comparison')
for bar, val in zip(bars, sizes):
    axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2,
                 f'{val:.0f}', ha='center', fontsize=10)

# Inference time
bars = axes[1].bar(names, times, color=colors, edgecolor='black', linewidth=0.5)
axes[1].set_ylabel('Inference Time (ms/batch)')
axes[1].set_title('Inference Speed Comparison (batch=32)')
for bar, val in zip(bars, times):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                 f'{val:.1f}', ha='center', fontsize=10)

plt.tight_layout()
plt.show()

Summary

Technique Code Complexity Size Reduction Speed Improvement
Dynamic Quantization 1 line ~2-4x smaller ~1.5-3x faster
ONNX Export ~5 lines Similar ~1.5-5x faster
Static Quantization ~15 lines + data ~3-4x smaller ~2-4x faster
Pruning ~5 lines ~2-10x smaller Varies
Distillation Full retraining Custom Custom

Key takeaways: - Always profile before optimizing — find the actual bottleneck - Dynamic quantization is the easiest win (one line of code) - ONNX Runtime gives free speedups via operator fusion - Batching is a free optimization — larger batches = better throughput - Combine techniques for maximum effect: quantize + ONNX + batch