Manifold Hypothesis: High-dimensional data lies on low-dimensional manifold
Augmentation explores the manifold:
Interpolation on manifold:
where
Theoretical benefit:
# Visualize: Project to 2D with t-SNE
from sklearn.manifold import TSNE
original_2d = TSNE().fit_transform(X_original)
augmented_2d = TSNE().fit_transform(X_augmented)
# Augmented data fills manifold more densely
plt.scatter(original_2d[:, 0], original_2d[:, 1], label='Original')
plt.scatter(augmented_2d[:, 0], augmented_2d[:, 1], alpha=0.3, label='Augmented')
Mixup: Linear interpolation of examples and labels
Formula:
where
Theoretical motivation:
Implementation:
def mixup_data(x, y, alpha=1.0):
"""Mixup data and labels."""
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size(0)
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
# Training
mixed_x, y_a, y_b, lam = mixup_data(x, y, alpha=1.0)
pred = model(mixed_x)
loss = lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
CutMix: Replace patches instead of blending
Advantages over Mixup:
def cutmix(x, y, alpha=1.0):
"""CutMix augmentation."""
lam = np.random.beta(alpha, alpha)
batch_size, _, H, W = x.shape
index = torch.randperm(batch_size)
# Random box
cut_rat = np.sqrt(1. - lam)
cut_w = int(W * cut_rat)
cut_h = int(H * cut_rat)
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
# Apply patch
x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
# Adjust lambda to exactly match pixel ratio
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
return x, y, y[index], lam
MoEx (Momentum Exchange): Exponential moving average blending
Modern approach: Use diffusion models to generate variations
Workflow:
from diffusers import StableDiffusionImg2ImgPipeline
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
def diffusion_augment(image, prompt, strength=0.3):
"""Augment image using diffusion model."""
# Strength: How much to change (0=no change, 1=complete re-generation)
augmented = pipe(
prompt=prompt,
image=image,
strength=strength,
guidance_scale=7.5
).images[0]
return augmented
# Example: Augment dog image
prompt = "a photo of a dog"
aug_image = diffusion_augment(original_image, prompt, strength=0.2)
Benefits:
Challenges: Expensive, requires GPU
SimCLR: Self-supervised learning via contrastive loss
Key idea: Different augmentations of same image should have similar representations
Augmentation composition:
import torchvision.transforms as transforms
# SimCLR augmentation pipeline
simclr_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=23),
transforms.ToTensor(),
])
# Create two different views
view1 = simclr_transform(image)
view2 = simclr_transform(image)
# Contrastive loss: views should be similar
Findings:
MoCo v2 augmentation:
moco_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=23)], p=0.5),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
Queue-based approach:
Result: State-of-art self-supervised learning
Goal: Learn invariant features across environments
Formulation:
where:
Augmentation as environments:
def irm_loss(model, x, y, augmentations):
"""IRM loss across augmentation environments."""
total_loss = 0
penalty = 0
for aug in augmentations:
# Environment: different augmentation
x_aug = aug(x)
# Standard loss
pred = model(x_aug)
loss = F.cross_entropy(pred, y)
total_loss += loss
# Gradient penalty (encourage invariance)
grad = torch.autograd.grad(loss, model.parameters(), create_graph=True)
penalty += sum((g ** 2).sum() for g in grad)
return total_loss + lambda_irm * penalty
Benefit: Robust to distribution shift
Unsupervised Data Augmentation (UDA):
Idea: Model predictions should be consistent under augmentation
def uda_loss(model, x_unlabeled, strong_aug, weak_aug):
"""UDA consistency loss."""
# Weak augmentation prediction (pseudo-label)
with torch.no_grad():
weak_pred = model(weak_aug(x_unlabeled))
pseudo_label = torch.softmax(weak_pred, dim=1)
# Strong augmentation prediction
strong_pred = model(strong_aug(x_unlabeled))
# Consistency loss
loss = F.kl_div(
F.log_softmax(strong_pred, dim=1),
pseudo_label,
reduction='batchmean'
)
return loss
FixMatch: UDA + pseudo-labeling with confidence threshold
def fixmatch_loss(model, x_unlabeled, threshold=0.95):
"""FixMatch semi-supervised loss."""
# Weak augmentation
weak_pred = model(weak_aug(x_unlabeled))
max_probs, pseudo_labels = torch.max(torch.softmax(weak_pred, dim=1), dim=1)
# Only use high-confidence predictions
mask = max_probs >= threshold
# Strong augmentation
strong_pred = model(strong_aug(x_unlabeled))
# Supervised loss on pseudo-labels
loss = F.cross_entropy(strong_pred[mask], pseudo_labels[mask])
return loss
Neural Augmentation: Learn transformation parameters
Approach:
class LearnableAugmentation(nn.Module):
def __init__(self):
super().__init__()
# Learnable parameters for augmentation
self.rotation_range = nn.Parameter(torch.tensor(15.0))
self.brightness_factor = nn.Parameter(torch.tensor(0.2))
def forward(self, x):
# Apply augmentation with learned parameters
angle = torch.rand(1) * self.rotation_range
brightness = 1 + torch.rand(1) * self.brightness_factor
x_aug = rotate(x, angle)
x_aug = adjust_brightness(x_aug, brightness)
return x_aug
# Training: Backprop through augmentation
learnable_aug = LearnableAugmentation()
optimizer = torch.optim.Adam(learnable_aug.parameters())
for x, y in dataloader:
x_aug = learnable_aug(x)
pred = model(x_aug)
loss = criterion(pred, y)
loss.backward() # Gradients flow through augmentation!
Benefit: Automatically tune augmentation strength
Adversarial examples: Inputs with small perturbations that fool model
PGD (Projected Gradient Descent):
Adversarial training:
def pgd_attack(model, x, y, epsilon=0.3, alpha=0.01, num_iter=10):
"""Generate adversarial example."""
x_adv = x.clone().detach()
for _ in range(num_iter):
x_adv.requires_grad = True
# Compute loss
pred = model(x_adv)
loss = F.cross_entropy(pred, y)
# Gradient ascent
loss.backward()
grad = x_adv.grad
# Update adversarial example
x_adv = x_adv + alpha * grad.sign()
# Project back to epsilon ball
perturbation = torch.clamp(x_adv - x, -epsilon, epsilon)
x_adv = torch.clamp(x + perturbation, 0, 1).detach()
return x_adv
# Training with adversarial examples
for x, y in dataloader:
# Standard training
pred = model(x)
loss_clean = F.cross_entropy(pred, y)
# Adversarial training
x_adv = pgd_attack(model, x, y)
pred_adv = model(x_adv)
loss_adv = F.cross_entropy(pred_adv, y)
# Combined loss
loss = loss_clean + 0.5 * loss_adv
Goal: Learn which augmentations help for specific tasks
Meta-augmentation:
class MetaAugmentation:
def __init__(self, augmentations):
self.augmentations = augmentations
# Learnable weights for each augmentation
self.weights = nn.Parameter(torch.ones(len(augmentations)))
def sample_augmentation(self):
"""Sample augmentation based on learned weights."""
probs = F.softmax(self.weights, dim=0)
idx = torch.multinomial(probs, 1).item()
return self.augmentations[idx]
def meta_train(self, meta_train_tasks, meta_val_tasks):
"""Meta-training loop."""
for epoch in range(num_epochs):
for task in meta_train_tasks:
# Sample augmentation
aug = self.sample_augmentation()
# Train on augmented data
x_aug, y_aug = aug(task.x_train), task.y_train
model.train_step(x_aug, y_aug)
# Evaluate on validation
val_loss = model.evaluate(task.x_val, task.y_val)
# Update augmentation weights
val_loss.backward()
self.weights.grad # Gradient on weights
Benefit: Task-specific augmentation policies
Computational cost:
| Augmentation | Cost (ms/image) | Speedup Strategy |
|---|---|---|
| Horizontal flip | 0.1 | Already fast |
| Rotation | 2.0 | Use smaller angles |
| Color jitter | 1.5 | GPU acceleration |
| Cutout | 0.5 | Already fast |
| Mixup | 0.2 | In-place operations |
| Back-translation | 500.0 | Cache results |
| Diffusion models | 5000.0 | Offline generation |
Optimization strategies:
# 1. GPU acceleration
import kornia
transform = kornia.augmentation.AugmentationSequential(
kornia.augmentation.RandomRotation(30),
kornia.augmentation.ColorJitter(0.2, 0.2, 0.2, 0.1),
data_keys=["input"]
)
# Apply on GPU (batched)
x_aug = transform(x_gpu) # Much faster than CPU
# 2. Caching expensive augmentations
class CachedAugmentation:
def __init__(self, aug_fn, cache_size=10000):
self.aug_fn = aug_fn
self.cache = {}
def __call__(self, x, idx):
if idx not in self.cache:
self.cache[idx] = self.aug_fn(x)
return self.cache[idx]
# 3. Parallel augmentation
from concurrent.futures import ThreadPoolExecutor
def parallel_augment(images, aug_fn, n_workers=8):
with ThreadPoolExecutor(max_workers=n_workers) as executor:
augmented = list(executor.map(aug_fn, images))
return augmented
SaliencyMix: Mix based on saliency maps
def saliencymix(x, y, saliency_fn):
"""Mix based on saliency."""
batch_size = x.size(0)
index = torch.randperm(batch_size)
# Get saliency maps
sal_a = saliency_fn(x)
sal_b = saliency_fn(x[index])
# Mix based on saliency
mask = (sal_a > sal_b).float()
mixed_x = mask * x + (1 - mask) * x[index]
# Label proportional to saliency
lam = mask.sum() / mask.numel()
return mixed_x, y, y[index], lam
PuzzleMix: Mix by solving optimization problem
Co-Mixup: Mix within same class to preserve fine-grained features
Population Based Augmentation (PBA):
Algorithm:
class AugmentationPolicy:
def __init__(self):
self.ops = random.sample(ALL_OPS, k=5)
self.probs = np.random.uniform(0, 1, size=5)
self.magnitudes = np.random.uniform(0, 1, size=5)
def mutate(self):
"""Mutate policy."""
idx = random.randint(0, 4)
if random.random() < 0.5:
self.probs[idx] += np.random.normal(0, 0.1)
else:
self.magnitudes[idx] += np.random.normal(0, 0.1)
def crossover(self, other):
"""Combine two policies."""
child = AugmentationPolicy()
for i in range(5):
if random.random() < 0.5:
child.ops[i] = self.ops[i]
child.probs[i] = self.probs[i]
child.magnitudes[i] = self.magnitudes[i]
else:
child.ops[i] = other.ops[i]
child.probs[i] = other.probs[i]
child.magnitudes[i] = other.magnitudes[i]
return child
# Population based training
population = [AugmentationPolicy() for _ in range(16)]
for generation in range(50):
# Evaluate each policy
scores = []
for policy in population:
model = train_model(policy)
score = evaluate_model(model)
scores.append(score)
# Select top performers
top_indices = np.argsort(scores)[-8:]
survivors = [population[i] for i in top_indices]
# Create next generation
population = survivors.copy()
for _ in range(8):
parent1, parent2 = random.sample(survivors, 2)
child = parent1.crossover(parent2)
child.mutate()
population.append(child)
Problem: Rare classes benefit more from augmentation
Class-balanced augmentation:
class ClassBalancedAugmentation:
def __init__(self, class_counts):
# Compute augmentation probability per class
# More augmentation for rare classes
total = sum(class_counts)
self.aug_probs = {
cls: 1.0 - (count / total)
for cls, count in enumerate(class_counts)
}
def __call__(self, x, y):
"""Apply augmentation based on class."""
aug_prob = self.aug_probs[y]
if random.random() < aug_prob:
# Strong augmentation for rare classes
x = strong_augment(x)
else:
# Weak augmentation for common classes
x = weak_augment(x)
return x, y
# Example: CIFAR-10-LT (imbalanced)
class_counts = [5000, 2997, 1796, 1077, 645, 387, 232, 139, 83, 50]
aug = ClassBalancedAugmentation(class_counts)
Remix: Oversample tail classes with mixup
BBN: Bilateral-branch network with different augmentation per branch
Video-specific challenges:
Temporal augmentation:
def temporal_augment(video, fps=30):
"""Augment video data."""
# 1. Temporal crop
start = random.randint(0, len(video) - 64)
video = video[start:start+64]
# 2. Temporal sub-sampling
stride = random.choice([1, 2])
video = video[::stride]
# 3. Temporal jittering
# Randomly drop/duplicate frames
if random.random() < 0.3:
drop_idx = random.randint(0, len(video)-1)
video = np.delete(video, drop_idx, axis=0)
# 4. Speed perturbation
speed_factor = random.uniform(0.8, 1.2)
video = resample_video(video, speed_factor)
return video
Spatial + Temporal:
Point cloud augmentation:
def pointcloud_augment(points):
"""Augment 3D point cloud."""
# 1. Random rotation
angle = np.random.uniform(0, 2*np.pi)
rotation_matrix = np.array([
[np.cos(angle), -np.sin(angle), 0],
[np.sin(angle), np.cos(angle), 0],
[0, 0, 1]
])
points = points @ rotation_matrix.T
# 2. Random scaling
scale = np.random.uniform(0.8, 1.2)
points = points * scale
# 3. Random jitter
noise = np.random.normal(0, 0.02, size=points.shape)
points = points + noise
# 4. Random point dropout
keep_mask = np.random.random(len(points)) > 0.1
points = points[keep_mask]
return points
PointAugment: Learnable augmentation for point clouds
PointMixup: Mixup in 3D space
Graph-specific augmentation:
def graph_augment(graph):
"""Augment graph structure."""
# 1. Edge dropping
edge_mask = torch.rand(graph.num_edges) > 0.1
graph.edge_index = graph.edge_index[:, edge_mask]
# 2. Node dropping
node_mask = torch.rand(graph.num_nodes) > 0.1
graph = graph.subgraph(node_mask)
# 3. Feature masking
feat_mask = torch.rand(graph.x.size(1)) > 0.2
graph.x[:, ~feat_mask] = 0
# 4. Edge perturbation (add random edges)
n_new_edges = int(0.1 * graph.num_edges)
src = torch.randint(0, graph.num_nodes, (n_new_edges,))
dst = torch.randint(0, graph.num_nodes, (n_new_edges,))
new_edges = torch.stack([src, dst])
graph.edge_index = torch.cat([graph.edge_index, new_edges], dim=1)
return graph
GraphCL: Contrastive learning for graphs with augmentation
M-Mix: Mixup for molecular graphs
How to measure augmentation quality?
1. Downstream Performance:
def evaluate_augmentation(aug_fn, model, data):
"""Evaluate by downstream task performance."""
# Train with augmentation
model_aug = train_model(data, augmentation=aug_fn)
acc_aug = evaluate(model_aug, test_data)
# Train without augmentation
model_no_aug = train_model(data, augmentation=None)
acc_no_aug = evaluate(model_no_aug, test_data)
improvement = acc_aug - acc_no_aug
return improvement
2. Diversity Score:
def diversity_score(original, augmented):
"""Measure diversity of augmented samples."""
# Compute pairwise distances
from scipy.spatial.distance import pdist
all_samples = np.vstack([original, augmented])
distances = pdist(all_samples)
# Higher mean distance = more diverse
return np.mean(distances)
3. Invariance Test:
def invariance_score(model, x, augmentations):
"""Measure how invariant model is to augmentations."""
original_pred = model(x)
disagreements = []
for aug in augmentations:
x_aug = aug(x)
aug_pred = model(x_aug)
disagreement = (original_pred.argmax() != aug_pred.argmax()).float().mean()
disagreements.append(disagreement)
# Lower disagreement = better invariance
return 1 - np.mean(disagreements)
Idea: Start with weak augmentation, gradually increase strength
Progressive augmentation:
class CurriculumAugmentation:
def __init__(self, max_epochs):
self.max_epochs = max_epochs
self.current_epoch = 0
def get_augmentation(self):
"""Return augmentation based on training progress."""
# Linearly increase augmentation strength
progress = self.current_epoch / self.max_epochs
if progress < 0.3:
# Early: weak augmentation
return A.Compose([
A.HorizontalFlip(p=0.5),
])
elif progress < 0.7:
# Mid: medium augmentation
return A.Compose([
A.HorizontalFlip(p=0.5),
A.Rotate(limit=15, p=0.5),
A.RandomBrightnessContrast(p=0.3),
])
else:
# Late: strong augmentation
return A.Compose([
A.HorizontalFlip(p=0.5),
A.Rotate(limit=30, p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.GaussNoise(p=0.3),
A.Cutout(num_holes=8, max_h_size=16, max_w_size=16, p=0.5),
])
def update_epoch(self, epoch):
self.current_epoch = epoch
# Training loop
curr_aug = CurriculumAugmentation(max_epochs=100)
for epoch in range(100):
augmentation = curr_aug.get_augmentation()
for x, y in dataloader:
x_aug = augmentation(image=x)['image']
# Train...
curr_aug.update_epoch(epoch)
Cross-modal augmentation: Augment multiple modalities consistently
Example: Image + Text
def multimodal_augment(image, caption):
"""Augment image and caption together."""
# Image augmentation
if random.random() < 0.5:
image = horizontal_flip(image)
# Update caption if needed
# "person on left" → "person on right"
caption = flip_spatial_words(caption)
# Color augmentation
if random.random() < 0.3:
image = grayscale(image)
# Update caption: remove color words
caption = remove_color_adjectives(caption)
# Text augmentation (preserve image)
caption = synonym_replacement(caption)
return image, caption
Audio + Text (speech recognition):
def audio_text_augment(audio, transcript):
"""Augment audio and transcript together."""
# Speed perturbation
speed = random.uniform(0.9, 1.1)
audio = change_speed(audio, speed)
# Transcript unchanged (same words)
# Add noise (audio only)
audio = add_noise(audio, snr_db=random.uniform(10, 30))
# Text augmentation
# Simulate recognition errors
transcript = simulate_asr_errors(transcript, error_rate=0.05)
return audio, transcript
Stable Diffusion for augmentation:
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1")
def foundation_augment(original_image, class_name, n_augments=5):
"""Generate augmented images using Stable Diffusion."""
# Create prompt from class name
prompts = [
f"a photo of a {class_name}",
f"a {class_name} in different lighting",
f"a {class_name} from different angle",
f"a {class_name} in different background",
f"a high quality photo of a {class_name}",
]
augmented_images = []
for prompt in prompts[:n_augments]:
# Generate with guidance from original image
image = pipe(
prompt=prompt,
image=original_image,
strength=0.5, # How much to change
).images[0]
augmented_images.append(image)
return augmented_images
# Example: Augment "cat" images
cat_images = foundation_augment(original_cat_image, "cat", n_augments=10)
Benefits:
Challenges:
Question: Do augmentations learned on one dataset transfer to others?
Empirical findings:
ImageNet → Other Vision Tasks:
Natural Images → Medical Images:
Practical implications:
# Use pre-discovered policies for your domain
from timm.data.auto_augment import auto_augment_transform
# For natural images
transform = auto_augment_transform('original', {})
# For medical images
transform = auto_augment_transform('original-mstd0.5', {}) # Weaker
Domain-specific search still recommended for best results
Theoretical Foundations:
Advanced Mixing Strategies:
Modern Approaches:
Specialized Domains:
Meta-Strategies:
Production Considerations:
Common interview questions on data augmentation:
"When does data augmentation help and when can it hurt?"
"What is Mixup and why does it work?"
Augmentation expands effective training data
Choose augmentations that preserve semantics
Domain-specific strategies matter
Start simple, measure impact
Next week: Using LLMs for feature extraction
Lab: Implement and compare augmentation strategies
Measure impact on model performance