import numpy as np
import matplotlib.pyplot as plt
import torch
import seaborn as sns
import pandas as pd
=torch.distributions
dist
sns.reset_defaults()="talk", font_scale=1)
sns.set_context(context%matplotlib inline
%config InlineBackend.figure_format='retina'
Basic Imports
= torch.tensor([[1.0, 0.0],
pt1 1.0, 0.0],
[1.0, 0.0],
[1.0, 0.0]
[
])= torch.tensor([[0.5, 0.5],
pt2 0.5, 0.5],
[0.5, 0.5],
[0.5, 0.5]
[
])= torch.tensor([[0.0, 1.0],
pt3 1.0, 0.0],
[0.0, 1.0],
[1.0, 0.0]
[ ])
def avg_prob(pt):
return torch.mean(pt, dim=0)
avg_prob(pt1), avg_prob(pt2), avg_prob(pt3)
(tensor([1., 0.]), tensor([0.5000, 0.5000]), tensor([0.5000, 0.5000]))
def predictive_entropy(pt):
= avg_prob(pt)
avg return -torch.sum(avg * torch.log2(avg))
predictive_entropy(pt1), predictive_entropy(pt2), predictive_entropy(pt3)
(tensor(nan), tensor(1.), tensor(1.))
# Numerically stable version to compute entropy by avoiding log(0)
def predictive_entropy(pt):
= avg_prob(pt)
avg return -torch.sum(avg * torch.log2(avg + 1e-8))
predictive_entropy(pt1), predictive_entropy(pt2), predictive_entropy(pt3)
(tensor(-0.), tensor(1.), tensor(1.))
def expected_entropy(pt):
return torch.mean(torch.sum(-pt * torch.log2(pt + 1e-8), dim=1))
expected_entropy(pt1), expected_entropy(pt2), expected_entropy(pt3)
(tensor(0.), tensor(1.), tensor(0.))
def mutual_information(pt):
return predictive_entropy(pt) - expected_entropy(pt)
mutual_information(pt1), mutual_information(pt2), mutual_information(pt3)
(tensor(-0.), tensor(0.), tensor(1.))
Example from Gal et al.
# Training data
# Between -5 and -3, y = 0, and we have n1 = 100
# Between -3 and -2, y = 1, and we have n2 = 50
# Between 1 and 2.5, y = 0, and we have n3 = 75
# Between 2.5 and 5, y = 1, and we have n4 = 125
= 30
fac # Generate data
= 1000//fac
n1 = 500//fac
n2 = 750//fac
n3 = 1250//fac
n4
= dist.Uniform(-5, -3).sample((n1,))
x1 = dist.Uniform(-3, -2.5).sample((n2,))
x2 = dist.Uniform(1.5, 2.5).sample((n3,))
x3 = dist.Uniform(2.5, 5).sample((n4,))
x4
= torch.cat([x1, x2, x3, x4])
x = torch.cat([torch.zeros(n1), torch.ones(n2), torch.zeros(n3), torch.ones(n4)])
y
=(8, 6))
plt.figure(figsize=1)
plt.scatter(x, y, s"x")
plt.xlabel("y")
plt.ylabel(
# Move to GPU
= x.cuda()
x = y.cuda() y
# Simple MLP with 4 layers and dropout in between
import torch.functional as F
class MLP(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, dropout_prob):
super().__init__()
self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
self.fc3 = torch.nn.Linear(hidden_dim, output_dim)
self.dropout = torch.nn.Dropout(dropout_prob)
def forward(self, x):
= torch.nn.GELU()(self.fc1(x))
x = self.dropout(x)
x = torch.nn.GELU()(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x)
x return x
# Training loop
def train(model, x, y, optimizer, loss_fn, num_epochs):
for epoch in range(num_epochs):
optimizer.zero_grad()= model(x)
y_pred = loss_fn(y_pred.squeeze(), y)
loss
loss.backward()
optimizer.step()if epoch % 300 == 0:
print(f"Epoch {epoch}, loss {loss.item():.4f}")
# Train the model
= MLP(1, 16, 1, 0.2).cuda()
model = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = torch.nn.BCEWithLogitsLoss()
loss_fn 1), y, optimizer, loss_fn, 4000) train(model, x.unsqueeze(
Epoch 0, loss 0.7011
Epoch 300, loss 0.0829
Epoch 600, loss 0.0474
Epoch 900, loss 0.0192
# At test time, we want to use MC dropout to get the predictive distribution
def predict(model, x, num_mc_samples):
model.train()= []
y_preds for _ in range(num_mc_samples):
= torch.sigmoid(model(x.unsqueeze(1)))
y_pred
y_preds.append(y_pred.detach().cpu().numpy())return np.concatenate(y_preds, axis=1)
# Plot the predictive distribution
= torch.linspace(-7, 7, 1000).cuda()
x_test = predict(model, x_test, 2000) # 500 MC samples y_test
= torch.from_numpy(y_test)
y_test_orig y_test_orig.shape
torch.Size([1000, 2000])
=(8, 6))
plt.figure(figsize=50, alpha=0.5)
plt.scatter(x.cpu(), y.cpu(), s=1).cpu(), color="black", label="Mean prediction")
plt.plot(x_test.cpu(), torch.mean(y_test_orig, axis2.5, axis=1), np.percentile(y_test_orig.cpu(), 97.5, axis=1), alpha=0.3, color="black", label="95% CI")
plt.fill_between(x_test.cpu(), np.percentile(y_test_orig.cpu(), "x")
plt.xlabel("y")
plt.ylabel(# Legend outside the plot
=(1.05, 1), loc='upper left') plt.legend(bbox_to_anchor
y_test_orig
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,
0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,
0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 2.0430e-26, 0.0000e+00,
0.0000e+00],
...,
[5.0677e-01, 1.0000e+00, 2.8454e-01, ..., 1.0000e+00, 1.0000e+00,
1.0000e+00],
[1.0000e+00, 1.0000e+00, 1.0000e+00, ..., 6.4444e-01, 1.0000e+00,
1.0000e+00],
[9.9873e-01, 1.4383e-01, 3.9857e-01, ..., 1.0000e+00, 1.0000e+00,
1.0000e+00]])
# Convert y_test to (n_samples, n_mc_samples, n_classes)
= torch.stack([1 - y_test_orig, y_test_orig], axis=2)
y_test
y_test.shape
torch.Size([1000, 2000, 2])
# Compute the predictive entropy for a single example
0]) predictive_entropy(y_test[
tensor(0.0214)
# Find MI for a single example
0]) mutual_information(y_test[
tensor(0.0182)
=(8, 6))
plt.figure(figsize=10, alpha=0.5)
plt.scatter(x.cpu(), y.cpu(), s=1).cpu(), color="black", label="Mean prediction")
plt.plot(x_test.cpu(), torch.mean(y_test_orig, axis2.5, axis=1), np.percentile(y_test_orig.cpu(), 97.5, axis=1), alpha=0.3, color="black", label="95% CI")
plt.fill_between(x_test.cpu(), np.percentile(y_test_orig.cpu(), # Use vmap from torch to compute the predictive entropy for all examples
= torch.func.vmap(predictive_entropy)(y_test)
pred_entropy_vals ="red", label="Predictive entropy")
plt.plot(x_test.cpu(), pred_entropy_vals, color
# Use vmap from torch to compute the mutual information for all examples
= torch.func.vmap(mutual_information)(y_test)
mi_vals ="blue", label="Mutual information")
plt.plot(x_test.cpu(), mi_vals, color
# Legend outside the plot
=(1.05, 1), loc='upper left') plt.legend(bbox_to_anchor