Train on all tasks

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'
import torch
import torch.nn as nn

from tueplots import bundles

# Also add despine to the bundle using rcParams
plt.rcParams['axes.spines.right'] = False
plt.rcParams[''] = False

# Increase font size to match Beamer template
plt.rcParams['font.size'] = 16
# Make background transparent
plt.rcParams['figure.facecolor'] = 'none'
import torch.distributions as dist
μ_α = dist.Normal(0.0, 2.0)
σ_α =dist.HalfNormal(1.0)
μ_β = dist.Normal(0.0, 3.0)
σ_β = dist.HalfNormal(1.0) 

n_tasks = 11
μ_α_samples = μ_α.sample((n_tasks,))
σ_α_samples = σ_α.sample((n_tasks,))
μ_β_samples = μ_β.sample((n_tasks,))
σ_β_samples = σ_β.sample((n_tasks,))
α = dist.Normal(μ_α_samples, σ_α_samples).sample()
β = dist.Normal(μ_β_samples, σ_β_samples).sample()

σ = dist.HalfNormal(5.0).sample((n_tasks,))
x_lin = torch.linspace(-1, 1, 100)

true_fs = []
ys = []

for i in range(n_tasks):
    true_fs.append(α[i] + β[i] * x_lin)

# Add noise
for i in range(n_tasks):
    ys.append(dist.Normal(true_fs[i], σ[i]).sample())

# Normalize both x and y for all tasks
true_fs_norm = []
ys_norm = []
x_means_task = []
x_stds_task = []
y_means_task = []
y_stds_task = []

for i in range(n_tasks):
    ys_norm.append((ys[i] - ys[i].mean()) / ys[i].std())
    true_fs_norm.append((true_fs[i] - true_fs[i].mean()) / true_fs[i].std())
# Plot the `n_tasks` functions with noise and the true functions in grid 
# of 2 x 5

fig, axes = plt.subplots(2, 5, figsize=(8, 4), sharex=True, sharey=True)

for i, ax in enumerate(axes.flatten()):
    ax.plot(x_lin.cpu(), true_fs[i].cpu(), label='True function')
    ax.scatter(x_lin.cpu(), ys[i].cpu(), label='Data', s=4, alpha=0.2)
    # Print the parameters in the title
    ax.set_title(f'Task {i}\n α=[i]:.2f}, β=[i]:.2f}, σ=[i]:.2f}')
plt.suptitle(r'True functions $(f(x) = \alpha + \beta x)$ and data $(y \sim \mathcal{N}(f(x), \sigma))$')
plt.savefig("../diagrams/metalearning/true.pdf", bbox_inches="tight")
# Plot the normalized functions and data in grid of 2 x 5
fig, axes = plt.subplots(2, 5, figsize=(8, 4), sharex=True, sharey=True)

for i, ax in enumerate(axes.flatten()):
    ax.plot(x_lin.cpu(), true_fs_norm[i].cpu(), label='True function')
    ax.scatter(x_lin.cpu(), ys_norm[i].cpu(), label='Data', s=4, alpha=0.2)
    ax.set_title(f'Task {i}\n α=[i]:.2f}, β=[i]:.2f}, σ=[i]:.2f}')
plt.suptitle(r'True functions $(f(x) = \alpha + \beta x)$ and data $(y \sim \mathcal{N}(f(x), \sigma))$')
Text(0.5, 0.98, 'True functions $(f(x) = \\alpha + \\beta x)$ and data $(y \\sim \\mathcal{N}(f(x), \\sigma))$')

# Plot the last task with few data points (context)
context_size = 5
context_idx = torch.randperm(100)[:context_size]
context_x = x_lin[context_idx]
context_y = ys[-1][context_idx]

plt.scatter(context_x.cpu(), context_y.cpu(), label='Context', s=20, color='k')
plt.plot(x_lin.cpu(), true_fs[-1].cpu(), label='True function (to estimate)')
plt.title("New Task")
plt.savefig("../diagrams/metalearning/context.pdf", bbox_inches="tight")

# Split data across each task into train and test sets

x_train = []
y_train = []

x_test = []
y_test = []

for i in range(n_tasks):
    # For each task, divide the data into 50% train and 50% context randomly
    r_perm = torch.randperm(100)
    train_idx = r_perm[:50]
    test_idx = r_perm[50:]
    print(x_train[i].shape, y_train[i].shape)
# Plot the train and test sets for each task

fig, axes = plt.subplots(2, 5, figsize=(8, 4), sharex=True, sharey=True)

for i, ax in enumerate(axes.flatten()):
    ax.scatter(x_train[i].cpu(), y_train[i].cpu(), label='Train', s=4)
    ax.scatter(x_test[i].cpu(), y_test[i].cpu(), label='Context', s=4)
    ax.set_title(f'Task {i}')
# Define the hyper-net and target-net
hyper_net = torch.nn.Sequential(torch.nn.Linear(2, 64), torch.nn.SELU(), torch.nn.Linear(64, 2))
target_net = torch.nn.Linear(1, 1)
# Let us pick Task 0

task = 0
len_train = len(x_train[task])
x_train_task = x_train[task]
y_train_task = y_train[task]

# Context is 50% of the training data, last 50% is training data
x_c = x_train_task[:len_train // 2]
y_c = y_train_task[:len_train // 2]

x_t = x_train_task[len_train // 2:]
y_t = y_train_task[len_train // 2:]

# Concatenate x_c and y_c to form the context
context =[x_c.view(-1, 1), y_c.view(-1, 1)], dim=1)

print(context.shape, x_t.shape, y_t.shape)
# Pass the context to the hyper-net to get output
hyper_out = hyper_net(context)
# Average the output of the hyper-net to get the weights of the target-net
weights = hyper_out.mean(dim=0)
# Create a new target-net with the weight and bias from the hyper-net
target_net_new = torch.nn.Linear(1, 1)

# Set the weights and bias of the new target-net from the hyper-net = weights[:1].view(1, 1) = weights[1:]


# Create torch.func.functional_call to call the target-net with the 
# weights and bias from the hyper-net
# and pass the test data to get the predictions

target_net_new = torch.nn.Linear(1, 1)
new_dict = target_net_new.state_dict()
new_dict.update({'weight': weights[:1].view(1, 1), 'bias': weights[1:]})

# Predict on the train set with the new target-net
y_pred = target_net_new(x_t.view(-1, 1)).ravel()

criterion = torch.nn.MSELoss()

l = criterion(y_pred, y_t)
plt.scatter(x_t.cpu(), y_t.cpu(), label='Train')
plt.scatter(x_c.cpu(), y_c.cpu(), label='Context')
for i in range(10):
    # Define hyper_net and target_net architectures
    hyper_net = torch.nn.Sequential(torch.nn.Linear(2, 64), torch.nn.SELU(), torch.nn.Linear(64, 2))
    target_net = torch.nn.Linear(1, 1)  # Create a target_net
    # Learnt function 
    with torch.no_grad():
        plt.plot(x_lin.cpu(), target_net(x_lin.view(-1, 1)).cpu().ravel(), label=f'Learnt function {i}')
    # Put legend outside the plot
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# Define hyper_net and target_net architectures
hyper_net = torch.nn.Sequential(torch.nn.Linear(2, 64), torch.nn.SELU(), torch.nn.Linear(64, 2))
target_net = torch.nn.Linear(1, 1)  # Create a target_net

optimizer = torch.optim.Adam(hyper_net.parameters(), lr=0.01)
criterion = nn.MSELoss()
new_dict = target_net.state_dict()

for epoch in range(1000):

    hyper_out = hyper_net(context)
    weights = hyper_out.mean(dim=0)
    new_dict.update({'weight': weights[:1].view(1, 1), 'bias': weights[1:]})

    y_pred = torch.func.functional_call(target_net, new_dict, x_t.view(-1, 1)).ravel()

    l = criterion(y_pred, y_t)

    if epoch % 30 == 0:
        print(f'Epoch {epoch} loss {l:.2f}')
plt.scatter(x_t.cpu(), y_t.cpu(), label='Train')
plt.scatter(x_c.cpu(), y_c.cpu(), label='Context')
# Learnt function 
with torch.no_grad():
    plt.plot(x_lin.cpu(), target_net(x_lin.view(-1, 1)).ravel().cpu(), label='Learnt function')
# Define hyper_net and target_net architectures
hyper_net = torch.nn.Sequential(torch.nn.Linear(2, 64), torch.nn.SELU(), torch.nn.Linear(64, 2))
target_net = torch.nn.Linear(1, 1)  # Create a target_net

optimizer = torch.optim.Adam(hyper_net.parameters(), lr=0.01)
criterion = nn.MSELoss()
new_dict = target_net.state_dict()

for epoch in range(200):
    for task in range(len(x_train)):
        len_train = len(x_train[task])
        x_train_task = x_train[task]
        y_train_task = y_train[task]

        # Context is 50% of the training data, last 50% is training data
        x_c = x_train_task[:len_train // 2]
        y_c = y_train_task[:len_train // 2]

        x_t = x_train_task[len_train // 2:]
        y_t = y_train_task[len_train // 2:]

        # Concatenate x_c and y_c to form the context
        context =[x_c.view(-1, 1), y_c.view(-1, 1)], dim=1)

        hyper_out = hyper_net(context)
        weights = hyper_out.mean(dim=0)
        new_dict.update({'weight': weights[:1].view(1, 1), 'bias': weights[1:]})

        y_pred = torch.func.functional_call(target_net, new_dict, x_t.view(-1, 1)).ravel()

        l = criterion(y_pred, y_t)

    if epoch % 30 == 0:
        print(f'Epoch {epoch} loss {l:.2f}')
Predict on all tasks

fig, axes = plt.subplots(2, 5, figsize=(8, 4), sharex=True, sharey=True)
axes = axes.flatten()
for task, ax in zip(range(len(x_train)), axes):
    # Context is 50% of the training data, last 50% is training data
    # x_c = x_train_task[:len_train // 2]
    # y_c = y_train_task[:len_train // 2]

    # x_t = x_train_task[len_train // 2:]
    # y_t = y_train_task[len_train // 2:]

    # Concatenate x_c and y_c to form the context
    context =[x_train[task].view(-1, 1), y_train[task].view(-1, 1)], dim=1)
    # print("context", context.shape)

    hyper_out = hyper_net(context)
    weights = hyper_out.mean(dim=0)
    new_dict.update({'weight': weights[:1].view(1, 1), 'bias': weights[1:]})

    with torch.no_grad():
        y_pred = torch.func.functional_call(target_net, new_dict, x_test[task].view(-1, 1)).ravel()
    # print(y_pred.shape)
    ax.scatter(x_train[task].cpu(), y_train[task].cpu(), label='Train', s=4)
    ax.scatter(x_test[task].cpu(), y_test[task].cpu(), label='Context', s=4)
    ax.plot(x_test[task].cpu(), y_pred.cpu(), label='Learnt function', color='k')
    ax.set_title(f'Task {task}')

Predict on new task

fig, ax = plt.subplots()
    # Concatenate x_c and y_c to form the context
context =[x_train[-1].view(-1, 1), y_train[-1].view(-1, 1)], dim=1)

# Cut the context to have only 5 points
context = context[:5]

print("context", context.shape)

hyper_out = hyper_net(context)
weights = hyper_out.mean(dim=0)
new_dict.update({'weight': weights[:1].view(1, 1), 'bias': weights[1:]})

with torch.no_grad():
    y_pred = torch.func.functional_call(target_net, new_dict, x_test[-1].view(-1, 1)).ravel()
# print(y_pred.shape)

ax.scatter(context[:, 0].cpu(), context[:, 1].cpu(), label='Train', s=4)
ax.scatter(x_test[-1].cpu(), y_test[-1].cpu(), label='Context', s=4)
ax.plot(x_test[-1].cpu(), y_pred.cpu(), label='Learnt function', color='k')
ax.set_title(f'-1 {i}')
Neural Processes

encoder = torch.nn.Sequential(torch.nn.Linear(2, 128), torch.nn.ReLU(), torch.nn.Linear(128, 128))
decoder = torch.nn.Sequential(torch.nn.Linear(128+1, 128), torch.nn.ReLU(), torch.nn.Linear(128, 1))

Train on a single task

# Let us pick Task 0

task = 4
len_train = len(x_train[task])
x_train_task = x_train[task]
y_train_task = y_train[task]

# Context is 50% of the training data, last 50% is training data
x_c = x_train_task[:len_train // 2]
y_c = y_train_task[:len_train // 2]

x_t = x_train_task[len_train // 2:]
y_t = y_train_task[len_train // 2:]

# Concatenate x_c and y_c to form the context
context =[x_c.view(-1, 1), y_c.view(-1, 1)], dim=1)

print(context.shape, x_t.shape, y_t.shape)
representation = encoder(context)
representation = representation.mean(dim=0, keepdim=True)
target_repr = representation.repeat(x_t.shape[0], 1)
joint_target_x =[target_repr, x_t.view(-1, 1)], dim=1)
pred = decoder(joint_target_x)
# Define hyper_net and target_net architectures
# hyper_net = torch.nn.Sequential(torch.nn.Linear(2, 64), torch.nn.SELU(), torch.nn.Linear(64, 2))
# target_net = torch.nn.Linear(1, 1)  # Create a target_net

optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-2)
criterion = nn.MSELoss()

for epoch in range(1000):

    representation = encoder(context)
    representation = representation.mean(dim=0, keepdim=True)
    target_repr = representation.repeat(x_t.shape[0], 1)
    joint_target_x =[target_repr, x_t.view(-1, 1)], dim=1)
    y_pred = decoder(joint_target_x)

    l = criterion(y_pred, y_t)

    if epoch % 30 == 0:
        print(f'Epoch {epoch} loss {l:.2f}')
plt.scatter(x_t.cpu(), y_t.cpu(), label='Train')
plt.scatter(x_c.cpu(), y_c.cpu(), label='Context')
# Learnt function 
with torch.no_grad():
    representation = encoder(context)
    representation = representation.mean(dim=0, keepdim=True)
    target_repr = representation.repeat(x_lin.shape[0], 1)
    print(target_repr.shape, x_lin.shape)
    joint_target_x =[target_repr, x_lin.view(-1, 1)], dim=1)
    plt.plot(x_lin.cpu(), decoder(joint_target_x).ravel().cpu(), label='Learnt function')
Train on all tasks

# Define hyper_net and target_net architectures
encoder = torch.nn.Sequential(torch.nn.Linear(2, 8), torch.nn.ReLU(), torch.nn.Linear(8, 8))
decoder = torch.nn.Sequential(torch.nn.Linear(8+1, 8), torch.nn.ReLU(), torch.nn.Linear(8, 1))

optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.001)
criterion = nn.MSELoss()
new_dict = target_net.state_dict()

for epoch in range(300):
    for task in range(len(x_train)):
        len_train = len(x_train[task])
        x_train_task = x_train[task]
        y_train_task = y_train[task]

        # Context is 50% of the training data, last 50% is training data
        x_c = x_train_task[:len_train // 2]
        y_c = y_train_task[:len_train // 2]

        x_t = x_train_task[len_train // 2:]
        y_t = y_train_task[len_train // 2:]

        # Concatenate x_c and y_c to form the context
        context =[x_c.view(-1, 1), y_c.view(-1, 1)], dim=1)

        representation = encoder(context)
        representation = representation.mean(dim=0, keepdim=True)
        target_repr = representation.repeat(x_t.shape[0], 1)
        joint_target_x =[target_repr, x_t.view(-1, 1)], dim=1)
        y_pred = decoder(joint_target_x)

        l = criterion(y_pred, y_t)

    if epoch % 30 == 0:
        print(f'Epoch {epoch} loss {l:.2f}')
Predict on all tasks

fig, axes = plt.subplots(2, 5, figsize=(8, 4), sharex=True, sharey=True)
axes = axes.flatten()
for task, ax in zip(range(len(x_train)), axes):
    # Context is 50% of the training data, last 50% is training data
    # x_c = x_train_task[:len_train // 2]
    # y_c = y_train_task[:len_train // 2]

    # x_t = x_train_task[len_train // 2:]
    # y_t = y_train_task[len_train // 2:]

    # Concatenate x_c and y_c to form the context
    context =[x_train[task].view(-1, 1), y_train[task].view(-1, 1)], dim=1)
    # print("context", context.shape)

    representation = encoder(context)
    representation = representation.mean(dim=0, keepdim=True)
    # print(representation)
    target_repr = representation.repeat(x_test[task].shape[0], 1)
    joint_target_x =[target_repr, x_test[task].view(-1, 1)], dim=1)

    with torch.no_grad():
        y_pred = decoder(joint_target_x).ravel()
        # print(y_pred)
    # print(y_pred.shape)
    ax.scatter(x_train[task].cpu(), y_train[task].cpu(), label='Train', s=4)
    ax.scatter(x_test[task].cpu(), y_test[task].cpu(), label='Context', s=4)
    ax.plot(x_test[task].cpu(), y_pred.cpu(), label='Learnt function', color='k')
    ax.set_title(f'Task {task}')

### Hyper-Net for image reconstruction

from import load_mnist

ds, ds_name = load_mnist()
Dimensions:  (sample: 70000, channel: 1, x: 28, y: 28)
  * sample   (sample) int64 0 1 2 3 4 5 ... 69994 69995 69996 69997 69998 69999
  * channel  (channel) int64 0
  * x        (x) int64 27 26 25 24 23 22 21 20 19 18 17 ... 9 8 7 6 5 4 3 2 1 0
  * y        (y) int64 0 1 2 3 4 5 6 7 8 9 10 ... 18 19 20 21 22 23 24 25 26 27
Data variables:
    img      (sample, channel, x, y) float32 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
    label    (sample) float32 5.0 0.0 4.0 1.0 9.0 2.0 ... 2.0 3.0 4.0 5.0 6.0
# Read 1000 images from ds xarray dataset into PyTorch tensors

imgs = ds['img'].values[:1000]
<xarray.DataArray 'img' (sample: 70000, channel: 1, x: 28, y: 28)>
array([[[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]],

       [[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]],

       [[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]],

       [[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]],

       [[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]]], dtype=float32)
  * sample   (sample) int64 0 1 2 3 4 5 ... 69994 69995 69996 69997 69998 69999
  * channel  (channel) int64 0
  * x        (x) int64 27 26 25 24 23 22 21 20 19 18 17 ... 9 8 7 6 5 4 3 2 1 0
  * y        (y) int64 0 1 2 3 4 5 6 7 8 9 10 ... 18 19 20 21 22 23 24 25 26 27
# Make the imgs as PyTorch tensors
imgs = torch.from_numpy(imgs)
# Plot the first 10 images
fig, axes = plt.subplots(2, 5, figsize=(8, 4), sharex=True, sharey=True)
axes = axes.flatten()
for i, ax in enumerate(axes):
    ax.imshow(imgs[i].view(28, 28).cpu().numpy(), cmap='gray')
    ax.set_title(f'Image {i}')

### Coodinate MLP for image reconstruction

class CoordMLP(nn.Module):
    def __init__(self, in_dim=2, out_dim=1, hidden_dim=64):
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = torch.sin(self.fc1(x))
        x = torch.sin(self.fc2(x))
        return self.fc3(x)
coord = CoordMLP()

# Fit the model on the first image
optimizer = torch.optim.Adam(coord.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Create input to the model
x = torch.linspace(-1, 1, 28)
y = torch.linspace(-1, 1, 28)

# Create a grid of x and y
x_grid, y_grid = torch.meshgrid(x, y)

# Flatten the grid
x_flat = x_grid.flatten()
y_flat = y_grid.flatten()

# Concatenate x and y to form the input
inp =[x_flat.view(-1, 1), y_flat.view(-1, 1)], dim=1)

# Create the target
target = imgs[0].flatten()

for epoch in range(1000):

    pred = coord(inp)
    l = criterion(pred, target)

    if epoch % 100 == 0:
        print(f'Epoch {epoch} loss {l:.2f}')
# Plot the first 10 images
fig, axes = plt.subplots(2, 5, figsize=(8, 4), sharex=True, sharey=True)
axes = axes.flatten()
for i, ax in enumerate(axes):
    ax.imshow(imgs[i], cmap='gray')
    ax.set_title(f'Label: {ds["label"].values[i]}')
