Using Diffusion to Generate Images from Text

ML
Author

Nipun Batra

Published

January 1, 2024

References

  1. PromptHero guide
%pip install --upgrade \
  diffusers \
  transformers \
  safetensors \
  sentencepiece \
  accelerate \
  bitsandbytes \
  torch \
  huggingface_hub --quiet
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchvision 0.15.1+cu118 requires torch==2.0.0, but you have torch 2.1.2 which is incompatible.
torchaudio 2.0.1+cu118 requires torch==2.0.0, but you have torch 2.1.2 which is incompatible.
Note: you may need to restart the kernel to use updated packages.
from huggingface_hub import login

login()

Basic Imports

# Display the images in a grid
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
from PIL import Image
import io
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, AutoencoderKL
import torch
pipe = DiffusionPipeline.from_pretrained(
    "prompthero/openjourney", 
    torch_dtype=torch.float16
)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.vae = vae
pipe = pipe.to("cuda")
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
def generate_images(prompt, num_steps, num_variations, prompt_guidance, dimensions):
    random_seeds = [random.randint(0, 65000) for _ in range(num_variations)]
    images = pipe(prompt= num_variations * [prompt],
              num_inference_steps=num_steps,
              guidance_scale=prompt_guidance,
              height = dimensions[0],
              width = dimensions[1],
              generator = [torch.Generator('cuda').manual_seed(i) for i in random_seeds]
             ).images
    return images
import random
# Setting for image generation
prompt = 'Small happy dog anf owner learning to walk on a rainy day. Colored photography. Leica lens. Hi-res. hd 8k --ar 2:3'
num_steps = 150
num_variations = 4
prompt_guidance = 8
dimensions = (400, 600) # (width, height) tuple
random_seeds = [random.randint(0, 65000) for _ in range(num_variations)]
images = generate_images(prompt, num_steps, num_variations, prompt_guidance, dimensions)


def display_images(images, num_variations, dimensions):
    fig = plt.figure(figsize=(dimensions[0]/10, dimensions[1]/10))
    columns = num_variations
    rows = 1
    for i in range(1, columns*rows +1):
        img = images[i-1]
        fig.add_subplot(rows, columns, i)
        plt.imshow(img)
        # hide axes
        plt.axis('off')
    plt.show()
    
display_images(images, num_variations, dimensions)

prompt = "Batman, cinematic lighting, dark background, very high resolution 3D render."


images = generate_images(prompt, num_steps, num_variations, prompt_guidance, dimensions)
display_images(images, num_variations, dimensions)

prompt = "A logo for a research group in India called Sustainability lab that works on AI for sustainability. Show AI as the central theme. Show applications in health, air quality"

images = generate_images(prompt, num_steps, num_variations, prompt_guidance, dimensions)
display_images(images, num_variations, dimensions)

prompt = "Portrait of a small kid with a big smile."

images = generate_images(prompt, num_steps, num_variations, prompt_guidance, dimensions)
display_images(images, num_variations, dimensions)

prompt = "A photorealistic render of an academic campus in India, on the edges of a river, low-light, evening."

images = generate_images(prompt, num_steps, num_variations, prompt_guidance, dimensions)
display_images(images, num_variations, dimensions)