import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F# Generating data with heteroscedastic noise
x_lin = torch.linspace(-5, 5, 200)
f_true = lambda x: torch.sin(x) + 0.5 * x
eps = torch.randn_like(x_lin)*0.1*(x_lin+5) + 0.1*torch.cos(x_lin)*0.1*(x_lin-5)
y_lin = f_true(x_lin) + eps
plt.plot(x_lin, f_true(x_lin), 'k--')
plt.plot(x_lin, y_lin, 'o', alpha=0.5)
plt.xlabel('x')
plt.ylabel('y')Text(0, 0.5, 'y')

# Define a simple neural network with three outputs corresponding to the three quantiles
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(1, 20)
self.fc2 = nn.Linear(20, 3)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def loss(y_pred, y, tau):
e = -1*(y_pred - y) #check for sign
return torch.mean(torch.max(tau*e, (tau-1)*e))
net = Net()
optimizer = optim.Adam(net.parameters(), lr=0.01)net(x_lin[0].reshape(-1, 1))tensor([[-1.3076, -1.6280, -0.6429]], grad_fn=<AddmmBackward0>)
taus = [0.1, 0.5, 0.9]
# Plot the predictions of the network before training
def plot_quantiles(net):
y_pred = net(x_lin.reshape(-1, 1))
plt.plot(x_lin, f_true(x_lin), 'k--')
plt.plot(x_lin, y_lin, 'o', alpha=0.5)
for i, tau in enumerate(taus):
plt.plot(x_lin, y_pred[:, i].detach(), label=tau)
plt.legend()
plot_quantiles(net)
for epoch in range(200):
optimizer.zero_grad()
y_pred = net(x_lin.unsqueeze(1))
loss_val = sum([loss(y_pred[:, i], y_lin, tau) for i, tau in enumerate(taus)])
loss_val.backward()
optimizer.step()
if epoch % 100 == 0:
print('Epoch {}: loss {}'.format(epoch, loss_val.item()))Epoch 0: loss 1.4152570962905884
Epoch 100: loss 0.4124695360660553
plot_quantiles(net)
for i, tau in enumerate(taus):
print(f'Fraction of points lesser than {tau}th quantile: {(y_lin < y_pred[:, i]).float().mean():0.2f}')Fraction of points lesser than 0.1th quantile: 0.08
Fraction of points lesser than 0.5th quantile: 0.49
Fraction of points lesser than 0.9th quantile: 0.90