VLM-Guided Image Generation: From Text and Images to Pixels

VLM
image-generation
diffusion
stable-diffusion
CLIP
text-to-image
Author

Nipun Batra

Published

January 3, 2026

VLM-Guided Image Generation

In this notebook, we explore how Vision-Language Models (VLMs) can generate images, not just understand them. We’ll cover:

  1. Text-to-Image Generation: Using Stable Diffusion to generate images from text prompts
  2. VLM-Guided Refinement: Using CLIP/VLM to guide and improve generation quality
  3. Image-Conditioned Generation: Editing, inpainting, variations based on input images
  4. Iterative Refinement: VLM critiques → regenerate → improve

Why VLM-Guided Generation?

Traditional text-to-image models (Stable Diffusion, DALL-E) generate from text alone. But VLMs can: - Verify quality: Check if generated image matches the prompt - Provide feedback: Suggest improvements (“add more detail”, “wrong color”) - Iterative refinement: Generate → critique → refine → repeat - Multi-modal conditioning: Use both text AND reference images

Architecture Overview

┌─────────────────────────────────────────────────────────┐
│  VLM-Guided Image Generation Pipeline                  │
└─────────────────────────────────────────────────────────┘

Input: "A cat wearing a red hat"
   ↓
[Stable Diffusion] → Generated Image v1
   ↓
[VLM Critic] → "The hat is blue, not red. Cat looks good."
   ↓
[Refine Prompt] → "A cat wearing a BRIGHT RED hat, realistic"
   ↓
[Stable Diffusion] → Generated Image v2 (improved!)
   ↓
[VLM Critic] → "Perfect match!"

Let’s get started!

# IMPORTANT: Set GPU before importing PyTorch!
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import torch
import torch.nn as nn
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline, DPMSolverMultistepScheduler
from transformers import AutoProcessor, CLIPModel, AutoTokenizer, AutoModel, AutoImageProcessor
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import gc

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

1. Text-to-Image Generation with Stable Diffusion

We’ll start with Stable Diffusion, the popular open-source text-to-image model. We’ll use the smaller SD 1.5 version for faster generation.

# Load Stable Diffusion 1.5 (smaller, faster)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Loading Stable Diffusion 1.5...")
sd_pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    safety_checker=None,  # Disable for speed
)

# Use faster scheduler (DPM++ 2M)
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(device)

# Enable memory optimizations
if torch.cuda.is_available():
    sd_pipe.enable_attention_slicing()
    sd_pipe.enable_vae_slicing()

print("Stable Diffusion loaded!")
def generate_image(prompt, num_inference_steps=25, guidance_scale=7.5, seed=42):
    """Generate image from text prompt using Stable Diffusion."""
    generator = torch.Generator(device=device).manual_seed(seed)
    
    with torch.inference_mode():
        image = sd_pipe(
            prompt=prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=generator,
        ).images[0]
    
    return image

# Test basic generation
test_prompts = [
    "a photograph of a cat wearing a red hat, high quality",
    "a beautiful sunset over mountains, realistic photography",
    "a robot reading a book in a library, digital art",
]

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for i, prompt in enumerate(test_prompts):
    print(f"Generating: {prompt}")
    image = generate_image(prompt, seed=42+i)
    axes[i].imshow(image)
    axes[i].set_title(prompt[:40] + "...", fontsize=10)
    axes[i].axis('off')
plt.tight_layout()
plt.show()

2. VLM Critic for Quality Assessment

Now let’s build a VLM critic that can: 1. Check if the generated image matches the prompt 2. Identify problems (“wrong color”, “missing objects”) 3. Provide feedback for refinement

We’ll use CLIP for image-text similarity scoring, and optionally a language model for detailed critique.

# Load CLIP for image-text alignment scoring
print("Loading CLIP...")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = clip_model.to(device)
clip_model.eval()
print("CLIP loaded!")
def compute_clip_score(image, text):
    """Compute CLIP similarity between image and text."""
    inputs = clip_processor(
        text=[text],
        images=image,
        return_tensors="pt",
        padding=True
    ).to(device)
    
    with torch.no_grad():
        outputs = clip_model(**inputs)
        # Cosine similarity between image and text embeddings
        logits_per_image = outputs.logits_per_image
        similarity = logits_per_image.item()
    
    return similarity

def critique_image(image, prompt, attributes_to_check=None):
    """Critique generated image against prompt."""
    # Overall similarity
    overall_score = compute_clip_score(image, prompt)
    
    # Check specific attributes if provided
    attribute_scores = {}
    if attributes_to_check:
        for attr in attributes_to_check:
            attr_score = compute_clip_score(image, attr)
            attribute_scores[attr] = attr_score
    
    # Simple critique logic
    critique = {
        "overall_score": overall_score,
        "attribute_scores": attribute_scores,
        "verdict": "Good match!" if overall_score > 25 else "Poor match, needs refinement"
    }
    
    return critique

# Test critique
test_image = generate_image("a cat wearing a red hat", seed=42)
critique = critique_image(
    test_image,
    "a cat wearing a red hat",
    attributes_to_check=[
        "a cat",
        "a red hat",
        "a blue hat",  # Wrong - should score lower
        "a dog",  # Wrong - should score lower
    ]
)

print("\nCritique Results:")
print(f"Overall score: {critique['overall_score']:.2f}")
print(f"Verdict: {critique['verdict']}")
print("\nAttribute scores:")
for attr, score in critique['attribute_scores'].items():
    print(f"  '{attr}': {score:.2f}")

plt.figure(figsize=(5, 5))
plt.imshow(test_image)
plt.title(f"Score: {critique['overall_score']:.2f}\n{critique['verdict']}")
plt.axis('off')
plt.show()

3. Iterative Refinement with VLM Feedback

Now the magic happens! We’ll implement an iterative refinement loop: 1. Generate image from initial prompt 2. VLM critiques the image 3. If score is low, refine the prompt (add emphasis, negative prompts) 4. Generate again with refined prompt 5. Repeat until quality threshold is met

def refine_prompt_based_on_critique(original_prompt, critique, iteration):
    """Refine prompt based on VLM critique."""
    # Simple refinement strategies:
    # 1. Add emphasis keywords
    # 2. Increase detail
    # 3. Add quality boosters
    
    refinements = [
        f"{original_prompt}, highly detailed",
        f"{original_prompt}, high quality, professional photography",
        f"{original_prompt}, 8k, ultra realistic, sharp focus",
    ]
    
    # Use different refinement based on iteration
    refined_prompt = refinements[min(iteration, len(refinements)-1)]
    
    return refined_prompt

def iterative_generate_with_refinement(prompt, max_iterations=3, target_score=28.0, seed=42):
    """Generate image with iterative VLM-guided refinement."""
    history = []
    
    current_prompt = prompt
    
    for iteration in range(max_iterations):
        print(f"\n--- Iteration {iteration + 1} ---")
        print(f"Prompt: {current_prompt}")
        
        # Generate image
        image = generate_image(current_prompt, seed=seed+iteration)
        
        # Critique
        critique = critique_image(image, prompt)
        score = critique['overall_score']
        
        print(f"Score: {score:.2f}")
        print(f"Verdict: {critique['verdict']}")
        
        # Save to history
        history.append({
            "iteration": iteration + 1,
            "prompt": current_prompt,
            "image": image,
            "score": score,
            "critique": critique
        })
        
        # Check if target reached
        if score >= target_score:
            print(f"✓ Target score reached! ({score:.2f} >= {target_score})")
            break
        
        # Refine prompt for next iteration
        if iteration < max_iterations - 1:
            current_prompt = refine_prompt_based_on_critique(prompt, critique, iteration)
    
    return history

# Test iterative refinement
test_prompt = "a red sports car in front of a modern building"
history = iterative_generate_with_refinement(test_prompt, max_iterations=3, target_score=27.0, seed=100)

# Visualize progression
fig, axes = plt.subplots(1, len(history), figsize=(5*len(history), 5))
if len(history) == 1:
    axes = [axes]

for i, entry in enumerate(history):
    axes[i].imshow(entry['image'])
    axes[i].set_title(f"Iteration {entry['iteration']}\nScore: {entry['score']:.2f}", fontsize=10)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

print("\n=== Refinement Summary ===")
for entry in history:
    print(f"Iter {entry['iteration']}: Score {entry['score']:.2f} - {entry['prompt'][:60]}...")

4. Image-Conditioned Generation (Inpainting)

Beyond text-to-image, we can do image-conditioned generation: - Inpainting: Fill in masked regions - Image-to-image: Transform existing images - Variations: Generate similar images with changes

Let’s implement inpainting with Stable Diffusion.

# Load inpainting pipeline
print("Loading Stable Diffusion Inpainting...")
inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "runwayml/stable-diffusion-inpainting",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    safety_checker=None,
)
inpaint_pipe = inpaint_pipe.to(device)

if torch.cuda.is_available():
    inpaint_pipe.enable_attention_slicing()
    inpaint_pipe.enable_vae_slicing()

print("Inpainting pipeline loaded!")
def create_circular_mask(image_size, center, radius):
    """Create a circular mask for inpainting."""
    mask = Image.new('L', image_size, 0)
    draw = ImageDraw.Draw(mask)
    draw.ellipse(
        [
            (center[0] - radius, center[1] - radius),
            (center[0] + radius, center[1] + radius)
        ],
        fill=255
    )
    return mask

def inpaint_image(image, mask, prompt, num_inference_steps=25, seed=42):
    """Inpaint masked region of image based on prompt."""
    generator = torch.Generator(device=device).manual_seed(seed)
    
    # Resize to 512x512 (SD requirement)
    image = image.resize((512, 512))
    mask = mask.resize((512, 512))
    
    with torch.inference_mode():
        result = inpaint_pipe(
            prompt=prompt,
            image=image,
            mask_image=mask,
            num_inference_steps=num_inference_steps,
            generator=generator,
        ).images[0]
    
    return result

# Generate a base image
base_image = generate_image("a room with white walls and wooden floor", seed=200)
base_image = base_image.resize((512, 512))

# Create mask for center region
mask = create_circular_mask((512, 512), center=(256, 256), radius=100)

# Inpaint with different objects
inpaint_prompts = [
    "a red chair in the center of the room",
    "a green plant in a pot",
    "a modern lamp",
]

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Show original and mask
axes[0, 0].imshow(base_image)
axes[0, 0].set_title("Original Image")
axes[0, 0].axis('off')

axes[0, 1].imshow(mask, cmap='gray')
axes[0, 1].set_title("Mask (region to inpaint)")
axes[0, 1].axis('off')

axes[0, 2].axis('off')

# Generate inpainted versions
for i, prompt in enumerate(inpaint_prompts):
    print(f"Inpainting: {prompt}")
    inpainted = inpaint_image(base_image, mask, prompt, seed=200+i)
    axes[1, i].imshow(inpainted)
    axes[1, i].set_title(f"Inpainted: {prompt}", fontsize=9)
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

5. VLM-Guided Inpainting with Quality Control

Combine inpainting with VLM critique for quality-controlled image editing: 1. User specifies region to edit and desired content 2. Inpaint the region 3. VLM checks if edit matches description 4. If not, try again with refined prompt

def vlm_guided_inpainting(image, mask, prompt, max_attempts=3, target_score=25.0, seed=42):
    """Inpaint with VLM quality control."""
    attempts = []
    
    current_prompt = prompt
    
    for attempt in range(max_attempts):
        print(f"\n--- Attempt {attempt + 1} ---")
        print(f"Prompt: {current_prompt}")
        
        # Inpaint
        result = inpaint_image(image, mask, current_prompt, seed=seed+attempt)
        
        # Critique the inpainted result
        critique = critique_image(result, prompt)
        score = critique['overall_score']
        
        print(f"Score: {score:.2f}")
        
        attempts.append({
            "attempt": attempt + 1,
            "prompt": current_prompt,
            "result": result,
            "score": score
        })
        
        if score >= target_score:
            print(f"✓ Quality threshold reached!")
            break
        
        # Refine prompt
        if attempt < max_attempts - 1:
            current_prompt = f"{prompt}, highly detailed, realistic"
    
    # Return best result
    best = max(attempts, key=lambda x: x['score'])
    return best, attempts

# Test VLM-guided inpainting
test_inpaint_prompt = "a beautiful red rose in a vase"
best, attempts = vlm_guided_inpainting(
    base_image,
    mask,
    test_inpaint_prompt,
    max_attempts=2,
    target_score=24.0,
    seed=300
)

# Visualize attempts
fig, axes = plt.subplots(1, len(attempts)+2, figsize=(5*(len(attempts)+2), 5))

axes[0].imshow(base_image)
axes[0].set_title("Original")
axes[0].axis('off')

axes[1].imshow(mask, cmap='gray')
axes[1].set_title("Mask")
axes[1].axis('off')

for i, att in enumerate(attempts):
    axes[i+2].imshow(att['result'])
    marker = "✓" if att == best else ""
    axes[i+2].set_title(f"Attempt {att['attempt']} {marker}\nScore: {att['score']:.2f}", fontsize=9)
    axes[i+2].axis('off')

plt.tight_layout()
plt.show()

6. Multi-Prompt Generation and Selection

Another strategy: Generate multiple variations with different prompts/seeds, then use VLM to select the best one.

def generate_and_select_best(base_prompt, num_variations=4, seeds=None):
    """Generate multiple variations and select best using VLM."""
    if seeds is None:
        seeds = list(range(num_variations))
    
    # Prompt variations (different phrasings)
    prompt_variations = [
        base_prompt,
        f"{base_prompt}, professional photography",
        f"{base_prompt}, digital art, trending on artstation",
        f"{base_prompt}, ultra realistic, 8k, detailed",
    ]
    
    candidates = []
    
    print(f"Generating {num_variations} variations...")
    for i in range(num_variations):
        prompt = prompt_variations[i % len(prompt_variations)]
        seed = seeds[i]
        
        print(f"  Variation {i+1}: seed={seed}")
        image = generate_image(prompt, seed=seed, num_inference_steps=20)
        
        # Score against base prompt
        score = compute_clip_score(image, base_prompt)
        
        candidates.append({
            "index": i,
            "prompt": prompt,
            "seed": seed,
            "image": image,
            "score": score
        })
    
    # Select best
    best = max(candidates, key=lambda x: x['score'])
    print(f"\n✓ Best: Variation {best['index']+1} (score: {best['score']:.2f})")
    
    return best, candidates

# Test multi-generation selection
test_prompt = "a magical forest with glowing mushrooms"
best, candidates = generate_and_select_best(test_prompt, num_variations=4, seeds=[10, 20, 30, 40])

# Visualize all candidates
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

for i, cand in enumerate(candidates):
    axes[i].imshow(cand['image'])
    marker = "★ BEST" if cand == best else ""
    axes[i].set_title(f"Variation {cand['index']+1} {marker}\nScore: {cand['score']:.2f}\nSeed: {cand['seed']}", fontsize=9)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

7. Advanced: Caption-Based Improvement Loop

We can use a captioning VLM to describe the generated image, then compare the caption to the original prompt to identify discrepancies.

# Load a simple captioning model (BLIP)
from transformers import BlipProcessor, BlipForConditionalGeneration

print("Loading BLIP for captioning...")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = blip_model.to(device)
blip_model.eval()
print("BLIP loaded!")
def caption_image(image):
    """Generate caption for image using BLIP."""
    inputs = blip_processor(image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        generated_ids = blip_model.generate(**inputs, max_length=50)
        caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
    
    return caption

def caption_based_refinement(prompt, max_iterations=3, seed=42):
    """Generate with caption-based feedback loop."""
    history = []
    current_prompt = prompt
    
    for iteration in range(max_iterations):
        print(f"\n--- Iteration {iteration + 1} ---")
        print(f"Prompt: {current_prompt}")
        
        # Generate
        image = generate_image(current_prompt, seed=seed+iteration, num_inference_steps=20)
        
        # Caption the result
        generated_caption = caption_image(image)
        print(f"Generated caption: '{generated_caption}'")
        
        # Score alignment
        score = compute_clip_score(image, prompt)
        print(f"CLIP score: {score:.2f}")
        
        history.append({
            "iteration": iteration + 1,
            "prompt": current_prompt,
            "image": image,
            "caption": generated_caption,
            "score": score
        })
        
        # Check if satisfactory
        if score > 27.0:
            print("✓ Good quality achieved!")
            break
        
        # Refine: combine original prompt with generated caption insights
        if iteration < max_iterations - 1:
            current_prompt = f"{prompt}, {generated_caption}, high quality"
    
    return history

# Test caption-based refinement
test_prompt = "a cyberpunk city at night with neon lights"
history = caption_based_refinement(test_prompt, max_iterations=2, seed=500)

# Visualize
fig, axes = plt.subplots(1, len(history), figsize=(7*len(history), 7))
if len(history) == 1:
    axes = [axes]

for i, entry in enumerate(history):
    axes[i].imshow(entry['image'])
    axes[i].set_title(
        f"Iteration {entry['iteration']}\n"
        f"Score: {entry['score']:.2f}\n"
        f"Caption: {entry['caption'][:50]}...",
        fontsize=9
    )
    axes[i].axis('off')

plt.tight_layout()
plt.show()

Summary and Comparison

Let’s compare all the VLM-guided generation strategies we’ve explored:

import pandas as pd

strategies = {
    "Strategy": [
        "Basic Text-to-Image",
        "Iterative Refinement",
        "Multi-Variation Selection",
        "VLM-Guided Inpainting",
        "Caption-Based Loop",
    ],
    "Description": [
        "Single generation from text prompt",
        "Generate → critique → refine prompt → regenerate",
        "Generate N variations → VLM selects best",
        "Inpaint region → VLM quality check → retry if needed",
        "Generate → caption → compare caption to prompt → refine",
    ],
    "Pros": [
        "Fast, simple",
        "Improves quality iteratively",
        "Explores diverse options",
        "Precise region control",
        "Detailed feedback via captions",
    ],
    "Cons": [
        "No quality control",
        "Multiple generations needed",
        "Computationally expensive (N generations)",
        "Requires mask definition",
        "Caption may miss details",
    ],
    "Best Use Case": [
        "Quick prototyping",
        "High-quality single image",
        "Finding best from options",
        "Editing specific regions",
        "Understanding generation failures",
    ],
}

df = pd.DataFrame(strategies)
print("\nVLM-Guided Generation Strategies Comparison:")
print(df.to_string(index=False))

Key Takeaways

1. VLM as Quality Gatekeeper

  • Use CLIP scores to verify image-text alignment
  • Reject low-quality generations automatically
  • Iterate until quality threshold is met

2. Iterative Refinement is Powerful

  • Simple prompt refinement can significantly improve results
  • 2-3 iterations usually enough to reach quality goals
  • Each iteration learns from previous failures

3. Multi-Variation Selection

  • Generate diverse options (different seeds/prompts)
  • Let VLM pick the best match
  • More expensive but explores solution space better

4. Image-Conditioned Generation

  • Inpainting allows precise region control
  • VLM can verify edits match descriptions
  • Useful for iterative image editing workflows

5. Caption-Based Feedback

  • Captions reveal what the model actually generated
  • Compare caption to prompt to identify gaps
  • Helps diagnose systematic generation failures

Practical Recommendations

For production systems: - Use multi-variation selection with VLM scoring - Set quality thresholds based on your use case - Log scores and prompts for analysis

For user-facing apps: - Show multiple variations, let users + VLM co-select - Implement “regenerate” with automatic refinement - Use inpainting for editing workflows

For research: - Use caption-based loops to understand model behavior - Experiment with different VLM critics (CLIP, BLIP, custom) - Try adversarial refinement (VLM as discriminator)

Next Steps

To extend this work: 1. Fine-tune VLM critic: Train on your specific domain for better scoring 2. Learned refinement: Train a model to predict prompt refinements 3. Multi-modal fusion: Combine CLIP scores, captions, and object detection 4. Interactive editing: Let users provide feedback alongside VLM 5. Efficiency: Use distilled/quantized VLMs for faster critique

Further Reading: - Stable Diffusion: https://github.com/Stability-AI/stablediffusion - CLIP: https://github.com/openai/CLIP - BLIP: https://github.com/salesforce/BLIP - Diffusers library: https://github.com/huggingface/diffusers

GPU Memory Cleanup

# Delete all models
del sd_pipe
del inpaint_pipe
del clip_model
del blip_model

# Clear cache
gc.collect()
torch.cuda.empty_cache()

print("GPU memory cleared!")
if torch.cuda.is_available():
    print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")