import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
# Generating data with heteroscedastic noise
= torch.linspace(-5, 5, 200)
x_lin = lambda x: torch.sin(x) + 0.5 * x
f_true = torch.randn_like(x_lin)*0.1*(x_lin+5) + 0.1*torch.cos(x_lin)*0.1*(x_lin-5)
eps = f_true(x_lin) + eps
y_lin
'k--')
plt.plot(x_lin, f_true(x_lin), 'o', alpha=0.5)
plt.plot(x_lin, y_lin, 'x')
plt.xlabel('y') plt.ylabel(
Text(0, 0.5, 'y')
# Define a simple neural network with three outputs corresponding to the three quantiles
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(1, 20)
self.fc2 = nn.Linear(20, 3)
def forward(self, x):
= F.relu(self.fc1(x))
x = self.fc2(x)
x return x
def loss(y_pred, y, tau):
= -1*(y_pred - y) #check for sign
e return torch.mean(torch.max(tau*e, (tau-1)*e))
= Net()
net = optim.Adam(net.parameters(), lr=0.01) optimizer
0].reshape(-1, 1)) net(x_lin[
tensor([[-1.3076, -1.6280, -0.6429]], grad_fn=<AddmmBackward0>)
= [0.1, 0.5, 0.9]
taus
# Plot the predictions of the network before training
def plot_quantiles(net):
= net(x_lin.reshape(-1, 1))
y_pred 'k--')
plt.plot(x_lin, f_true(x_lin), 'o', alpha=0.5)
plt.plot(x_lin, y_lin, for i, tau in enumerate(taus):
=tau)
plt.plot(x_lin, y_pred[:, i].detach(), label
plt.legend() plot_quantiles(net)
for epoch in range(200):
optimizer.zero_grad()= net(x_lin.unsqueeze(1))
y_pred = sum([loss(y_pred[:, i], y_lin, tau) for i, tau in enumerate(taus)])
loss_val
loss_val.backward()
optimizer.step()if epoch % 100 == 0:
print('Epoch {}: loss {}'.format(epoch, loss_val.item()))
Epoch 0: loss 1.4152570962905884
Epoch 100: loss 0.4124695360660553
plot_quantiles(net)
for i, tau in enumerate(taus):
print(f'Fraction of points lesser than {tau}th quantile: {(y_lin < y_pred[:, i]).float().mean():0.2f}')
Fraction of points lesser than 0.1th quantile: 0.08
Fraction of points lesser than 0.5th quantile: 0.49
Fraction of points lesser than 0.9th quantile: 0.90