# IMPORTANT: Set GPU before importing PyTorch!
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'VLM-Guided Image Generation
In this notebook, we explore how Vision-Language Models (VLMs) can generate images, not just understand them. We’ll cover:
- Text-to-Image Generation: Using Stable Diffusion to generate images from text prompts
- VLM-Guided Refinement: Using CLIP/VLM to guide and improve generation quality
- Image-Conditioned Generation: Editing, inpainting, variations based on input images
- Iterative Refinement: VLM critiques → regenerate → improve
Why VLM-Guided Generation?
Traditional text-to-image models (Stable Diffusion, DALL-E) generate from text alone. But VLMs can: - Verify quality: Check if generated image matches the prompt - Provide feedback: Suggest improvements (“add more detail”, “wrong color”) - Iterative refinement: Generate → critique → refine → repeat - Multi-modal conditioning: Use both text AND reference images
Architecture Overview
┌─────────────────────────────────────────────────────────┐
│ VLM-Guided Image Generation Pipeline │
└─────────────────────────────────────────────────────────┘
Input: "A cat wearing a red hat"
↓
[Stable Diffusion] → Generated Image v1
↓
[VLM Critic] → "The hat is blue, not red. Cat looks good."
↓
[Refine Prompt] → "A cat wearing a BRIGHT RED hat, realistic"
↓
[Stable Diffusion] → Generated Image v2 (improved!)
↓
[VLM Critic] → "Perfect match!"
Let’s get started!
import torch
import torch.nn as nn
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline, DPMSolverMultistepScheduler
from transformers import AutoProcessor, CLIPModel, AutoTokenizer, AutoModel, AutoImageProcessor
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import gc
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")1. Text-to-Image Generation with Stable Diffusion
We’ll start with Stable Diffusion, the popular open-source text-to-image model. We’ll use the smaller SD 1.5 version for faster generation.
# Load Stable Diffusion 1.5 (smaller, faster)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Loading Stable Diffusion 1.5...")
sd_pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
safety_checker=None, # Disable for speed
)
# Use faster scheduler (DPM++ 2M)
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(device)
# Enable memory optimizations
if torch.cuda.is_available():
sd_pipe.enable_attention_slicing()
sd_pipe.enable_vae_slicing()
print("Stable Diffusion loaded!")def generate_image(prompt, num_inference_steps=25, guidance_scale=7.5, seed=42):
"""Generate image from text prompt using Stable Diffusion."""
generator = torch.Generator(device=device).manual_seed(seed)
with torch.inference_mode():
image = sd_pipe(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
).images[0]
return image
# Test basic generation
test_prompts = [
"a photograph of a cat wearing a red hat, high quality",
"a beautiful sunset over mountains, realistic photography",
"a robot reading a book in a library, digital art",
]
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for i, prompt in enumerate(test_prompts):
print(f"Generating: {prompt}")
image = generate_image(prompt, seed=42+i)
axes[i].imshow(image)
axes[i].set_title(prompt[:40] + "...", fontsize=10)
axes[i].axis('off')
plt.tight_layout()
plt.show()2. VLM Critic for Quality Assessment
Now let’s build a VLM critic that can: 1. Check if the generated image matches the prompt 2. Identify problems (“wrong color”, “missing objects”) 3. Provide feedback for refinement
We’ll use CLIP for image-text similarity scoring, and optionally a language model for detailed critique.
# Load CLIP for image-text alignment scoring
print("Loading CLIP...")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = clip_model.to(device)
clip_model.eval()
print("CLIP loaded!")def compute_clip_score(image, text):
"""Compute CLIP similarity between image and text."""
inputs = clip_processor(
text=[text],
images=image,
return_tensors="pt",
padding=True
).to(device)
with torch.no_grad():
outputs = clip_model(**inputs)
# Cosine similarity between image and text embeddings
logits_per_image = outputs.logits_per_image
similarity = logits_per_image.item()
return similarity
def critique_image(image, prompt, attributes_to_check=None):
"""Critique generated image against prompt."""
# Overall similarity
overall_score = compute_clip_score(image, prompt)
# Check specific attributes if provided
attribute_scores = {}
if attributes_to_check:
for attr in attributes_to_check:
attr_score = compute_clip_score(image, attr)
attribute_scores[attr] = attr_score
# Simple critique logic
critique = {
"overall_score": overall_score,
"attribute_scores": attribute_scores,
"verdict": "Good match!" if overall_score > 25 else "Poor match, needs refinement"
}
return critique
# Test critique
test_image = generate_image("a cat wearing a red hat", seed=42)
critique = critique_image(
test_image,
"a cat wearing a red hat",
attributes_to_check=[
"a cat",
"a red hat",
"a blue hat", # Wrong - should score lower
"a dog", # Wrong - should score lower
]
)
print("\nCritique Results:")
print(f"Overall score: {critique['overall_score']:.2f}")
print(f"Verdict: {critique['verdict']}")
print("\nAttribute scores:")
for attr, score in critique['attribute_scores'].items():
print(f" '{attr}': {score:.2f}")
plt.figure(figsize=(5, 5))
plt.imshow(test_image)
plt.title(f"Score: {critique['overall_score']:.2f}\n{critique['verdict']}")
plt.axis('off')
plt.show()3. Iterative Refinement with VLM Feedback
Now the magic happens! We’ll implement an iterative refinement loop: 1. Generate image from initial prompt 2. VLM critiques the image 3. If score is low, refine the prompt (add emphasis, negative prompts) 4. Generate again with refined prompt 5. Repeat until quality threshold is met
def refine_prompt_based_on_critique(original_prompt, critique, iteration):
"""Refine prompt based on VLM critique."""
# Simple refinement strategies:
# 1. Add emphasis keywords
# 2. Increase detail
# 3. Add quality boosters
refinements = [
f"{original_prompt}, highly detailed",
f"{original_prompt}, high quality, professional photography",
f"{original_prompt}, 8k, ultra realistic, sharp focus",
]
# Use different refinement based on iteration
refined_prompt = refinements[min(iteration, len(refinements)-1)]
return refined_prompt
def iterative_generate_with_refinement(prompt, max_iterations=3, target_score=28.0, seed=42):
"""Generate image with iterative VLM-guided refinement."""
history = []
current_prompt = prompt
for iteration in range(max_iterations):
print(f"\n--- Iteration {iteration + 1} ---")
print(f"Prompt: {current_prompt}")
# Generate image
image = generate_image(current_prompt, seed=seed+iteration)
# Critique
critique = critique_image(image, prompt)
score = critique['overall_score']
print(f"Score: {score:.2f}")
print(f"Verdict: {critique['verdict']}")
# Save to history
history.append({
"iteration": iteration + 1,
"prompt": current_prompt,
"image": image,
"score": score,
"critique": critique
})
# Check if target reached
if score >= target_score:
print(f"✓ Target score reached! ({score:.2f} >= {target_score})")
break
# Refine prompt for next iteration
if iteration < max_iterations - 1:
current_prompt = refine_prompt_based_on_critique(prompt, critique, iteration)
return history
# Test iterative refinement
test_prompt = "a red sports car in front of a modern building"
history = iterative_generate_with_refinement(test_prompt, max_iterations=3, target_score=27.0, seed=100)
# Visualize progression
fig, axes = plt.subplots(1, len(history), figsize=(5*len(history), 5))
if len(history) == 1:
axes = [axes]
for i, entry in enumerate(history):
axes[i].imshow(entry['image'])
axes[i].set_title(f"Iteration {entry['iteration']}\nScore: {entry['score']:.2f}", fontsize=10)
axes[i].axis('off')
plt.tight_layout()
plt.show()
print("\n=== Refinement Summary ===")
for entry in history:
print(f"Iter {entry['iteration']}: Score {entry['score']:.2f} - {entry['prompt'][:60]}...")4. Image-Conditioned Generation (Inpainting)
Beyond text-to-image, we can do image-conditioned generation: - Inpainting: Fill in masked regions - Image-to-image: Transform existing images - Variations: Generate similar images with changes
Let’s implement inpainting with Stable Diffusion.
# Load inpainting pipeline
print("Loading Stable Diffusion Inpainting...")
inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
safety_checker=None,
)
inpaint_pipe = inpaint_pipe.to(device)
if torch.cuda.is_available():
inpaint_pipe.enable_attention_slicing()
inpaint_pipe.enable_vae_slicing()
print("Inpainting pipeline loaded!")def create_circular_mask(image_size, center, radius):
"""Create a circular mask for inpainting."""
mask = Image.new('L', image_size, 0)
draw = ImageDraw.Draw(mask)
draw.ellipse(
[
(center[0] - radius, center[1] - radius),
(center[0] + radius, center[1] + radius)
],
fill=255
)
return mask
def inpaint_image(image, mask, prompt, num_inference_steps=25, seed=42):
"""Inpaint masked region of image based on prompt."""
generator = torch.Generator(device=device).manual_seed(seed)
# Resize to 512x512 (SD requirement)
image = image.resize((512, 512))
mask = mask.resize((512, 512))
with torch.inference_mode():
result = inpaint_pipe(
prompt=prompt,
image=image,
mask_image=mask,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
return result
# Generate a base image
base_image = generate_image("a room with white walls and wooden floor", seed=200)
base_image = base_image.resize((512, 512))
# Create mask for center region
mask = create_circular_mask((512, 512), center=(256, 256), radius=100)
# Inpaint with different objects
inpaint_prompts = [
"a red chair in the center of the room",
"a green plant in a pot",
"a modern lamp",
]
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# Show original and mask
axes[0, 0].imshow(base_image)
axes[0, 0].set_title("Original Image")
axes[0, 0].axis('off')
axes[0, 1].imshow(mask, cmap='gray')
axes[0, 1].set_title("Mask (region to inpaint)")
axes[0, 1].axis('off')
axes[0, 2].axis('off')
# Generate inpainted versions
for i, prompt in enumerate(inpaint_prompts):
print(f"Inpainting: {prompt}")
inpainted = inpaint_image(base_image, mask, prompt, seed=200+i)
axes[1, i].imshow(inpainted)
axes[1, i].set_title(f"Inpainted: {prompt}", fontsize=9)
axes[1, i].axis('off')
plt.tight_layout()
plt.show()5. VLM-Guided Inpainting with Quality Control
Combine inpainting with VLM critique for quality-controlled image editing: 1. User specifies region to edit and desired content 2. Inpaint the region 3. VLM checks if edit matches description 4. If not, try again with refined prompt
def vlm_guided_inpainting(image, mask, prompt, max_attempts=3, target_score=25.0, seed=42):
"""Inpaint with VLM quality control."""
attempts = []
current_prompt = prompt
for attempt in range(max_attempts):
print(f"\n--- Attempt {attempt + 1} ---")
print(f"Prompt: {current_prompt}")
# Inpaint
result = inpaint_image(image, mask, current_prompt, seed=seed+attempt)
# Critique the inpainted result
critique = critique_image(result, prompt)
score = critique['overall_score']
print(f"Score: {score:.2f}")
attempts.append({
"attempt": attempt + 1,
"prompt": current_prompt,
"result": result,
"score": score
})
if score >= target_score:
print(f"✓ Quality threshold reached!")
break
# Refine prompt
if attempt < max_attempts - 1:
current_prompt = f"{prompt}, highly detailed, realistic"
# Return best result
best = max(attempts, key=lambda x: x['score'])
return best, attempts
# Test VLM-guided inpainting
test_inpaint_prompt = "a beautiful red rose in a vase"
best, attempts = vlm_guided_inpainting(
base_image,
mask,
test_inpaint_prompt,
max_attempts=2,
target_score=24.0,
seed=300
)
# Visualize attempts
fig, axes = plt.subplots(1, len(attempts)+2, figsize=(5*(len(attempts)+2), 5))
axes[0].imshow(base_image)
axes[0].set_title("Original")
axes[0].axis('off')
axes[1].imshow(mask, cmap='gray')
axes[1].set_title("Mask")
axes[1].axis('off')
for i, att in enumerate(attempts):
axes[i+2].imshow(att['result'])
marker = "✓" if att == best else ""
axes[i+2].set_title(f"Attempt {att['attempt']} {marker}\nScore: {att['score']:.2f}", fontsize=9)
axes[i+2].axis('off')
plt.tight_layout()
plt.show()6. Multi-Prompt Generation and Selection
Another strategy: Generate multiple variations with different prompts/seeds, then use VLM to select the best one.
def generate_and_select_best(base_prompt, num_variations=4, seeds=None):
"""Generate multiple variations and select best using VLM."""
if seeds is None:
seeds = list(range(num_variations))
# Prompt variations (different phrasings)
prompt_variations = [
base_prompt,
f"{base_prompt}, professional photography",
f"{base_prompt}, digital art, trending on artstation",
f"{base_prompt}, ultra realistic, 8k, detailed",
]
candidates = []
print(f"Generating {num_variations} variations...")
for i in range(num_variations):
prompt = prompt_variations[i % len(prompt_variations)]
seed = seeds[i]
print(f" Variation {i+1}: seed={seed}")
image = generate_image(prompt, seed=seed, num_inference_steps=20)
# Score against base prompt
score = compute_clip_score(image, base_prompt)
candidates.append({
"index": i,
"prompt": prompt,
"seed": seed,
"image": image,
"score": score
})
# Select best
best = max(candidates, key=lambda x: x['score'])
print(f"\n✓ Best: Variation {best['index']+1} (score: {best['score']:.2f})")
return best, candidates
# Test multi-generation selection
test_prompt = "a magical forest with glowing mushrooms"
best, candidates = generate_and_select_best(test_prompt, num_variations=4, seeds=[10, 20, 30, 40])
# Visualize all candidates
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
for i, cand in enumerate(candidates):
axes[i].imshow(cand['image'])
marker = "★ BEST" if cand == best else ""
axes[i].set_title(f"Variation {cand['index']+1} {marker}\nScore: {cand['score']:.2f}\nSeed: {cand['seed']}", fontsize=9)
axes[i].axis('off')
plt.tight_layout()
plt.show()7. Advanced: Caption-Based Improvement Loop
We can use a captioning VLM to describe the generated image, then compare the caption to the original prompt to identify discrepancies.
# Load a simple captioning model (BLIP)
from transformers import BlipProcessor, BlipForConditionalGeneration
print("Loading BLIP for captioning...")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = blip_model.to(device)
blip_model.eval()
print("BLIP loaded!")def caption_image(image):
"""Generate caption for image using BLIP."""
inputs = blip_processor(image, return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = blip_model.generate(**inputs, max_length=50)
caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
return caption
def caption_based_refinement(prompt, max_iterations=3, seed=42):
"""Generate with caption-based feedback loop."""
history = []
current_prompt = prompt
for iteration in range(max_iterations):
print(f"\n--- Iteration {iteration + 1} ---")
print(f"Prompt: {current_prompt}")
# Generate
image = generate_image(current_prompt, seed=seed+iteration, num_inference_steps=20)
# Caption the result
generated_caption = caption_image(image)
print(f"Generated caption: '{generated_caption}'")
# Score alignment
score = compute_clip_score(image, prompt)
print(f"CLIP score: {score:.2f}")
history.append({
"iteration": iteration + 1,
"prompt": current_prompt,
"image": image,
"caption": generated_caption,
"score": score
})
# Check if satisfactory
if score > 27.0:
print("✓ Good quality achieved!")
break
# Refine: combine original prompt with generated caption insights
if iteration < max_iterations - 1:
current_prompt = f"{prompt}, {generated_caption}, high quality"
return history
# Test caption-based refinement
test_prompt = "a cyberpunk city at night with neon lights"
history = caption_based_refinement(test_prompt, max_iterations=2, seed=500)
# Visualize
fig, axes = plt.subplots(1, len(history), figsize=(7*len(history), 7))
if len(history) == 1:
axes = [axes]
for i, entry in enumerate(history):
axes[i].imshow(entry['image'])
axes[i].set_title(
f"Iteration {entry['iteration']}\n"
f"Score: {entry['score']:.2f}\n"
f"Caption: {entry['caption'][:50]}...",
fontsize=9
)
axes[i].axis('off')
plt.tight_layout()
plt.show()Summary and Comparison
Let’s compare all the VLM-guided generation strategies we’ve explored:
import pandas as pd
strategies = {
"Strategy": [
"Basic Text-to-Image",
"Iterative Refinement",
"Multi-Variation Selection",
"VLM-Guided Inpainting",
"Caption-Based Loop",
],
"Description": [
"Single generation from text prompt",
"Generate → critique → refine prompt → regenerate",
"Generate N variations → VLM selects best",
"Inpaint region → VLM quality check → retry if needed",
"Generate → caption → compare caption to prompt → refine",
],
"Pros": [
"Fast, simple",
"Improves quality iteratively",
"Explores diverse options",
"Precise region control",
"Detailed feedback via captions",
],
"Cons": [
"No quality control",
"Multiple generations needed",
"Computationally expensive (N generations)",
"Requires mask definition",
"Caption may miss details",
],
"Best Use Case": [
"Quick prototyping",
"High-quality single image",
"Finding best from options",
"Editing specific regions",
"Understanding generation failures",
],
}
df = pd.DataFrame(strategies)
print("\nVLM-Guided Generation Strategies Comparison:")
print(df.to_string(index=False))Key Takeaways
1. VLM as Quality Gatekeeper
- Use CLIP scores to verify image-text alignment
- Reject low-quality generations automatically
- Iterate until quality threshold is met
2. Iterative Refinement is Powerful
- Simple prompt refinement can significantly improve results
- 2-3 iterations usually enough to reach quality goals
- Each iteration learns from previous failures
3. Multi-Variation Selection
- Generate diverse options (different seeds/prompts)
- Let VLM pick the best match
- More expensive but explores solution space better
4. Image-Conditioned Generation
- Inpainting allows precise region control
- VLM can verify edits match descriptions
- Useful for iterative image editing workflows
5. Caption-Based Feedback
- Captions reveal what the model actually generated
- Compare caption to prompt to identify gaps
- Helps diagnose systematic generation failures
Practical Recommendations
For production systems: - Use multi-variation selection with VLM scoring - Set quality thresholds based on your use case - Log scores and prompts for analysis
For user-facing apps: - Show multiple variations, let users + VLM co-select - Implement “regenerate” with automatic refinement - Use inpainting for editing workflows
For research: - Use caption-based loops to understand model behavior - Experiment with different VLM critics (CLIP, BLIP, custom) - Try adversarial refinement (VLM as discriminator)
Next Steps
To extend this work: 1. Fine-tune VLM critic: Train on your specific domain for better scoring 2. Learned refinement: Train a model to predict prompt refinements 3. Multi-modal fusion: Combine CLIP scores, captions, and object detection 4. Interactive editing: Let users provide feedback alongside VLM 5. Efficiency: Use distilled/quantized VLMs for faster critique
Further Reading: - Stable Diffusion: https://github.com/Stability-AI/stablediffusion - CLIP: https://github.com/openai/CLIP - BLIP: https://github.com/salesforce/BLIP - Diffusers library: https://github.com/huggingface/diffusers
GPU Memory Cleanup
# Delete all models
del sd_pipe
del inpaint_pipe
del clip_model
del blip_model
# Clear cache
gc.collect()
torch.cuda.empty_cache()
print("GPU memory cleared!")
if torch.cuda.is_available():
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")