from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import torch
%config InlineBackend.figure_format = 'retina'
class MaskedImageDataset(Dataset):
def __init__(self, base_dataset, mask_ratio=0.75):
self.base = base_dataset
self.mask_ratio = mask_ratio
def __len__(self):
return len(self.base)
def __getitem__(self, idx):
= self.base[idx] # ignore label
img, _ = torch.rand_like(img[:1]) > self.mask_ratio # [1,H,W]
mask = img * mask # zero out masked pixels
masked_img return masked_img, img # (X=masked, Y=original)
from torchvision import datasets
= transforms.Compose([
transform 224,224)),
transforms.Resize((
transforms.ToTensor()
])
from torch.utils.data import Subset
# full CIFAR10
= datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_base = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_base
# take first 1000 samples from each
= Subset(train_base, range(1000))
train_small = Subset(test_base, range(1000))
test_small
# wrap with masking dataset
= MaskedImageDataset(train_small, mask_ratio=0.75)
train_masked = MaskedImageDataset(test_small, mask_ratio=0.75)
test_masked
# dataloaders
from torch.utils.data import DataLoader
= DataLoader(train_masked, batch_size=16, shuffle=False)
train_loader = DataLoader(test_masked, batch_size=16, shuffle=False)
test_loader
print(len(train_loader.dataset), len(test_loader.dataset))
1000 1000
= train_masked[0]
x, y print(x.shape, y.shape) # both [3,224,224]
import matplotlib.pyplot as plt
1,2,1)
plt.subplot(1,2,0)) # masked
plt.imshow(x.permute(1,2,2)
plt.subplot(1,2,0)) # original
plt.imshow(y.permute( plt.show()
torch.Size([3, 224, 224]) torch.Size([3, 224, 224])
import torch.nn as nn
class SimpleAutoencoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
3, 32, 4, 2, 1), # 112×112
nn.Conv2d(
nn.ReLU(),32, 64, 4, 2, 1), # 56×56
nn.Conv2d(
nn.ReLU(),64, 128, 4, 2, 1),# 28×28
nn.Conv2d(
nn.ReLU()
)self.decoder = nn.Sequential(
128, 64, 4, 2, 1), # 56×56
nn.ConvTranspose2d(
nn.ReLU(),64, 32, 4, 2, 1), # 112×112
nn.ConvTranspose2d(
nn.ReLU(),32, 3, 4, 2, 1), # 224×224
nn.ConvTranspose2d(
nn.Sigmoid()
)
def forward(self, x):
= self.encoder(x)
z = self.decoder(z)
out return out
= SimpleAutoencoder()
model = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device
= model.to(device) model
def visualize_reconstruction(model, dataset, idx=0, device='cpu'):
"""Visualize one reconstruction from a masked dataset."""
eval()
model.
= dataset[idx]
x_masked, y_full = x_masked.unsqueeze(0).to(device)
x_masked = y_full.unsqueeze(0).to(device)
y_full
with torch.no_grad():
= model(x_masked)
recon
# Plot side-by-side
=(9, 3))
plt.figure(figsize= [x_masked[0], y_full[0], recon[0]]
imgs = ["Masked", "Original", "Reconstructed"]
titles
for i in range(3):
1, 3, i + 1)
plt.subplot(1, 2, 0).cpu().clamp(0, 1))
plt.imshow(imgs[i].permute(
plt.title(titles[i])"off")
plt.axis(
plt.tight_layout()
plt.show()
=1, device=device) visualize_reconstruction(model, train_masked, idx
= SimpleAutoencoder().to(device)
model = torch.optim.Adam(model.parameters(), lr=3e-4)
optimizer = nn.MSELoss()
criterion
def train_autoencoder(model, train_loader, optimizer, criterion, device, epochs=5):
model.to(device)
model.train()
for epoch in range(epochs):
= 0.0
total_loss for x_masked, y_full in train_loader:
= x_masked.to(device), y_full.to(device)
x_masked, y_full
# Forward → Backward → Optimize
= model(x_masked)
pred = criterion(pred, y_full)
loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
+= loss.item() * x_masked.size(0)
total_loss
= total_loss / len(train_loader.dataset)
epoch_loss print(f"Epoch {epoch+1}/{epochs} | Train Loss: {epoch_loss:.4f}")
=10)
train_autoencoder(model, train_loader, optimizer, criterion, device, epochs
Epoch 1/10 | Train Loss: 0.0397
Epoch 2/10 | Train Loss: 0.0257
Epoch 3/10 | Train Loss: 0.0226
Epoch 4/10 | Train Loss: 0.0211
Epoch 5/10 | Train Loss: 0.0200
Epoch 6/10 | Train Loss: 0.0193
Epoch 7/10 | Train Loss: 0.0191
Epoch 8/10 | Train Loss: 0.0191
Epoch 9/10 | Train Loss: 0.0189
Epoch 10/10 | Train Loss: 0.0182
=1, device=device) visualize_reconstruction(model, train_masked, idx
Now, let us improve the model by adding skip connections like in U-Net.
import torch.nn.functional as F
class SkipAutoencoder(nn.Module):
def __init__(self):
super().__init__()
self.enc1 = nn.Conv2d(3, 32, 4, 2, 1)
self.enc2 = nn.Conv2d(32, 64, 4, 2, 1)
self.enc3 = nn.Conv2d(64, 128, 4, 2, 1)
self.dec3 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
self.dec2 = nn.ConvTranspose2d(64+64, 32, 4, 2, 1)
self.dec1 = nn.ConvTranspose2d(32+32, 3, 4, 2, 1)
def forward(self, x):
= F.relu(self.enc1(x)) # 112x112
x1 = F.relu(self.enc2(x1)) # 56x56
x2 = F.relu(self.enc3(x2)) # 28x28
x3 = F.relu(self.dec3(x3)) # 56x56
y3 = F.relu(self.dec2(torch.cat([y3, x2], 1))) # skip
y2 = torch.sigmoid(self.dec1(torch.cat([y2, x1], 1)))
y1 return y1
= SkipAutoencoder().to(device)
model = torch.optim.Adam(model.parameters(), lr=3e-4)
optimizer = nn.MSELoss()
criterion
=10) train_autoencoder(model, train_loader, optimizer, criterion, device, epochs
Epoch 1/10 | Train Loss: 0.0449
Epoch 2/10 | Train Loss: 0.0283
Epoch 3/10 | Train Loss: 0.0236
Epoch 4/10 | Train Loss: 0.0218
Epoch 5/10 | Train Loss: 0.0195
Epoch 6/10 | Train Loss: 0.0187
Epoch 7/10 | Train Loss: 0.0174
Epoch 8/10 | Train Loss: 0.0173
Epoch 9/10 | Train Loss: 0.0164
Epoch 10/10 | Train Loss: 0.0159
=1, device=device) visualize_reconstruction(model, train_masked, idx
We can see that the skip connections help a lot in reconstructing the original image from the masked input. The colors match much better and the details are clearer.
Mask out entire patches
class PatchMaskedDataset(Dataset):
def __init__(self, base_dataset, patch_size=16, mask_ratio=0.75):
self.base_dataset = base_dataset
self.patch_size = patch_size
self.mask_ratio = mask_ratio
def __len__(self):
return len(self.base_dataset)
def __getitem__(self, idx):
= self.base_dataset[idx] # we don’t use labels for MAE
img, _ = img.shape
C, H, W assert H % self.patch_size == 0 and W % self.patch_size == 0
# --- Split into patches ---
= H // self.patch_size
h_patches = W // self.patch_size
w_patches = h_patches * w_patches
total_patches
# --- Choose patches to mask ---
= int(total_patches * self.mask_ratio)
num_mask = np.random.choice(total_patches, num_mask, replace=False)
mask_indices
# --- Make copy for masked version ---
= img.clone()
img_masked for idx_mask in mask_indices:
= idx_mask // w_patches
r = idx_mask % w_patches
c = r * self.patch_size, (r + 1) * self.patch_size
y0, y1 = c * self.patch_size, (c + 1) * self.patch_size
x0, x1 = 0.0 # zero out full patch
img_masked[:, y0:y1, x0:x1]
return img_masked, img
= transforms.Compose([
transform 224, 224)),
transforms.Resize((
transforms.ToTensor()
])
= datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_base = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_base
# Take small subset for educational speed
from torch.utils.data import Subset
= Subset(train_base, range(1000))
train_small = Subset(test_base, range(1000))
test_small
# Wrap with patch-mask dataset
= PatchMaskedDataset(train_small, patch_size=16, mask_ratio=0.75)
train_pmask = PatchMaskedDataset(test_small, patch_size=16, mask_ratio=0.75)
test_pmask
from torch.utils.data import DataLoader
= DataLoader(train_pmask, batch_size=8, shuffle=True)
train_loader = DataLoader(test_pmask, batch_size=8, shuffle=False) test_loader
= SkipAutoencoder().to(device)
model = torch.optim.Adam(model.parameters(), lr=3e-4)
optimizer = nn.MSELoss()
criterion =1, device=device) visualize_reconstruction(model, train_pmask, idx
=15) train_autoencoder(model, train_loader, optimizer, criterion, device, epochs
Epoch 1/15 | Train Loss: 0.0412
Epoch 2/15 | Train Loss: 0.0267
Epoch 3/15 | Train Loss: 0.0239
Epoch 4/15 | Train Loss: 0.0220
Epoch 5/15 | Train Loss: 0.0208
Epoch 6/15 | Train Loss: 0.0198
Epoch 7/15 | Train Loss: 0.0192
Epoch 8/15 | Train Loss: 0.0180
Epoch 9/15 | Train Loss: 0.0180
Epoch 10/15 | Train Loss: 0.0174
Epoch 11/15 | Train Loss: 0.0168
Epoch 12/15 | Train Loss: 0.0159
Epoch 13/15 | Train Loss: 0.0156
Epoch 14/15 | Train Loss: 0.0150
Epoch 15/15 | Train Loss: 0.0150
=1, device=device) visualize_reconstruction(model, train_pmask, idx