from PIL import Image, ImageDraw, ImageFilter
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import random
%config InlineBackend.figure_format = 'retina'
class ShapesDataset(Dataset):
def __init__(self, n_samples=200, size=224):
self.n_samples = n_samples
self.size = size
self.to_tensor = transforms.ToTensor()
def __len__(self):
return self.n_samples
def _rand_color(self):
# random pastel or bright color
return tuple(np.random.randint(50, 255, 3))
def _make_shape(self):
= self.size
size
# random background color (light texture feel)
= tuple(np.random.randint(0, 100, 3))
bg_color = Image.new('RGB', (size, size), bg_color)
img = Image.new('L', (size, size), 0)
mask = ImageDraw.Draw(img)
draw_img = ImageDraw.Draw(mask)
draw_mask
= np.random.choice(['circle', 'square'])
shape = size // 8
margin = size // 8
min_extent = size // 4
max_extent
= np.random.randint(margin, size - margin)
cx = np.random.randint(margin, size - margin)
cy = np.random.randint(min_extent, max_extent)
half
# random color and slight blur or transparency
= self._rand_color()
shape_color = random.uniform(0.7, 1.0)
alpha
if shape == 'circle':
-half, cy-half, cx+half, cy+half), fill=shape_color)
draw_img.ellipse((cx-half, cy-half, cx+half, cy+half), fill=1)
draw_mask.ellipse((cxelse:
-half, cy-half, cx+half, cy+half), fill=shape_color)
draw_img.rectangle((cx-half, cy-half, cx+half, cy+half), fill=1)
draw_mask.rectangle((cx
# add random Gaussian noise / blur to background
if np.random.rand() < 0.5:
= img.filter(ImageFilter.GaussianBlur(radius=np.random.uniform(0, 1.5)))
img
# overlay random noise to make it look less perfect
= np.array(img).astype(np.float32)
np_img = np.random.randn(*np_img.shape) * 10 # light noise
noise = np.clip(np_img + noise, 0, 255).astype(np.uint8)
np_img = Image.fromarray(np_img)
img
= self.to_tensor(img)
img = torch.tensor(np.array(mask), dtype=torch.long)
mask return img, mask
def __getitem__(self, idx):
return self._make_shape()
from torch.utils.data import DataLoader
= ShapesDataset(n_samples=100, size=224)
train_dataset = ShapesDataset(n_samples=20, size=224)
test_dataset
= DataLoader(train_dataset, batch_size=8, shuffle=True)
train_loader = DataLoader(test_dataset, batch_size=8, shuffle=False) test_loader
import matplotlib.pyplot as plt
= next(iter(train_loader))
imgs, masks = plt.subplots(2, 4, figsize=(12, 6))
fig, axes for i in range(4):
0, i].imshow(imgs[i].permute(1, 2, 0))
axes[0, i].axis('off')
axes[1, i].imshow(masks[i], cmap='gray')
axes[1, i].axis('off')
axes[
0, 0].set_title('Images')
axes[1, 0].set_title('Masks') axes[
Text(0.5, 1.0, 'Masks')
import os
import json
def save_shapes_dataset(
="shapes224",
root=200,
n_train=50,
n_test=224,
size=42
seed
):=True)
os.makedirs(root, exist_ok
np.random.seed(seed)
torch.manual_seed(seed)
= {"train": n_train, "test": n_test}
splits = ShapesDataset(n_samples=n_train + n_test, size=size)
dataset
= {}
metadata = 0
idx
for split, n_samples in splits.items():
= os.path.join(root, split, "images")
img_dir = os.path.join(root, split, "masks")
mask_dir =True)
os.makedirs(img_dir, exist_ok=True)
os.makedirs(mask_dir, exist_ok
for _ in range(n_samples):
= dataset[idx]
img, mask += 1
idx
= transforms.ToPILImage()(img)
img_pil = Image.fromarray(mask.numpy().astype("uint8"))
mask_pil
= f"{idx:04d}.png"
img_name = f"{idx:04d}.png"
mask_name
= os.path.join("images", img_name)
img_path = os.path.join("masks", mask_name)
mask_path
img_pil.save(os.path.join(root, split, img_path))
mask_pil.save(os.path.join(root, split, mask_path))
metadata.setdefault(split, []).append({"image": img_path,
"mask": mask_path,
"width": size,
"height": size
})
# save metadata
= os.path.join(root, "metadata.json")
meta_path with open(meta_path, "w") as f:
=2)
json.dump(metadata, f, indentprint(f"Saved dataset to {root}/ with metadata.json")
="shapes224", n_train=1000, n_test=1000) save_shapes_dataset(root
Saved dataset to shapes224/ with metadata.json
class ShapesDiskDataset(Dataset):
def __init__(self, root="shapes224", split="train", transform=None, target_transform=None):
self.root = root
self.split = split
self.transform = transform or transforms.ToTensor()
self.target_transform = target_transform
self.meta_path = os.path.join(root, "metadata.json")
with open(self.meta_path, "r") as f:
= json.load(f)
metadata self.samples = metadata[split]
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
= self.samples[idx]
sample = os.path.join(self.root, self.split, sample["image"])
img_path = os.path.join(self.root, self.split, sample["mask"])
mask_path
= Image.open(img_path).convert("RGB")
img = Image.open(mask_path).convert("L")
mask
= self.transform(img)
img = torch.tensor(np.array(mask), dtype=torch.long)
mask return img, mask
= ShapesDiskDataset(root="shapes224", split="train")
train_dataset print(f"Train samples: {len(train_dataset)}")
for i in range(3):
= train_dataset[i]
img, mask 2,3,i+1)
plt.subplot(1,2,0))
plt.imshow(img.permute('off')
plt.axis(2,3,i+4)
plt.subplot(='gray')
plt.imshow(mask, cmap'off')
plt.axis( plt.show()
Train samples: 1000