VLM Compression: Quantization, Pruning, and Distillation

VLM
compression
quantization
pruning
distillation
optimization
Author

Nipun Batra

Published

January 2, 2026

VLM Compression: Making Models Smaller and Faster

In this notebook, we’ll explore model compression techniques to make Vision-Language Models (VLMs) more efficient without significantly sacrificing performance. We’ll demonstrate three powerful techniques:

  1. Quantization: Converting FP32 weights to INT8 → 4x size reduction
  2. Structured Pruning: Removing entire attention heads/FFN dimensions → 2x speedup
  3. Knowledge Distillation: Training a smaller student model to mimic a larger teacher → Better accuracy than training from scratch

Why Compress VLMs?

  • Deployment constraints: Mobile devices, edge computing, limited memory
  • Cost reduction: Lower inference costs, faster processing
  • Energy efficiency: Less compute = less power consumption
  • Real-time applications: Faster response times for interactive use

Our Baseline Model

We’ll use our image captioning VLM as the baseline: - Vision encoder: ViT-Base (85M params) - Language model: SmolLM-135M (135M params) - Projector: Linear layer (2M params) - Total: ~222M parameters, ~888 MB (FP32)

Target Compression Results

Technique Model Size Inference Speed Accuracy Drop
Baseline (FP32) 888 MB 1.0x 0%
INT8 Quantization 222 MB 1.5-2x <2%
Pruning (50%) 444 MB 2x <5%
Distillation (smaller) 300 MB 2.5x <3%
Combined 150 MB 3-4x <7%

Let’s get started!

# IMPORTANT: Set GPU before importing PyTorch!
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
from datasets import load_dataset
from PIL import Image
import numpy as np
from tqdm import tqdm
import time
import gc
import copy

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
PyTorch version: 2.9.1+cu128
CUDA available: True
CUDA device: NVIDIA RTX A4000
GPU memory: 15.72 GB

Load Pre-trained VLM (Baseline)

We’ll load our pre-trained image captioning model as the baseline. If you don’t have it, we’ll train a quick version.

class SimpleVLM(nn.Module):
    """Simple Vision-Language Model for image captioning."""
    def __init__(self, vision_model_name="google/vit-base-patch16-224", 
                 language_model_name="HuggingFaceTB/SmolLM-135M"):
        super().__init__()
        
        # Vision encoder
        self.vision_encoder = AutoModel.from_pretrained(vision_model_name)
        self.vision_hidden_size = self.vision_encoder.config.hidden_size  # 768 for ViT-Base
        
        # Language model
        self.language_model = AutoModel.from_pretrained(language_model_name)
        self.language_hidden_size = self.language_model.config.hidden_size  # 576 for SmolLM
        
        # Projection layer
        self.projector = nn.Linear(self.vision_hidden_size, self.language_hidden_size)
        
        # LM head for generation
        self.lm_head = nn.Linear(self.language_hidden_size, 
                                 self.language_model.config.vocab_size, bias=False)
        
    def forward(self, pixel_values, input_ids, attention_mask=None, labels=None):
        # Encode image
        vision_outputs = self.vision_encoder(pixel_values=pixel_values)
        image_features = vision_outputs.last_hidden_state  # [batch, 197, 768]
        
        # Project to language space
        image_embeds = self.projector(image_features)  # [batch, 197, 576]
        
        # Get text embeddings (SmolLM/Llama uses embed_tokens, not embeddings.word_embeddings)
        text_embeds = self.language_model.embed_tokens(input_ids)  # [batch, seq_len, 576]
        
        # Concatenate image and text embeddings
        combined_embeds = torch.cat([image_embeds, text_embeds], dim=1)  # [batch, 197+seq_len, 576]
        
        # Create attention mask for combined sequence
        batch_size = pixel_values.shape[0]
        image_attention = torch.ones(batch_size, image_embeds.shape[1], 
                                     dtype=torch.long, device=pixel_values.device)
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        combined_attention = torch.cat([image_attention, attention_mask], dim=1)
        
        # Forward through language model
        outputs = self.language_model(
            inputs_embeds=combined_embeds,
            attention_mask=combined_attention,
            output_hidden_states=True,
            return_dict=True
        )
        
        hidden_states = outputs.last_hidden_state
        logits = self.lm_head(hidden_states)
        
        loss = None
        if labels is not None:
            # Shift logits and labels for next-token prediction
            shift_logits = logits[:, image_embeds.shape[1]:-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            
            # Compute loss
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 
                          shift_labels.view(-1))
        
        return {"loss": loss, "logits": logits, "hidden_states": hidden_states}
    
    def count_parameters(self):
        """Count total trainable parameters."""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def get_model_size_mb(self):
        """Calculate model size in MB."""
        param_size = sum(p.numel() * p.element_size() for p in self.parameters())
        buffer_size = sum(b.numel() * b.element_size() for b in self.buffers())
        return (param_size + buffer_size) / 1024**2
# Initialize baseline model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
baseline_model = SimpleVLM().to(device)

print(f"Baseline Model Statistics:")
print(f"  Total parameters: {baseline_model.count_parameters() / 1e6:.1f}M")
print(f"  Model size: {baseline_model.get_model_size_mb():.1f} MB")
print(f"\nComponent breakdown:")
print(f"  Vision encoder: {sum(p.numel() for p in baseline_model.vision_encoder.parameters()) / 1e6:.1f}M")
print(f"  Language model: {sum(p.numel() for p in baseline_model.language_model.parameters()) / 1e6:.1f}M")
print(f"  Projector: {sum(p.numel() for p in baseline_model.projector.parameters()) / 1e6:.1f}M")
print(f"  LM head: {sum(p.numel() for p in baseline_model.lm_head.parameters()) / 1e6:.1f}M")
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Baseline Model Statistics:
  Total parameters: 249.7M
  Model size: 952.4 MB

Component breakdown:
  Vision encoder: 86.4M
  Language model: 134.5M
  Projector: 0.4M
  LM head: 28.3M

Prepare Dataset for Evaluation

We’ll use a small subset of COCO for quick evaluation and benchmarking.

# Load tokenizer and image processor
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
tokenizer.pad_token = tokenizer.eos_token
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")

# Load Flickr8k dataset (simpler, no trust_remote_code needed)
dataset = load_dataset("jxie/flickr8k", split="train")
dataset = dataset.shuffle(seed=42).select(range(500))  # Small subset for speed
print(f"Loaded {len(dataset)} examples for evaluation")
Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.
Loaded 500 examples for evaluation
class CaptionDataset(Dataset):
    def __init__(self, dataset, image_processor, tokenizer, max_length=64):
        self.dataset = dataset
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image'].convert('RGB')
        # flickr8k uses caption_0, caption_1, etc. - use first caption
        caption = item['caption_0']
        
        # Process image
        pixel_values = self.image_processor(image, return_tensors="pt")["pixel_values"][0]
        
        # Process caption
        prompt = "Caption: "
        full_text = f"{prompt}{caption}{self.tokenizer.eos_token}"
        
        encoding = self.tokenizer(
            full_text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        input_ids = encoding["input_ids"][0]
        attention_mask = encoding["attention_mask"][0]
        
        # Create labels (mask prompt tokens)
        labels = input_ids.clone()
        prompt_len = len(self.tokenizer(prompt, add_special_tokens=False)["input_ids"])
        labels[:prompt_len] = -100
        labels[attention_mask == 0] = -100
        
        return {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "caption": caption
        }

eval_dataset = CaptionDataset(dataset, image_processor, tokenizer)
eval_loader = DataLoader(eval_dataset, batch_size=8, shuffle=False)
print(f"Created DataLoader with {len(eval_loader)} batches")
Created DataLoader with 63 batches

Baseline Evaluation

Let’s measure baseline performance: inference speed and loss.

def evaluate_model(model, dataloader, num_batches=50):
    """Evaluate model performance."""
    model.eval()
    total_loss = 0
    total_time = 0
    num_samples = 0
    
    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader, total=num_batches, desc="Evaluating")):
            if i >= num_batches:
                break
            
            pixel_values = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            # Measure inference time
            start_time = time.time()
            outputs = model(pixel_values, input_ids, attention_mask, labels)
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            end_time = time.time()
            
            total_loss += outputs["loss"].item() * pixel_values.size(0)
            total_time += (end_time - start_time)
            num_samples += pixel_values.size(0)
    
    avg_loss = total_loss / num_samples
    avg_time_per_sample = total_time / num_samples
    throughput = num_samples / total_time
    
    return {
        "loss": avg_loss,
        "time_per_sample": avg_time_per_sample,
        "throughput": throughput
    }

print("Evaluating baseline model...")
baseline_results = evaluate_model(baseline_model, eval_loader)
print(f"\nBaseline Results:")
print(f"  Loss: {baseline_results['loss']:.4f}")
print(f"  Time per sample: {baseline_results['time_per_sample']*1000:.2f} ms")
print(f"  Throughput: {baseline_results['throughput']:.2f} samples/sec")
Evaluating baseline model...
Evaluating: 100%|██████████| 50/50 [00:09<00:00,  5.13it/s]

Baseline Results:
  Loss: 11.1261
  Time per sample: 17.07 ms
  Throughput: 58.57 samples/sec

1. INT8 Quantization

Quantization converts FP32 (32-bit floating point) weights to INT8 (8-bit integers), reducing model size by 4x with minimal accuracy loss.

How it works:

  • Weight quantization: Convert weights to INT8
  • Activation quantization: Convert activations to INT8
  • Calibration: Compute scale/zero-point from sample data
  • Inference: Use INT8 arithmetic (much faster on modern hardware)

We’ll use PyTorch’s dynamic quantization for simplicity.

# Create quantized model (must be on CPU for dynamic quantization)
baseline_model_cpu = copy.deepcopy(baseline_model).cpu()
baseline_model_cpu.eval()

# Apply dynamic quantization to linear layers
# Note: Dynamic quantization only works on CPU - this is a PyTorch limitation
quantized_model = torch.quantization.quantize_dynamic(
    baseline_model_cpu,
    {nn.Linear},  # Quantize all Linear layers
    dtype=torch.qint8
)

print(f"\nQuantized Model Statistics:")
print(f"  Model size: {quantized_model.get_model_size_mb():.1f} MB")
print(f"  Size reduction: {baseline_model.get_model_size_mb() / quantized_model.get_model_size_mb():.2f}x")
print(f"\nNote: Dynamic quantization runs on CPU only (PyTorch limitation).")
/tmp/ipykernel_825286/4220383649.py:7: DeprecationWarning: torch.ao.quantization is deprecated and will be removed in 2.10. 
For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  quantized_model = torch.quantization.quantize_dynamic(

Quantized Model Statistics:
  Model size: 111.1 MB
  Size reduction: 8.57x

Note: Dynamic quantization runs on CPU only (PyTorch limitation).
# Evaluate quantized model on CPU
def evaluate_model_cpu(model, dataloader, num_batches=20):
    """Evaluate model on CPU (for quantized models)."""
    model.eval()
    total_loss = 0
    total_time = 0
    num_samples = 0
    
    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader, total=num_batches, desc="Evaluating (CPU)")):
            if i >= num_batches:
                break
            
            # Keep data on CPU for quantized model
            pixel_values = batch["pixel_values"]  # Already on CPU from dataloader
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            labels = batch["labels"]
            
            # Measure inference time
            start_time = time.time()
            outputs = model(pixel_values, input_ids, attention_mask, labels)
            end_time = time.time()
            
            total_loss += outputs["loss"].item() * pixel_values.size(0)
            total_time += (end_time - start_time)
            num_samples += pixel_values.size(0)
    
    avg_loss = total_loss / num_samples
    avg_time_per_sample = total_time / num_samples
    throughput = num_samples / total_time
    
    return {
        "loss": avg_loss,
        "time_per_sample": avg_time_per_sample,
        "throughput": throughput
    }

print("Evaluating quantized model on CPU...")
print("(Note: Quantized models run on CPU; comparing CPU vs GPU isn't apples-to-apples)")
quantized_results = evaluate_model_cpu(quantized_model, eval_loader, num_batches=10)

print(f"\nQuantized Results (CPU):")
print(f"  Loss: {quantized_results['loss']:.4f} (vs baseline {baseline_results['loss']:.4f})")
print(f"  Loss change: {(quantized_results['loss'] - baseline_results['loss']) / baseline_results['loss'] * 100:+.2f}%")
print(f"  Time per sample: {quantized_results['time_per_sample']*1000:.2f} ms (CPU)")
print(f"  Throughput: {quantized_results['throughput']:.2f} samples/sec (CPU)")
print(f"\nKey benefit: Model size reduced from {baseline_model.get_model_size_mb():.0f} MB to {quantized_model.get_model_size_mb():.0f} MB ({baseline_model.get_model_size_mb() / quantized_model.get_model_size_mb():.1f}x smaller)")
Evaluating quantized model on CPU...
(Note: Quantized models run on CPU; comparing CPU vs GPU isn't apples-to-apples)
Evaluating (CPU): 100%|██████████| 10/10 [00:14<00:00,  1.41s/it]

Quantized Results (CPU):
  Loss: 11.2568 (vs baseline 11.1261)
  Loss change: +1.17%
  Time per sample: 166.04 ms (CPU)
  Throughput: 6.02 samples/sec (CPU)

Key benefit: Model size reduced from 952 MB to 111 MB (8.6x smaller)

2. Structured Pruning

Pruning removes unimportant weights or entire structures (attention heads, FFN dimensions) to reduce model size and computation.

Types of Pruning:

  • Unstructured: Remove individual weights → Sparse matrix (requires special hardware)
  • Structured: Remove entire rows/columns/heads → Dense smaller matrix (works on standard hardware)

We’ll implement structured pruning by removing attention heads and FFN dimensions based on importance scores.

def compute_head_importance(model, dataloader, num_batches=20):
    """Compute importance scores for each attention head."""
    model.eval()
    
    # Get number of layers and heads (SmolLM/Llama uses 'layers' not 'encoder.layer')
    num_layers = len(model.language_model.layers)
    num_heads = model.language_model.config.num_attention_heads
    
    # Initialize importance scores
    head_importance = torch.zeros(num_layers, num_heads).to(device)
    
    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader, total=num_batches, desc="Computing importance")):
            if i >= num_batches:
                break
            
            pixel_values = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            
            # Forward pass
            outputs = model(pixel_values, input_ids, attention_mask)
            
            # Compute attention scores for each head
            # This is a simplified importance metric based on attention magnitudes
            for layer_idx in range(num_layers):
                # Get attention weights from this layer
                # In practice, you'd need to extract attention weights from the model
                # For now, we'll use a placeholder random importance
                head_importance[layer_idx] += torch.rand(num_heads).to(device)
    
    # Normalize
    head_importance /= num_batches
    
    return head_importance

def prune_attention_heads(model, head_importance, prune_ratio=0.3):
    """Prune least important attention heads."""
    # Flatten importance scores
    flat_importance = head_importance.view(-1)
    num_heads_total = flat_importance.size(0)
    num_to_prune = int(num_heads_total * prune_ratio)
    
    # Get indices of heads to prune (lowest importance)
    _, indices_to_prune = torch.topk(flat_importance, num_to_prune, largest=False)
    
    print(f"Pruning {num_to_prune} out of {num_heads_total} attention heads ({prune_ratio*100:.0f}%)")
    
    # In practice, you would modify the model architecture here
    # For demonstration, we'll create a mask
    head_mask = torch.ones_like(flat_importance)
    head_mask[indices_to_prune] = 0
    
    return head_mask.view_as(head_importance)

# Compute head importance
print("Computing attention head importance...")
head_importance = compute_head_importance(baseline_model, eval_loader)

# Prune 30% of heads
head_mask = prune_attention_heads(baseline_model, head_importance, prune_ratio=0.3)
print(f"Remaining heads: {head_mask.sum().item():.0f} / {head_mask.numel()}")
Computing attention head importance...
Computing importance: 100%|██████████| 20/20 [00:04<00:00,  4.98it/s]
Pruning 81 out of 270 attention heads (30%)
Remaining heads: 189 / 270

Note: Full structured pruning implementation requires modifying model architecture. The above demonstrates the concept. In practice, you’d: 1. Identify unimportant heads/dimensions 2. Create a new model with reduced dimensions 3. Copy over important weights 4. Fine-tune the pruned model

3. Knowledge Distillation

Knowledge distillation trains a smaller “student” model to mimic a larger “teacher” model. The student learns from: - Soft labels: Teacher’s output probabilities (contains more information than hard labels) - Hidden states: Teacher’s intermediate representations

Distillation Loss:

L = α * L_hard(student, labels) + (1-α) * L_soft(student, teacher)
L_soft = KL_divergence(student_logits/T, teacher_logits/T)

Where T is temperature (softens probability distributions).

class TinyVLM(nn.Module):
    """Smaller student VLM for distillation."""
    def __init__(self):
        super().__init__()
        
        # Use smaller vision encoder (or distilled version)
        self.vision_encoder = AutoModel.from_pretrained("google/vit-base-patch16-224")
        self.vision_hidden_size = 768
        
        # Use even smaller language model
        self.language_model = AutoModel.from_pretrained("HuggingFaceTB/SmolLM-135M")
        self.language_hidden_size = 576
        
        # In practice, you'd use a smaller LM, but for demo we'll use same size
        # Real distillation would use SmolLM-35M or custom smaller architecture
        
        self.projector = nn.Linear(self.vision_hidden_size, self.language_hidden_size)
        self.lm_head = nn.Linear(self.language_hidden_size, 
                                 self.language_model.config.vocab_size, bias=False)
    
    def forward(self, pixel_values, input_ids, attention_mask=None, labels=None):
        # Same forward as SimpleVLM
        vision_outputs = self.vision_encoder(pixel_values=pixel_values)
        image_features = vision_outputs.last_hidden_state
        image_embeds = self.projector(image_features)
        
        # SmolLM/Llama uses embed_tokens
        text_embeds = self.language_model.embed_tokens(input_ids)
        combined_embeds = torch.cat([image_embeds, text_embeds], dim=1)
        
        batch_size = pixel_values.shape[0]
        image_attention = torch.ones(batch_size, image_embeds.shape[1], 
                                     dtype=torch.long, device=pixel_values.device)
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        combined_attention = torch.cat([image_attention, attention_mask], dim=1)
        
        outputs = self.language_model(
            inputs_embeds=combined_embeds,
            attention_mask=combined_attention,
            output_hidden_states=True,
            return_dict=True
        )
        
        hidden_states = outputs.last_hidden_state
        logits = self.lm_head(hidden_states)
        
        loss = None
        if labels is not None:
            shift_logits = logits[:, image_embeds.shape[1]:-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 
                          shift_labels.view(-1))
        
        return {"loss": loss, "logits": logits, "hidden_states": hidden_states}
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def get_model_size_mb(self):
        param_size = sum(p.numel() * p.element_size() for p in self.parameters())
        buffer_size = sum(b.numel() * b.element_size() for b in self.buffers())
        return (param_size + buffer_size) / 1024**2

# Create student model
student_model = TinyVLM().to(device)
print(f"\nStudent Model Statistics:")
print(f"  Total parameters: {student_model.count_parameters() / 1e6:.1f}M")
print(f"  Model size: {student_model.get_model_size_mb():.1f} MB")
print(f"  Size vs teacher: {baseline_model.get_model_size_mb() / student_model.get_model_size_mb():.2f}x smaller")
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Student Model Statistics:
  Total parameters: 249.7M
  Model size: 952.4 MB
  Size vs teacher: 1.00x smaller
def distillation_loss(student_logits, teacher_logits, labels, temperature=2.0, alpha=0.5):
    """Compute distillation loss combining hard and soft targets."""
    # Hard loss (standard cross-entropy with true labels)
    loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    hard_loss = loss_fct(student_logits.view(-1, student_logits.size(-1)), labels.view(-1))
    
    # Soft loss (KL divergence with teacher)
    # Apply temperature to soften distributions
    student_soft = F.log_softmax(student_logits / temperature, dim=-1)
    teacher_soft = F.softmax(teacher_logits / temperature, dim=-1)
    
    # KL divergence
    soft_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (temperature ** 2)
    
    # Combine losses
    total_loss = alpha * hard_loss + (1 - alpha) * soft_loss
    
    return total_loss, hard_loss, soft_loss

def train_distillation(student, teacher, train_loader, num_epochs=3, lr=1e-4, 
                      temperature=2.0, alpha=0.5):
    """Train student model with knowledge distillation."""
    teacher.eval()  # Teacher in eval mode
    student.train()
    
    optimizer = torch.optim.AdamW(student.parameters(), lr=lr)
    
    losses = []
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for batch in pbar:
            pixel_values = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            # Get teacher predictions (no gradients)
            with torch.no_grad():
                teacher_outputs = teacher(pixel_values, input_ids, attention_mask)
                teacher_logits = teacher_outputs["logits"]
            
            # Get student predictions
            student_outputs = student(pixel_values, input_ids, attention_mask)
            student_logits = student_outputs["logits"]
            
            # Align logits (skip image tokens)
            image_seq_len = teacher_outputs["hidden_states"].shape[1] - input_ids.shape[1]
            student_logits_aligned = student_logits[:, image_seq_len:-1, :]
            teacher_logits_aligned = teacher_logits[:, image_seq_len:-1, :]
            labels_aligned = labels[:, 1:]
            
            # Compute distillation loss
            loss, hard_loss, soft_loss = distillation_loss(
                student_logits_aligned, teacher_logits_aligned, labels_aligned,
                temperature=temperature, alpha=alpha
            )
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            pbar.set_postfix({"loss": loss.item(), "hard": hard_loss.item(), "soft": soft_loss.item()})
        
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        print(f"Epoch {epoch+1} - Average loss: {avg_loss:.4f}")
    
    return losses

print("\nNote: Full distillation training would take several hours.")
print("For demonstration, we'll show the training setup.")
print("In practice, you would run: train_distillation(student_model, baseline_model, train_loader)")

Note: Full distillation training would take several hours.
For demonstration, we'll show the training setup.
In practice, you would run: train_distillation(student_model, baseline_model, train_loader)

Combined Compression: Quantization + Distillation

For maximum compression, we can combine techniques: 1. Train smaller student via distillation 2. Quantize the student to INT8 3. Optional: Prune the student before quantization

This achieves: - 8-10x size reduction (smaller architecture + INT8) - 3-4x speedup (fewer params + INT8 ops) - <10% accuracy drop (distillation preserves knowledge)

# After distillation training, quantize the student
# student_model_trained = train_distillation(...)  # Would train here

# For demo, quantize untrained student
student_quantized = torch.quantization.quantize_dynamic(
    copy.deepcopy(student_model),
    {nn.Linear},
    dtype=torch.qint8
)

print(f"\nCombined Compression Results:")
print(f"  Original model: {baseline_model.get_model_size_mb():.1f} MB")
print(f"  Distilled student: {student_model.get_model_size_mb():.1f} MB")
print(f"  Quantized student: {student_quantized.get_model_size_mb():.1f} MB")
print(f"  Total compression: {baseline_model.get_model_size_mb() / student_quantized.get_model_size_mb():.2f}x")

Compression Comparison Table

Let’s summarize all compression techniques:

import pandas as pd

# Create comparison table
comparison_data = {
    "Method": [
        "Baseline (FP32)",
        "INT8 Quantization",
        "Structured Pruning (30%)",
        "Knowledge Distillation",
        "Quantized Student",
    ],
    "Model Size (MB)": [
        f"{baseline_model.get_model_size_mb():.1f}",
        f"{quantized_model.get_model_size_mb():.1f}",
        f"{baseline_model.get_model_size_mb() * 0.7:.1f}",  # Estimated
        f"{student_model.get_model_size_mb():.1f}",
        f"{student_quantized.get_model_size_mb():.1f}",
    ],
    "Size Reduction": [
        "1.0x",
        f"{baseline_model.get_model_size_mb() / quantized_model.get_model_size_mb():.1f}x",
        "1.4x",
        f"{baseline_model.get_model_size_mb() / student_model.get_model_size_mb():.1f}x",
        f"{baseline_model.get_model_size_mb() / student_quantized.get_model_size_mb():.1f}x",
    ],
    "Expected Speedup": [
        "1.0x",
        "1.5-2x",
        "2x",
        "1.5x",
        "3-4x",
    ],
    "Expected Accuracy Drop": [
        "0%",
        "<2%",
        "3-5%",
        "<3%",
        "<7%",
    ],
    "Implementation Complexity": [
        "N/A",
        "Easy (1 line)",
        "Hard (architecture changes)",
        "Medium (training required)",
        "Easy (distill then quantize)",
    ]
}

df_comparison = pd.DataFrame(comparison_data)
print("\nCompression Techniques Comparison:")
print(df_comparison.to_string(index=False))

Key Takeaways

1. Quantization (Easiest)

  • ✅ One-line implementation with PyTorch
  • ✅ 4x size reduction automatically
  • ✅ 1.5-2x speedup on CPU (less on GPU)
  • ⚠️ Requires INT8 hardware support for best speedup

2. Pruning (Most Complex)

  • ✅ Can achieve 2-3x speedup with structured pruning
  • ✅ Works on standard hardware
  • ❌ Requires architecture modifications
  • ❌ Needs fine-tuning after pruning

3. Distillation (Best Accuracy)

  • ✅ Smaller model with minimal accuracy loss
  • ✅ Learns from teacher’s knowledge
  • ❌ Requires training time and data
  • ❌ Needs pre-trained teacher model

4. Combined (Best Overall)

  • ✅ 8-10x compression with <10% accuracy drop
  • ✅ 3-4x speedup
  • ✅ Production-ready for deployment
  • Recommended: Distill → Quantize

Practical Recommendations

For quick deployment: - Use INT8 quantization (easiest, good results)

For maximum compression: - Train smaller student via distillation - Then quantize to INT8

For maximum speed: - Combine pruning + quantization - Deploy on hardware with INT8 acceleration

For edge devices: - All three techniques combined - Consider even more aggressive compression (INT4, 4-bit quantization)

GPU Memory Cleanup

After running compression experiments, free up GPU memory:

# Delete all models
del baseline_model
del quantized_model
del student_model
del student_quantized

# Clear PyTorch cache
gc.collect()
torch.cuda.empty_cache()

print("GPU memory cleared!")
if torch.cuda.is_available():
    print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

Next Steps

To apply these techniques to your own VLMs:

  1. Start with quantization: Quick wins with minimal effort
  2. Try distillation: If you have training data and compute
  3. Explore pruning: For maximum customization and speedup
  4. Benchmark carefully: Measure size, speed, AND accuracy
  5. Deploy incrementally: Test compressed models before production

Further Reading: - PyTorch Quantization: https://pytorch.org/docs/stable/quantization.html - DistilBERT paper: https://arxiv.org/abs/1910.01108 - The Lottery Ticket Hypothesis (pruning): https://arxiv.org/abs/1803.03635 - QLoRA (4-bit quantization + LoRA): https://arxiv.org/abs/2305.14314