import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
# Create "S" shaped data
= torch.linspace(0, 2, 150)
x_region1 = torch.linspace(1, 3, 150)
x_region2 = torch.linspace(2, 4, 150)
x_region3
= lambda x: 2*x - 1
f1 = lambda x: -2*x + 6
f2 = lambda x: torch.tensor([0.5]*len(x))
f3
= f1(x_region1)
f1_x = f2(x_region2)
f2_x = f3(x_region3)
f3_x
= torch.cat([x_region1, x_region2, x_region3])
x_total = torch.cat([f1_x, f2_x, f3_x])
f_total
= f_total + torch.randn_like(f_total)*0.2 y_total
'o', ms=10, c='k', alpha=0.2) plt.plot(x_total, y_total,
class Linear(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
= Linear(1, 1)
e1 = Linear(1, 1)
e2 = Linear(1, 1) e3
class Gating(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = torch.nn.Linear(input_dim, output_dim)
def forward(self, x):
= F.softmax(self.linear(x), dim=1)
pred return pred
= Gating(1, 3) g
= torch.linspace(0, 4, 100).view(-1, 1)
x_test = g(x_test)
weights print(weights.shape)
torch.Size([100, 3])
= torch.cat([e1(x_test), e2(x_test), e3(x_test)], dim=1)
predictions print(predictions.shape)
torch.Size([100, 3])
= torch.sum(weights*predictions, dim=1)
overall_prediction print(overall_prediction.shape)
torch.Size([100])
def predict(g, e1, e2, e3, x):
= g(x)
weights = torch.cat([e1(x),
predictions
e2(x),=1)
e3(x)], dim= torch.sum(weights*predictions, dim=1)
overall_prediction return overall_prediction
def plot(g, e1, e2, e3, x_test):
= predict(g, e1, e2, e3, x_test).detach().numpy()
y_hat =3, label='f1')
plt.plot(x_region1, f1_x, lw=3, label='f2')
plt.plot(x_region2, f2_x, lw=3, label='f3')
plt.plot(x_region3, f3_x, lw'o', ms=5, c='k', alpha=0.1, label='data')
plt.plot(x_total, y_total, 'r', lw=3, label='Overall prediction')
plt.plot(x_test, y_hat, # prediction of each expert
= e1(x_test).detach().numpy()
y_hat_e1 = e2(x_test).detach().numpy()
y_hat_e2 = e3(x_test).detach().numpy()
y_hat_e3 '--', lw=2, label='Expert 1')
plt.plot(x_test, y_hat_e1, '--', lw=2, label='Expert 2')
plt.plot(x_test, y_hat_e2, '--', lw=2, label='Expert 3')
plt.plot(x_test, y_hat_e3, # legend outside the plot
=(1.05, 1), loc='upper left', borderaxespad=0.) plt.legend(bbox_to_anchor
plot(g, e1, e2, e3, x_test)
= torch.optim.Adam(list(e1.parameters()) +
optimizer list(e2.parameters()) +
list(e3.parameters()) +
list(g.parameters()), lr=0.01)
for i in range(5000):
optimizer.zero_grad()= torch.mean((y_total - predict(g, e1, e2, e3, x_total.view(-1, 1)))**2)
loss
loss.backward()
optimizer.step()if i % 500 == 0:
print(loss.item())
2.34350848197937
0.3760591149330139
0.34844234585762024
0.3476337492465973
0.34751665592193604
0.34749242663383484
0.34747281670570374
0.34743985533714294
0.347318172454834
0.33742383122444153
plot(g, e1, e2, e3, x_test)
= g(x_test).detach().numpy()
weights for i in range(3):
=f'Weight {i}')
plt.plot(x_test, weights[:, i], label plt.legend()