Building a Minimal Vision-Language Model from Scratch

VLM
vision-language
multimodal
deep-learning
transformers
educational
Author

Nipun Batra

Published

December 24, 2025

Introduction

Vision-Language Models (VLMs) have revolutionized how AI systems understand and reason about images. Models like GPT-4V, LLaVA, and Gemini can describe images, answer questions about them, and even follow complex visual instructions.

But how do these models actually work? At their core, VLMs combine three key components:

  1. Vision Encoder: Converts images into meaningful feature representations
  2. Projection Layer: Bridges the gap between vision and language embedding spaces
  3. Language Model: Generates text conditioned on the visual features

In this notebook, we’ll build a minimal VLM from scratch using small, publicly available models: - Vision: Google’s ViT-Large (304M parameters) - Language: HuggingFace’s SmolLM-360M (360M parameters) - Dataset: Flickr8k (a small subset for educational purposes)

The goal is educational - understanding the architecture, not building a production model.

Setup

!uv pip install -q transformers datasets torch torchvision pillow accelerate einops timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    ViTModel,
    ViTImageProcessor,
)
from datasets import load_dataset
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Print torch version
print(f"Torch version :{torch.__version__}")
/home/nipun.batra/.uv/nb-base/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Using device: cuda
Torch version :2.6.0+cu124
%config InlineBackend.figure_format = 'retina'

The vision encoder’s job is to convert an image into a sequence of meaningful feature vectors. We’ll use a pretrained Vision Transformer (ViT) which:

  1. Splits the image into 16x16 patches
  2. Projects each patch into an embedding
  3. Processes patches through transformer layers
  4. Outputs a sequence of features (one per patch + a CLS token)

For a 224x224 image with 16x16 patches, we get: - (224/16) x (224/16) = 14 x 14 = 196 patches - Plus 1 CLS token = 197 tokens total - Each token has dimension 1024 (for ViT-Large)

# Load pretrained ViT-Large (larger than ViT-Base for better performance)
vision_model_name = "google/vit-large-patch16-224"
vision_encoder = ViTModel.from_pretrained(vision_model_name)
image_processor = ViTImageProcessor.from_pretrained(vision_model_name)

# Freeze the vision encoder - we won't train it
for param in vision_encoder.parameters():
    param.requires_grad = False

vision_encoder = vision_encoder.to(device)
vision_encoder.eval()

print(f"Vision encoder hidden size: {vision_encoder.config.hidden_size}")
print(f"Number of patches: {(224//16)**2} + 1 CLS token = {(224//16)**2 + 1} tokens")
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-large-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Vision encoder hidden size: 1024
Number of patches: 196 + 1 CLS token = 197 tokens
# Test the vision encoder with a sample image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
response = requests.get(url)
sample_image = Image.open(BytesIO(response.content))

plt.figure(figsize=(6, 6))
plt.imshow(sample_image)
plt.title("Sample Image from COCO")
plt.axis('off')
plt.show()

# Process the image
inputs = image_processor(sample_image, return_tensors="pt").to(device)

with torch.no_grad():
    vision_outputs = vision_encoder(**inputs)

# Get the sequence of patch features (excluding pooler output)
image_features = vision_outputs.last_hidden_state
print(f"Image features shape: {image_features.shape}")
print(f"  - Batch size: {image_features.shape[0]}")
print(f"  - Sequence length (patches + CLS): {image_features.shape[1]}")
print(f"  - Hidden dimension: {image_features.shape[2]}")

Image features shape: torch.Size([1, 197, 1024])
  - Batch size: 1
  - Sequence length (patches + CLS): 197
  - Hidden dimension: 1024

Part 2: The Language Model

For the language model, we’ll use SmolLM-360M - a small but capable language model from HuggingFace. It has: - 360M parameters (larger than 135M, better performance) - Hidden dimension of 960 - Can run on modest hardware

The key insight is that we need to project our vision features (1024-dim from ViT-Large) into the LLM’s embedding space (960-dim).

# Load the language model (SmolLM-360M for better performance)
lm_model_name = "HuggingFaceTB/SmolLM-360M"
tokenizer = AutoTokenizer.from_pretrained(lm_model_name)
language_model = AutoModelForCausalLM.from_pretrained(lm_model_name)

# Add padding token if not present
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

language_model = language_model.to(device)

print(f"Language model hidden size: {language_model.config.hidden_size}")
print(f"Vocabulary size: {language_model.config.vocab_size}")
print(f"Number of layers: {language_model.config.num_hidden_layers}")
Language model hidden size: 960
Vocabulary size: 49152
Number of layers: 32
# Test the language model
test_text = "A photo of"
inputs = tokenizer(test_text, return_tensors="pt").to(device)

with torch.no_grad():
    outputs = language_model.generate(
        **inputs,
        max_new_tokens=20,
        do_sample=True,
        temperature=0.7,
        pad_token_id=tokenizer.pad_token_id
    )

print(f"Generated text: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
Generated text: A photo of a school playground with a huge number of kids playing with their friends. Kids are using their hands to

Part 3: The Projection Layer

The projection layer is the bridge between vision and language. It transforms: - From: Vision features (batch, 197, 1024) - To: Language-compatible embeddings (batch, 197, 960)

We’ll use a simple multi-layer projection with: 1. Linear layer (1024 -> 960) 2. GELU activation 3. LayerNorm for stability

This is similar to what LLaVA uses, but simpler.

class VisionProjector(nn.Module):
    """Projects vision features into the language model's embedding space."""
    
    def __init__(self, vision_dim: int, language_dim: int):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Linear(vision_dim, language_dim),
            nn.GELU(),
            nn.LayerNorm(language_dim),
            nn.Linear(language_dim, language_dim),  # Additional layer for better alignment
        )
    
    def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
        """Project vision features to language embedding space.
        
        Args:
            vision_features: (batch, seq_len, vision_dim)
        Returns:
            projected: (batch, seq_len, language_dim)
        """
        return self.projection(vision_features)


# Create projector
vision_dim = vision_encoder.config.hidden_size  # 1024 (ViT-Large)
language_dim = language_model.config.hidden_size  # 960 (SmolLM-360M)

projector = VisionProjector(vision_dim, language_dim).to(device)

print(f"Projector: {vision_dim} -> {language_dim}")
print(f"Trainable parameters: {sum(p.numel() for p in projector.parameters()):,}")
Projector: 1024 -> 960
Trainable parameters: 1,908,480
# Test the projector
with torch.no_grad():
    projected_features = projector(image_features)

print(f"Original vision features: {image_features.shape}")
print(f"Projected features: {projected_features.shape}")
Original vision features: torch.Size([1, 197, 1024])
Projected features: torch.Size([1, 197, 960])

Part 4: The Complete VLM Architecture

Now let’s combine everything into a single model. The forward pass:

  1. Encode image -> vision features (batch, 197, 768)
  2. Project -> language-space features (batch, 197, 576)
  3. Embed text -> text embeddings (batch, text_len, 576)
  4. Concatenate -> [image_embeds, text_embeds] (batch, 197 + text_len, 576)
  5. Generate -> use LLM to generate caption tokens
class MiniVLM(nn.Module):
    """A minimal Vision-Language Model."""
    
    def __init__(
        self,
        vision_encoder: ViTModel,
        language_model: AutoModelForCausalLM,
        projector: VisionProjector,
        tokenizer: AutoTokenizer,
    ):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.language_model = language_model
        self.projector = projector
        self.tokenizer = tokenizer
        
        # Freeze vision encoder
        for param in self.vision_encoder.parameters():
            param.requires_grad = False
    
    def encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """Encode image and project to language space."""
        with torch.no_grad():
            vision_outputs = self.vision_encoder(pixel_values=pixel_values)
        image_features = vision_outputs.last_hidden_state
        projected = self.projector(image_features)
        return projected
    
    def forward(
        self,
        pixel_values: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor = None,
    ):
        """Forward pass for training.
        
        Args:
            pixel_values: Image tensor (batch, 3, 224, 224)
            input_ids: Text token IDs (batch, text_len)
            attention_mask: Attention mask for text (batch, text_len)
            labels: Target token IDs for loss computation (batch, text_len)
        """
        batch_size = pixel_values.shape[0]
        
        # 1. Encode and project image
        image_embeds = self.encode_image(pixel_values)  # (batch, 197, hidden)
        num_image_tokens = image_embeds.shape[1]
        
        # 2. Get text embeddings from LLM's embedding layer
        text_embeds = self.language_model.get_input_embeddings()(input_ids)  # (batch, text_len, hidden)
        
        # 3. Concatenate: [IMAGE TOKENS] [TEXT TOKENS]
        combined_embeds = torch.cat([image_embeds, text_embeds], dim=1)
        
        # 4. Create attention mask for combined sequence
        image_attention = torch.ones(
            (batch_size, num_image_tokens),
            dtype=attention_mask.dtype,
            device=attention_mask.device
        )
        combined_attention = torch.cat([image_attention, attention_mask], dim=1)
        
        # 5. Create labels: -100 for image tokens (ignore in loss), actual labels for text
        if labels is not None:
            image_labels = torch.full(
                (batch_size, num_image_tokens),
                fill_value=-100,  # -100 is ignored by CrossEntropyLoss
                dtype=labels.dtype,
                device=labels.device
            )
            combined_labels = torch.cat([image_labels, labels], dim=1)
        else:
            combined_labels = None
        
        # 6. Forward through language model
        outputs = self.language_model(
            inputs_embeds=combined_embeds,
            attention_mask=combined_attention,
            labels=combined_labels,
            return_dict=True,
        )
        
        return outputs
    
    @torch.no_grad()
    def generate(
        self,
        pixel_values: torch.Tensor,
        max_new_tokens: int = 50,
        temperature: float = 0.7,
        do_sample: bool = True,
    ) -> str:
        """Generate a caption for an image."""
        self.eval()
        
        # Encode image
        image_embeds = self.encode_image(pixel_values)  # (1, 197, hidden)
        
        # Start with a prompt
        prompt = "This image shows"
        prompt_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(pixel_values.device)
        prompt_embeds = self.language_model.get_input_embeddings()(prompt_ids)
        
        # Combine image and prompt embeddings
        combined_embeds = torch.cat([image_embeds, prompt_embeds], dim=1)
        
        # Generate token by token
        generated_ids = prompt_ids.clone()
        
        for _ in range(max_new_tokens):
            # Get current embeddings
            current_embeds = self.language_model.get_input_embeddings()(generated_ids)
            full_embeds = torch.cat([image_embeds, current_embeds], dim=1)
            
            # Forward pass
            outputs = self.language_model(inputs_embeds=full_embeds)
            next_token_logits = outputs.logits[:, -1, :]
            
            # Sample or greedy
            if do_sample:
                probs = F.softmax(next_token_logits / temperature, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                next_token = next_token_logits.argmax(dim=-1, keepdim=True)
            
            generated_ids = torch.cat([generated_ids, next_token], dim=1)
            
            # Stop if EOS
            if next_token.item() == self.tokenizer.eos_token_id:
                break
        
        return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)


# Create the VLM
vlm = MiniVLM(
    vision_encoder=vision_encoder,
    language_model=language_model,
    projector=projector,
    tokenizer=tokenizer,
)
vlm = vlm.to(device)

# Count parameters
total_params = sum(p.numel() for p in vlm.parameters())
trainable_params = sum(p.numel() for p in vlm.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {total_params - trainable_params:,}")
Total parameters: 668,080,832
Trainable parameters: 363,729,600
Frozen parameters: 304,351,232

Part 5: Loading the Flickr8k Dataset

We’ll use a subset of the Flickr8k dataset - a popular image captioning benchmark with 8000 images. Each image has 5 captions - we’ll randomly sample one per image during training.

# Load Flickr8k captions dataset
dataset = load_dataset("jxie/flickr8k", split="train")

# Use 2000 samples for better training
num_samples = 2000
dataset = dataset.shuffle(seed=42).select(range(num_samples))

print(f"Dataset size: {len(dataset)}")
print(f"Sample item keys: {dataset[0].keys()}")
Generating train split: 100%|██████████| 6000/6000 [00:01<00:00, 4764.89 examples/s]
Generating validation split: 100%|██████████| 1000/1000 [00:00<00:00, 4347.67 examples/s]
Generating test split: 100%|██████████| 1000/1000 [00:00<00:00, 5419.11 examples/s]
Dataset size: 2000
Sample item keys: dict_keys(['image', 'caption_0', 'caption_1', 'caption_2', 'caption_3', 'caption_4'])
# Look at a sample - Flickr8k has caption_0 through caption_4
sample = dataset[0]
print(f"Image: {type(sample['image'])}")
print(f"Captions:")
for i in range(5):
    print(f"  {i}: {sample[f'caption_{i}']}")

plt.figure(figsize=(6, 6))
plt.imshow(sample['image'])
plt.title(f"Caption: {sample['caption_0'][:60]}...")
plt.axis('off')
plt.show()
Image: <class 'PIL.JpegImagePlugin.JpegImageFile'>
Captions:
  0: Boys with their backs against an incoming wave .
  1: Four boys are about to be hit by an approaching wave .
  2: Four people sitting in the path of a wave .
  3: Four young men are sitting on the beach under a crashing wave .
  4: three boys sitting in sand getting splashed by wave

import random

class Flickr8kDataset(Dataset):
    """Dataset for Flickr8k image-caption pairs."""
    
    def __init__(self, hf_dataset, image_processor, tokenizer, max_length=64):
        self.dataset = hf_dataset
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Process image
        image = item['image'].convert('RGB')
        pixel_values = self.image_processor(image, return_tensors="pt").pixel_values.squeeze(0)
        
        # Randomly select one of the 5 captions
        caption_idx = random.randint(0, 4)
        caption = item[f'caption_{caption_idx}']
        
        # Tokenize caption
        encoding = self.tokenizer(
            caption,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        
        # Labels: mask padding tokens with -100
        labels = input_ids.clone()
        labels[attention_mask == 0] = -100
        
        return {
            'pixel_values': pixel_values,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
        }


# Create dataset and dataloader
train_dataset = Flickr8kDataset(dataset, image_processor, tokenizer)

train_loader = DataLoader(
    train_dataset,
    batch_size=4,  # Small batch for memory
    shuffle=True,
    num_workers=0,  # Avoid multiprocessing issues
)

print(f"Number of batches: {len(train_loader)}")
Number of batches: 500
# Test a batch
batch = next(iter(train_loader))
print(f"Batch pixel_values shape: {batch['pixel_values'].shape}")
print(f"Batch input_ids shape: {batch['input_ids'].shape}")
print(f"Batch labels shape: {batch['labels'].shape}")
Batch pixel_values shape: torch.Size([4, 3, 224, 224])
Batch input_ids shape: torch.Size([4, 64])
Batch labels shape: torch.Size([4, 64])

Part 6: Untrained Model Predictions (Baseline)

Before training, let’s see what our VLM generates. Since the projection layer has random weights, the model has no way to understand the image - it will just generate random/nonsensical text. This establishes a baseline to appreciate the improvement after training.

def generate_caption(model, image, image_processor, device):
    """Generate a caption for an image."""
    model.eval()
    
    # Process image
    if isinstance(image, str):
        # URL
        response = requests.get(image)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    elif not isinstance(image, Image.Image):
        image = Image.fromarray(image)
    
    pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)
    
    # Generate
    caption = model.generate(
        pixel_values,
        max_new_tokens=30,
        temperature=0.8,
        do_sample=True,
    )
    
    return caption, image
# Store some test images and their ground truth captions for before/after comparison
test_indices = [0, 10, 50, 100]
test_samples = [(dataset[i]['image'].convert('RGB'), dataset[i]['caption_0']) for i in test_indices]

# Generate captions with UNTRAINED model
print("=" * 60)
print("UNTRAINED MODEL PREDICTIONS (Random projector weights)")
print("=" * 60)

untrained_captions = []
for i, (image, gt_caption) in enumerate(test_samples):
    caption, _ = generate_caption(vlm, image, image_processor, device)
    untrained_captions.append(caption)
    print(f"\nImage {i+1}:")
    print(f"  Generated: {caption}")
    print(f"  GT: {gt_caption[:80]}...")
============================================================
UNTRAINED MODEL PREDICTIONS (Random projector weights)
============================================================

Image 1:
  Generated: This image shows cross-hat on the back of the dinosaur, showing that the massive scales are the Z-axis. This is a rebound, not a right-
  GT: Boys with their backs against an incoming wave ....

Image 2:
  Generated: This image shows the two simple types of the intonation of the vowel.
Comments (28) (Comment)

Create and save a GAMS model from
  GT: A child wearing a white shirt hangs from the playground equipment ....

Image 3:
  Generated: This image shows alert WLLK wordword w wordword message w wordw w wordw
K-word w w w w w w w w w w
  GT: Two girls take in the view ....

Image 4:
  Generated: This image shows the former Prinsland Castle at the Pematang Ilir. It is one of the most important and largest 5-meters buildings in
  GT: A black dog barking ....
# Visualize the untrained predictions
import textwrap

def wrap_text(text, width=40):
    """Wrap text for better display in titles."""
    return '\n'.join(textwrap.wrap(text, width=width))

fig, axes = plt.subplots(2, 2, figsize=(14, 16))
axes = axes.flatten()

for i, ((image, gt_caption), ax) in enumerate(zip(test_samples, axes)):
    ax.imshow(image)
    
    # Create wrapped title text
    gen_text = wrap_text(f"UNTRAINED: {untrained_captions[i]}", 45)
    gt_text = wrap_text(f"GT: {gt_caption}", 45)
    
    ax.set_xlabel(f"{gen_text}\n\n{gt_text}", fontsize=9, ha='center')
    ax.set_xticks([])
    ax.set_yticks([])

plt.suptitle("Untrained VLM Predictions (Random noise!)", fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

Part 7: Training Loop

Now let’s train our VLM! We’ll only train: 1. The projection layer (bridges vision to language) 2. The language model (learns to generate captions)

The vision encoder stays frozen.

import os

def train_vlm(model, train_loader, num_epochs=3, lr=1e-4, checkpoint_path="vlm_checkpoint.pt"):
    """Train the VLM with checkpoint support for preemptible/resumable training."""
    
    # Only optimize trainable parameters
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(trainable_params, lr=lr)
    
    # Try to load checkpoint for resumable training
    start_epoch = 0
    losses = []
    
    if os.path.exists(checkpoint_path):
        print(f"Found checkpoint: {checkpoint_path}")
        try:
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
            start_epoch = checkpoint.get('epoch', 0)
            losses = checkpoint.get('losses', [])
            model.projector.load_state_dict(checkpoint['projector_state_dict'])
            model.language_model.load_state_dict(checkpoint['language_model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            # Move optimizer state to device
            for state in optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device)
            print(f"Resumed from epoch {start_epoch}")
        except Exception as e:
            print(f"Could not load checkpoint: {e}")
    
    model.train()
    # But keep vision encoder in eval mode
    model.vision_encoder.eval()
    
    try:
        for epoch in range(start_epoch, num_epochs):
            epoch_loss = 0
            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
            
            for batch in progress_bar:
                # Move batch to device
                pixel_values = batch['pixel_values'].to(device)
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                # Forward pass
                outputs = model(
                    pixel_values=pixel_values,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                )
                
                loss = outputs.loss
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
                optimizer.step()
                
                epoch_loss += loss.item()
                progress_bar.set_postfix({'loss': f"{loss.item():.4f}"})
            
            avg_loss = epoch_loss / len(train_loader)
            losses.append(avg_loss)
            print(f"Epoch {epoch+1} - Average Loss: {avg_loss:.4f}")
            
            # Save checkpoint after each epoch
            torch.save({
                'epoch': epoch + 1,
                'losses': losses,
                'projector_state_dict': model.projector.state_dict(),
                'language_model_state_dict': model.language_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")
    
    except KeyboardInterrupt:
        print("\n" + "="*70)
        print("Training interrupted!")
        print(f"Completed {len(losses)} epochs")
        # Save checkpoint on interrupt
        torch.save({
            'epoch': len(losses),
            'losses': losses,
            'projector_state_dict': model.projector.state_dict(),
            'language_model_state_dict': model.language_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")
        print("Run training again to resume from this checkpoint.")
        print("="*70)
    
    return losses
# Train for more epochs with larger dataset
# Training is preemptible - interrupt with Ctrl+C and re-run to resume
losses = train_vlm(vlm, train_loader, num_epochs=6, lr=2e-4, checkpoint_path="vlm_caption_checkpoint.pt")
Found checkpoint: vlm_caption_checkpoint.pt
Resumed from epoch 2
Epoch 3/6: 100%|██████████| 500/500 [03:46<00:00,  2.21it/s, loss=1.7564]
Epoch 3 - Average Loss: 2.3993
Checkpoint saved to vlm_caption_checkpoint.pt
Epoch 4/6:  67%|██████▋   | 336/500 [03:37<01:46,  1.55it/s, loss=2.3601]

======================================================================
Training interrupted!
Completed 3 epochs
Checkpoint saved to vlm_caption_checkpoint.pt
Run training again to resume from this checkpoint.
======================================================================
# Plot training loss
plt.figure(figsize=(8, 4))
plt.plot(range(1, len(losses)+1), losses, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

Part 8: Trained Model Predictions - Before vs After!

Let’s compare the untrained vs trained model on the same images. You should see a dramatic improvement!

# Generate captions with TRAINED model on the same test images
print("=" * 60)
print("TRAINED MODEL PREDICTIONS (After training)")
print("=" * 60)

trained_captions = []
for i, (image, gt_caption) in enumerate(test_samples):
    caption, _ = generate_caption(vlm, image, image_processor, device)
    trained_captions.append(caption)
    print(f"\nImage {i+1}:")
    print(f"  BEFORE (untrained): {untrained_captions[i]}")
    print(f"  AFTER (trained):    {caption}")
    print(f"  Ground Truth:       {gt_caption[:80]}...")
============================================================
TRAINED MODEL PREDICTIONS (After training)
============================================================

Image 1:
  BEFORE (untrained): This image shows cross-hat on the back of the dinosaur, showing that the massive scales are the Z-axis. This is a rebound, not a right-
  AFTER (trained):    This image shows three people or one-two aged people on a huge wave .A group is sitting on a bench and taking pictures .The two people and a lake
  Ground Truth:       Boys with their backs against an incoming wave ....

Image 2:
  BEFORE (untrained): This image shows the two simple types of the intonation of the vowel.
Comments (28) (Comment)

Create and save a GAMS model from
  AFTER (trained):    This image shows a child in a red dress in jumping up in the air .a boy in a pink striped shirt is skating .A child in blue is climbing over
  Ground Truth:       A child wearing a white shirt hangs from the playground equipment ....

Image 3:
  BEFORE (untrained): This image shows alert WLLK wordword w wordword message w wordw w wordw
K-word w w w w w w w w w w
  AFTER (trained):    This image shows two women standing in front of a building . One is sitting alone and one is sitting in a crowd . a woman and one child reach for the fountain
  Ground Truth:       Two girls take in the view ....

Image 4:
  BEFORE (untrained): This image shows the former Prinsland Castle at the Pematang Ilir. It is one of the most important and largest 5-meters buildings in
  AFTER (trained):    This image shows a dog with a red muzzle in its mouth is running . smiling . in the middle of the field . .A black dog is shaking its ear over
  Ground Truth:       A black dog barking ....
# Side-by-side visual comparison: BEFORE vs AFTER training
fig, axes = plt.subplots(4, 2, figsize=(16, 24))

for i, (image, gt_caption) in enumerate(test_samples):
    # Left column: Untrained
    axes[i, 0].imshow(image)
    untrained_text = wrap_text(f"BEFORE: {untrained_captions[i]}", 40)
    axes[i, 0].set_xlabel(untrained_text, fontsize=9, color='red', ha='center')
    axes[i, 0].set_xticks([])
    axes[i, 0].set_yticks([])
    
    # Right column: Trained  
    axes[i, 1].imshow(image)
    trained_text = wrap_text(f"AFTER: {trained_captions[i]}", 40)
    gt_text = wrap_text(f"GT: {gt_caption}", 40)
    axes[i, 1].set_xlabel(f"{trained_text}\n\n{gt_text}", fontsize=9, color='green', ha='center')
    axes[i, 1].set_xticks([])
    axes[i, 1].set_yticks([])

plt.suptitle("Before vs After Training Comparison", fontsize=16, fontweight='bold', y=1.01)
plt.tight_layout()
plt.show()

# Test on completely new images from the web (not in training set)
test_urls = [
    "http://images.cocodataset.org/val2017/000000039769.jpg",  # cats
    "http://images.cocodataset.org/val2017/000000037777.jpg",  # sports
    "http://images.cocodataset.org/val2017/000000087038.jpg",  # food
]

fig, axes = plt.subplots(1, 3, figsize=(15, 6))

for url, ax in zip(test_urls, axes):
    caption, image = generate_caption(vlm, url, image_processor, device)
    
    ax.imshow(image)
    caption_text = wrap_text(f"Generated: {caption}", 35)
    ax.set_xlabel(caption_text, fontsize=10, ha='center')
    ax.set_xticks([])
    ax.set_yticks([])

plt.suptitle("Trained VLM on New Images (Not in Training Set)", fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

Part 9: Understanding the Architecture

Let’s visualize what’s happening inside our VLM:

# Visualize the flow of information
sample = dataset[0]
image = sample['image'].convert('RGB')
pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)

with torch.no_grad():
    # Step 1: Vision encoding
    vision_outputs = vlm.vision_encoder(pixel_values=pixel_values)
    vision_features = vision_outputs.last_hidden_state
    print(f"Step 1 - Vision Encoder Output: {vision_features.shape}")
    print(f"         Shape meaning: (batch={vision_features.shape[0]}, patches={vision_features.shape[1]}, dim={vision_features.shape[2]})")
    
    # Step 2: Projection
    projected = vlm.projector(vision_features)
    print(f"\nStep 2 - After Projection: {projected.shape}")
    print(f"         Shape meaning: (batch={projected.shape[0]}, patches={projected.shape[1]}, lm_dim={projected.shape[2]})")
    
    # Step 3: Text embedding (example)
    sample_text = "A photo of"
    text_ids = tokenizer.encode(sample_text, return_tensors="pt").to(device)
    text_embeds = vlm.language_model.get_input_embeddings()(text_ids)
    print(f"\nStep 3 - Text Embeddings: {text_embeds.shape}")
    print(f"         Shape meaning: (batch={text_embeds.shape[0]}, tokens={text_embeds.shape[1]}, lm_dim={text_embeds.shape[2]})")
    
    # Step 4: Concatenation
    combined = torch.cat([projected, text_embeds], dim=1)
    print(f"\nStep 4 - Combined Embeddings: {combined.shape}")
    print(f"         Total sequence: {projected.shape[1]} image tokens + {text_embeds.shape[1]} text tokens = {combined.shape[1]}")
Step 1 - Vision Encoder Output: torch.Size([1, 197, 1024])
         Shape meaning: (batch=1, patches=197, dim=1024)

Step 2 - After Projection: torch.Size([1, 197, 960])
         Shape meaning: (batch=1, patches=197, lm_dim=960)

Step 3 - Text Embeddings: torch.Size([1, 3, 960])
         Shape meaning: (batch=1, tokens=3, lm_dim=960)

Step 4 - Combined Embeddings: torch.Size([1, 200, 960])
         Total sequence: 197 image tokens + 3 text tokens = 200
# Visualize attention over patches (simplified)
# This shows which patches the model might be attending to

# Get patch features (excluding CLS token)
patch_features = vision_features[0, 1:, :]  # Remove CLS token
patch_norms = torch.norm(patch_features, dim=-1)  # Feature magnitude per patch

# Reshape to 14x14 grid
patch_grid = patch_norms.reshape(14, 14).cpu().numpy()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

ax1.imshow(image)
ax1.set_title("Original Image")
ax1.axis('off')

im = ax2.imshow(patch_grid, cmap='hot')
ax2.set_title("Vision Feature Magnitude per Patch")
ax2.axis('off')
plt.colorbar(im, ax=ax2, fraction=0.046)

plt.tight_layout()
plt.show()

Part 10: Save Model for Future Use

Let’s save the trained model so we can use it in future notebooks (e.g., for instruction fine-tuning or object detection tasks).

import os

# Create directory for saving the model
save_dir = "mini-vlm-flickr8k"
os.makedirs(save_dir, exist_ok=True)

# Save the projector weights (the main trained component)
torch.save(vlm.projector.state_dict(), f"{save_dir}/projector.pt")

# Save the full VLM state dict (projector + language model weights)
torch.save({
    'projector_state_dict': vlm.projector.state_dict(),
    'language_model_state_dict': vlm.language_model.state_dict(),
    'config': {
        'vision_model_name': vision_model_name,
        'lm_model_name': lm_model_name,
        'vision_dim': vision_dim,
        'language_dim': language_dim,
    }
}, f"{save_dir}/mini_vlm_full.pt")

# Also save the tokenizer and image processor config for easy loading
tokenizer.save_pretrained(f"{save_dir}/tokenizer")
image_processor.save_pretrained(f"{save_dir}/image_processor")

print(f"Model saved to {save_dir}/")
print(f"Contents: {os.listdir(save_dir)}")
Model saved to mini-vlm-flickr8k/
Contents: ['image_processor', 'mini_vlm_full.pt', 'projector.pt', 'tokenizer']
# Helper function to load the model (useful for next notebook)
def load_mini_vlm(save_dir, device='cuda'):
    """Load the saved MiniVLM model."""
    from transformers import AutoModelForCausalLM, AutoTokenizer, ViTModel, ViTImageProcessor
    
    # Load checkpoint
    checkpoint = torch.load(f"{save_dir}/mini_vlm_full.pt", map_location=device)
    config = checkpoint['config']
    
    # Recreate components
    vision_encoder = ViTModel.from_pretrained(config['vision_model_name'])
    language_model = AutoModelForCausalLM.from_pretrained(config['lm_model_name'])
    tokenizer = AutoTokenizer.from_pretrained(f"{save_dir}/tokenizer")
    image_processor = ViTImageProcessor.from_pretrained(f"{save_dir}/image_processor")
    
    # Recreate projector and load weights
    projector = VisionProjector(config['vision_dim'], config['language_dim'])
    projector.load_state_dict(checkpoint['projector_state_dict'])
    
    # Load language model weights
    language_model.load_state_dict(checkpoint['language_model_state_dict'])
    
    # Create VLM
    vlm = MiniVLM(vision_encoder, language_model, projector, tokenizer)
    vlm = vlm.to(device)
    
    return vlm, image_processor, tokenizer

print("Helper function 'load_mini_vlm' defined for loading in future notebooks.")
Helper function 'load_mini_vlm' defined for loading in future notebooks.

Summary

We built a minimal Vision-Language Model with three key components:

Component Model Parameters Trainable
Vision Encoder ViT-Large-16-224 304M No (frozen)
Projection 2-layer MLP ~2M Yes
Language Model SmolLM-360M 360M Yes

Key Takeaways

  1. Architecture: VLMs combine vision encoders and language models with a projection layer
  2. Training: We freeze the vision encoder and train only the projector + LLM
  3. Data: Image-caption pairs (Flickr8k) teach the model to describe images
  4. Scaling: Production VLMs use larger models and much more data
  5. Resumable Training: Training is preemptible - interrupt with Ctrl+C and re-run to resume from checkpoint

Limitations of This Minimal Model

  • Small training set (2000 images) vs millions in production
  • Simple projection (MLP) vs sophisticated cross-attention mechanisms
  • No instruction tuning for chat-like interactions
  • No visual grounding or object detection capabilities

Next Steps

In the next blog post, we’ll take this trained VLM and: 1. Instruction fine-tune it on object detection data 2. Teach it to respond to questions like “What objects are in this image?” 3. Explore visual grounding capabilities

The saved model is in mini-vlm-flickr8k/ and can be loaded with load_mini_vlm().

Cleanup: Release GPU Memory

When done with the notebook, run this cell to free up GPU memory for other tasks.

def cleanup_gpu_memory():
    """Clean up GPU memory by deleting models and clearing cache."""
    import gc
    
    # List of global variables that might hold GPU tensors
    global vlm, vision_encoder, language_model, projector
    global image_features, projected_features, vision_features
    global train_loader, train_dataset, dataset
    
    # Delete model components
    try:
        del vlm
        print("Deleted vlm")
    except NameError:
        pass
    
    try:
        del vision_encoder
        print("Deleted vision_encoder")
    except NameError:
        pass
    
    try:
        del language_model
        print("Deleted language_model")
    except NameError:
        pass
    
    try:
        del projector
        print("Deleted projector")
    except NameError:
        pass
    
    # Delete any cached tensors
    try:
        del image_features
        del projected_features
        del vision_features
    except NameError:
        pass
    
    # Delete data loaders
    try:
        del train_loader
        del train_dataset
        del dataset
        print("Deleted data loaders and dataset")
    except NameError:
        pass
    
    # Force garbage collection
    gc.collect()
    
    # Clear CUDA cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        
        # Print memory stats
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        print(f"\nGPU Memory after cleanup:")
        print(f"  Allocated: {allocated:.2f} GB")
        print(f"  Reserved:  {reserved:.2f} GB")
    
    print("\nGPU memory cleanup complete!")

# Run cleanup
cleanup_gpu_memory()
Deleted vlm
Deleted vision_encoder
Deleted language_model
Deleted projector
Deleted data loaders and dataset

GPU Memory after cleanup:
  Allocated: 0.02 GB
  Reserved:  0.05 GB

GPU memory cleanup complete!