import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'
# PyTorch device CUDA0
= torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device
Capturing uncertainty in neural nets:
- aleatoric
- homoskedastic (fixed)
- homoskedastic (learnt)
- heteroskedastic
- epistemic uncertainty (Laplace approximation)
- both aleatoric and epistemic uncertainty
Imports
from tueplots import bundles
plt.rcParams.update(bundles.beamer_moml())
# Also add despine to the bundle using rcParams
"axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
plt.rcParams[
# Increase font size to match Beamer template
"font.size"] = 16
plt.rcParams[# Make background transparent
"figure.facecolor"] = "none" plt.rcParams[
Models capturing aleatoric uncertainty
Dataset containing homoskedastic noise
# Generate data
42)
torch.manual_seed(= 100
N = torch.linspace(-1, 1, N)
x_lin
= lambda x: 0.5 * x**2 + 0.25 * x**3
f
= torch.randn(N) * 0.2
eps
= f(x_lin) + eps
y
# Move to GPU
= x_lin.to(device)
x_lin = y.to(device) y
# Plot data and true function
"o", label="data")
plt.plot(x_lin.cpu(), y.cpu(), ="true function")
plt.plot(x_lin.cpu(), f(x_lin).cpu(), label"x")
plt.xlabel("y")
plt.ylabel( plt.legend()
<matplotlib.legend.Legend at 0x7fae8c30f3a0>
Case 1.1: Models assuming Homoskedastic noise
Case 1.1.1: Homoskedastic noise is fixed beforehand and not learned
class HomoskedasticNNFixedNoise(torch.nn.Module):
def __init__(self, n_hidden=10):
super().__init__()
self.fc1 = torch.nn.Linear(1, n_hidden)
self.fc2 = torch.nn.Linear(n_hidden, n_hidden)
self.fc3 = torch.nn.Linear(n_hidden, 1)
def forward(self, x):
= self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
x = torch.relu(x)
x = self.fc3(x)
mu_hat return mu_hat
def loss_homoskedastic_fixed_noise(model, x, y, log_noise_std):
= model(x).squeeze()
mu_hat assert mu_hat.shape == y.shape
= torch.exp(log_noise_std).expand_as(mu_hat)
noise_std = torch.distributions.Normal(mu_hat, noise_std)
dist return -dist.log_prob(y).mean()
= HomoskedasticNNFixedNoise()
model_1 # Move to GPU
model_1.to(device)
HomoskedasticNNFixedNoise(
(fc1): Linear(in_features=1, out_features=10, bias=True)
(fc2): Linear(in_features=10, out_features=10, bias=True)
(fc3): Linear(in_features=10, out_features=1, bias=True)
)
= torch.log(torch.tensor(0.5)).to(device)
fixed_log_noise_std None], y, fixed_log_noise_std) loss_homoskedastic_fixed_noise(model_1, x_lin[:,
tensor(0.3774, device='cuda:0', grad_fn=<NegBackward0>)
def plot_model_1(
="Untrained model", log_noise_param=fixed_log_noise_std, aleatoric=True
model, l
):with torch.no_grad():
= model(x_lin[:, None]).squeeze().cpu()
y_hat = torch.exp(log_noise_param).cpu()
std
=10, color="C0", label="Data")
plt.scatter(x_lin.cpu(), y.cpu(), s="C1", label="True function")
plt.plot(x_lin.cpu(), f(x_lin.cpu()), color="C2", label=l)
plt.plot(x_lin.cpu(), y_hat.cpu(), colorif aleatoric:
# Plot the +- 2 sigma region (where sigma is fixed to 0.2)
plt.fill_between(- 2 * std, y_hat + 2 * std, alpha=0.2, color="C2"
x_lin.cpu(), y_hat
)"$x$")
plt.xlabel("$y$")
plt.ylabel(
plt.legend()
"Untrained Model") plot_model_1(model_1,
def train(model, loss_func, params, x, y, log_noise_param, n_epochs=1000, lr=0.01):
= torch.optim.Adam(params, lr=lr)
optimizer for epoch in range(n_epochs):
# Print every 10 epochs
if epoch % 50 == 0:
= torch.exp(
noise_std
log_noise_param# Calculate the noise standard deviation
) print(f"Epoch {epoch}: loss {loss_func(model, x, y, log_noise_param)}")
optimizer.zero_grad()= loss_func(model, x, y, log_noise_param)
loss
loss.backward()
optimizer.step()return loss.item()
= list(model_1.parameters())
params
train(
model_1,
loss_homoskedastic_fixed_noise,
params,None],
x_lin[:,
y,
fixed_log_noise_std,=1000,
n_epochs=0.001,
lr )
Epoch 0: loss 0.37737852334976196
Epoch 50: loss 0.36576324701309204
Epoch 100: loss 0.36115822196006775
Epoch 150: loss 0.35357466340065
Epoch 200: loss 0.3445783853530884
Epoch 250: loss 0.33561965823173523
Epoch 300: loss 0.3253604769706726
Epoch 350: loss 0.3168313503265381
Epoch 400: loss 0.3112250566482544
Epoch 450: loss 0.30827972292900085
Epoch 500: loss 0.30689162015914917
Epoch 550: loss 0.30621206760406494
Epoch 600: loss 0.3059086203575134
Epoch 650: loss 0.30568134784698486
Epoch 700: loss 0.3052031695842743
Epoch 750: loss 0.305046409368515
Epoch 800: loss 0.30500170588493347
Epoch 850: loss 0.3049851953983307
Epoch 900: loss 0.30497878789901733
Epoch 950: loss 0.3049766719341278
0.3049757182598114
"Trained Model") plot_model_1(model_1,
Case 1.1.2: Homoskedastic noise is learnt from the data
The model is the same as in case 1.1.1, but the noise is learned from the data. Thus, we need to modify the loss function to include the noise parameter \(\sigma\).
# Define the loss function
def loss_homoskedastic_learned_noise(model, x, y, noise):
= model(x)
mean = torch.distributions.Normal(mean, noise)
dist return -dist.log_prob(y).mean()
= HomoskedasticNNFixedNoise()
model_2 = torch.nn.Parameter(torch.tensor(0.0).to(device))
log_noise_param
# Move to GPU
model_2.to(device)
HomoskedasticNNFixedNoise(
(fc1): Linear(in_features=1, out_features=10, bias=True)
(fc2): Linear(in_features=10, out_features=10, bias=True)
(fc3): Linear(in_features=10, out_features=1, bias=True)
)
# Plot the untrained model
"Untrained Model", log_noise_param=log_noise_param) plot_model_1(model_2,
# Train the model
= list(model_2.parameters()) + [log_noise_param]
params
train(
model_2,
loss_homoskedastic_fixed_noise,
params,None],
x_lin[:,
y,
log_noise_param,=1000,
n_epochs=0.01,
lr )
Epoch 0: loss 0.9752065539360046
Epoch 50: loss 0.4708317220211029
Epoch 100: loss 0.06903044134378433
Epoch 150: loss -0.19027572870254517
Epoch 200: loss -0.31068307161331177
Epoch 250: loss -0.3400699496269226
Epoch 300: loss -0.3422872722148895
Epoch 350: loss -0.3424949049949646
Epoch 400: loss -0.3450787365436554
Epoch 450: loss -0.348661333322525
Epoch 500: loss -0.3499276638031006
Epoch 550: loss -0.34957197308540344
Epoch 600: loss -0.3541718125343323
Epoch 650: loss -0.3527887165546417
Epoch 700: loss -0.3585814833641052
Epoch 750: loss -0.3589327931404114
Epoch 800: loss -0.3592424690723419
Epoch 850: loss -0.3592303693294525
Epoch 900: loss -0.3538239300251007
Epoch 950: loss -0.35970067977905273
-0.3580436110496521
# Plot the trained model
"Trained Model", log_noise_param=log_noise_param) plot_model_1(model_2,
Case 1.2: Models assuming heteroskedastic noise
#### Heteroskedastic noise model
class HeteroskedasticNN(torch.nn.Module):
def __init__(self, n_hidden=10):
super().__init__()
self.fc1 = torch.nn.Linear(1, n_hidden)
self.fc2 = torch.nn.Linear(n_hidden, n_hidden)
self.fc3 = torch.nn.Linear(n_hidden, 2) # we learn both mu and log_noise_std
def forward(self, x):
= self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
x = torch.relu(x)
x = self.fc3(x)
z = z[:, 0]
mu_hat = z[:, 1]
log_noise_std return mu_hat, log_noise_std
= HeteroskedasticNN()
model_3 # Move to GPU
model_3.to(device)
HeteroskedasticNN(
(fc1): Linear(in_features=1, out_features=10, bias=True)
(fc2): Linear(in_features=10, out_features=10, bias=True)
(fc3): Linear(in_features=10, out_features=2, bias=True)
)
def _plot(y_hat, std, l="Untrained model", aleatoric=True):
=10, color="C0", label="Data")
plt.scatter(x_lin.cpu(), y.cpu(), s="C1", label="True function")
plt.plot(x_lin.cpu(), f(x_lin.cpu()), color="C2", label=l)
plt.plot(x_lin.cpu(), y_hat.cpu(), colorif aleatoric:
# Plot the +- 2 sigma region (where sigma is fixed to 0.2)
plt.fill_between(- 2 * std, y_hat + 2 * std, alpha=0.2, color="C2"
x_lin.cpu(), y_hat
)
"$x$")
plt.xlabel("$y$")
plt.ylabel( plt.legend()
def plot_heteroskedastic_model(
="Untrained model", log_noise_param=fixed_log_noise_std
model, l
):with torch.no_grad():
= model(x_lin[:, None])
y_hat, log_noise_std = torch.exp(log_noise_std).cpu()
std = y_hat.cpu()
y_hat
_plot(y_hat, std, l)
"Untrained Model") plot_heteroskedastic_model(model_3,
# Train
= list(model_3.parameters())
params
def loss_heteroskedastic(model, x, y):
= model(x)
mu_hat, log_noise_std = torch.exp(log_noise_std)
noise_std = torch.distributions.Normal(mu_hat, noise_std)
dist return -dist.log_prob(y).mean()
def train_heteroskedastic(model, loss_func, params, x, y, n_epochs=1000, lr=0.01):
= torch.optim.Adam(params, lr=lr)
optimizer for epoch in range(n_epochs):
# Print every 10 epochs
if epoch % 50 == 0:
print(f"Epoch {epoch}: loss {loss_func(model, x, y)}")
optimizer.zero_grad()= loss_func(model, x, y)
loss
loss.backward()
optimizer.step()return loss.item()
train_heteroskedastic(None], y, n_epochs=1000, lr=0.001
model_3, loss_heteroskedastic, params, x_lin[:, )
Epoch 0: loss 0.8471216559410095
Epoch 50: loss 0.5730974674224854
Epoch 100: loss 0.21664416790008545
Epoch 150: loss -0.0901612937450409
Epoch 200: loss -0.18870659172534943
Epoch 250: loss -0.21993499994277954
Epoch 300: loss -0.23635315895080566
Epoch 350: loss -0.24628539383411407
Epoch 400: loss -0.2610110640525818
Epoch 450: loss -0.2744489908218384
Epoch 500: loss -0.28417855501174927
Epoch 550: loss -0.29257726669311523
Epoch 600: loss -0.30057546496391296
Epoch 650: loss -0.3061988353729248
Epoch 700: loss -0.31035324931144714
Epoch 750: loss -0.3138866126537323
Epoch 800: loss -0.3167155683040619
Epoch 850: loss -0.3188284635543823
Epoch 900: loss -0.32069918513298035
Epoch 950: loss -0.32232046127319336
-0.323951780796051
# Plot the trained model
"Trained Model") plot_heteroskedastic_model(model_3,
Data with heteroskedastic noise
# Now, let us try these on some data that is not homoskedastic
# Generate data
42)
torch.manual_seed(= 100
N = torch.linspace(-1, 1, N)
x_lin
= lambda x: 0.5 * x**2 + 0.25 * x**3
f
= torch.randn(N) * (0.1 + 0.4 * x_lin)
eps
= f(x_lin) + eps
y
# Move to GPU
= x_lin.to(device)
x_lin = y.to(device)
y
# Plot data and true function
"o", label="data")
plt.plot(x_lin.cpu(), y.cpu(), ="true function")
plt.plot(x_lin.cpu(), f(x_lin).cpu(), label"x")
plt.xlabel("y")
plt.ylabel( plt.legend()
<matplotlib.legend.Legend at 0x7fae80735370>
# Now, fit the homoskedastic model
= HomoskedasticNNFixedNoise()
model_4 # Move to GPU
model_4.to(device)
= torch.log(torch.tensor(0.5)).to(device)
fixed_log_noise_std
# Plot the untrained model
"Untrained Model", log_noise_param=fixed_log_noise_std)
plot_model_1(model_4,
# Train the model
= list(model_4.parameters()) params
train(
model_4,
loss_homoskedastic_fixed_noise,
params,None],
x_lin[:,
y,
fixed_log_noise_std,=1000,
n_epochs=0.001,
lr
)
# Plot the trained model
"Trained Model", log_noise_param=fixed_log_noise_std) plot_model_1(model_4,
Epoch 0: loss 0.4041783809661865
Epoch 50: loss 0.3925568163394928
Epoch 100: loss 0.38583633303642273
Epoch 150: loss 0.3760771155357361
Epoch 200: loss 0.3653101623058319
Epoch 250: loss 0.3552305996417999
Epoch 300: loss 0.3461504280567169
Epoch 350: loss 0.3392468988895416
Epoch 400: loss 0.33482256531715393
Epoch 450: loss 0.3323317766189575
Epoch 500: loss 0.3309779167175293
Epoch 550: loss 0.3298887312412262
Epoch 600: loss 0.3292028307914734
Epoch 650: loss 0.32881179451942444
Epoch 700: loss 0.3285224139690399
Epoch 750: loss 0.3282739520072937
Epoch 800: loss 0.32803454995155334
Epoch 850: loss 0.32778945565223694
Epoch 900: loss 0.3275414705276489
Epoch 950: loss 0.3272767663002014
# Now, fit the homoskedastic model with learned noise
= HomoskedasticNNFixedNoise()
model_5 = torch.nn.Parameter(torch.tensor(0.0).to(device))
log_noise_param
# Move to GPU
model_5.to(device)
# Plot the untrained model
"Untrained Model", log_noise_param=log_noise_param) plot_model_1(model_5,
# Train the model
= list(model_5.parameters()) + [log_noise_param]
params
train(
model_5,
loss_homoskedastic_fixed_noise,
params,None],
x_lin[:,
y,
log_noise_param,=1000,
n_epochs=0.01,
lr
)
# Plot the trained model
"Trained Model", log_noise_param=log_noise_param) plot_model_1(model_5,
Epoch 0: loss 0.9799039959907532
Epoch 50: loss 0.487409383058548
Epoch 100: loss 0.10527730733156204
Epoch 150: loss -0.13744640350341797
Epoch 200: loss -0.225139319896698
Epoch 250: loss -0.24497787654399872
Epoch 300: loss -0.24193085730075836
Epoch 350: loss -0.2481481432914734
Epoch 400: loss -0.2538861632347107
Epoch 450: loss -0.25253206491470337
Epoch 500: loss -0.2516896426677704
Epoch 550: loss -0.2537526786327362
Epoch 600: loss -0.25108057260513306
Epoch 650: loss -0.2549664378166199
Epoch 700: loss -0.2547767460346222
Epoch 750: loss -0.25436392426490784
Epoch 800: loss -0.2519540786743164
Epoch 850: loss -0.25377777218818665
Epoch 900: loss -0.2500659227371216
Epoch 950: loss -0.2547629475593567
# Now, fit the heteroskedastic model
= HeteroskedasticNN()
model_6 # Move to GPU
model_6.to(device)
# Plot the untrained model
"Untrained Model") plot_heteroskedastic_model(model_6,
# Train the model
= list(model_6.parameters())
params
train_heteroskedastic(None], y, n_epochs=1000, lr=0.001
model_6, loss_heteroskedastic, params, x_lin[:, )
Epoch 0: loss 0.857245147228241
Epoch 50: loss 0.5856861472129822
Epoch 100: loss 0.24207858741283417
Epoch 150: loss -0.006745247635990381
Epoch 200: loss -0.0828532800078392
Epoch 250: loss -0.13739755749702454
Epoch 300: loss -0.19300477206707
Epoch 350: loss -0.24690848588943481
Epoch 400: loss -0.29788991808891296
Epoch 450: loss -0.33707040548324585
Epoch 500: loss -0.368070513010025
Epoch 550: loss -0.3885715901851654
Epoch 600: loss -0.40538740158081055
Epoch 650: loss -0.4239954948425293
Epoch 700: loss -0.4402863085269928
Epoch 750: loss -0.45339828729629517
Epoch 800: loss -0.46312591433525085
Epoch 850: loss -0.4695749878883362
Epoch 900: loss -0.47566917538642883
Epoch 950: loss -0.4823095500469208
-0.4882388710975647
# Plot the trained model
"Trained Model") plot_heteroskedastic_model(model_6,
Epistemic Uncertainty: Bayesian NN with Laplace approximation
= HomoskedasticNNFixedNoise()
model_7 # Move to GPU
model_7.to(device)
def negative_log_prior_last_layer(model):
= torch.distributions.Normal(0, 1).log_prob(model.fc3.weight).sum()
log_prior return -log_prior
def negative_log_prior(model):
= 0
log_prior
for param in model.parameters():
+= torch.distributions.Normal(0, 1).log_prob(param).sum()
log_prior return -log_prior
def negative_log_likelihood(model, x, y, log_noise_std):
= model(x).squeeze()
mu_hat assert mu_hat.shape == y.shape
= torch.exp(log_noise_std).expand_as(mu_hat)
noise_std = torch.distributions.Normal(mu_hat, noise_std)
dist return -dist.log_prob(y).sum()
def negative_log_joint(model, x, y, log_noise_std):
return negative_log_likelihood(model, x, y, log_noise_std) + negative_log_prior(
model
)
def negative_log_joint_last_layer(model, x, y, log_noise_std):
return negative_log_likelihood(
model, x, y, log_noise_std+ negative_log_prior_last_layer(model) )
negative_log_prior(model_7), negative_log_prior_last_layer(model_7)
(tensor(135.7282, device='cuda:0', grad_fn=<NegBackward0>),
tensor(9.3809, device='cuda:0', grad_fn=<NegBackward0>))
None], y, fixed_log_noise_std) negative_log_likelihood(model_7, x_lin[:,
tensor(42.9679, device='cuda:0', grad_fn=<NegBackward0>)
negative_log_joint(None], y, fixed_log_noise_std
model_7, x_lin[:, None], y, fixed_log_noise_std) ), negative_log_joint_last_layer(model_7, x_lin[:,
(tensor(178.6961, device='cuda:0', grad_fn=<AddBackward0>),
tensor(52.3488, device='cuda:0', grad_fn=<AddBackward0>))
# Find the MAP estimate
= list(model_7.parameters())
params
train(
model_7,
negative_log_joint_last_layer,
params,None],
x_lin[:,
y,
fixed_log_noise_std,=1000,
n_epochs=0.01,
lr )
Epoch 0: loss 52.34880065917969
Epoch 50: loss 42.465431213378906
Epoch 100: loss 41.26828384399414
Epoch 150: loss 40.83299255371094
Epoch 200: loss 40.593467712402344
Epoch 250: loss 40.43766784667969
Epoch 300: loss 40.352142333984375
Epoch 350: loss 40.29512405395508
Epoch 400: loss 40.25150680541992
Epoch 450: loss 40.21705627441406
Epoch 500: loss 40.18949890136719
Epoch 550: loss 40.16746520996094
Epoch 600: loss 40.14928436279297
Epoch 650: loss 40.134132385253906
Epoch 700: loss 40.12129211425781
Epoch 750: loss 39.98218536376953
Epoch 800: loss 39.971900939941406
Epoch 850: loss 39.92092514038086
Epoch 900: loss 39.89830780029297
Epoch 950: loss 39.89033889770508
39.86236572265625
# Plot the trained model (MAP)
plot_model_1("Trained Model (MAP)", log_noise_param=fixed_log_noise_std, aleatoric=False
model_7, )
What weighs to consider?
Goal: Compute the Hessian of the negative log joint wrt the last layer weights
Challenge: The negative log joint is a function of all the weights, not just the last layer weights
Aside on functools.partial
The functools module is for higher-order functions: functions that act on or return other functions. In general, any callable object can be treated as a function for the purposes of this module.
print(int("1001", base=2), int("1001", base=4), int("1001"))
from functools import partial
= partial(int, base=2)
base_two = "Convert base 2 string to an int."
base_two.__doc__
print(base_two)
print(base_two.__doc__)
print(help(base_two))
print(base_two("1001"))
9 65 1001
functools.partial(<class 'int'>, base=2)
Convert base 2 string to an int.
Help on partial:
functools.partial(<class 'int'>, base=2)
Convert base 2 string to an int.
None
9
dict(model_7.named_parameters())
{'fc1.weight': Parameter containing:
tensor([[-1.0390],
[ 1.8074],
[-0.8156],
[ 0.4683],
[-0.2185],
[-1.9768],
[ 0.4070],
[ 0.1533],
[ 0.4458],
[ 1.7506]], device='cuda:0', requires_grad=True),
'fc1.bias': Parameter containing:
tensor([ 1.0401, 1.7753, 0.8183, -0.8602, -0.7015, -1.7802, -0.8812, -0.5012,
-0.9206, -1.0502], device='cuda:0', requires_grad=True),
'fc2.weight': Parameter containing:
tensor([[-1.8896e-01, -3.1175e-01, -1.9410e-01, 1.2058e-01, 2.6375e-01,
-9.4066e-02, -9.1984e-02, 1.6885e-01, -1.5602e-01, -1.4952e-01],
[ 6.1491e-01, -2.3305e+00, 7.1547e-01, 2.7935e-01, 5.4229e-02,
-4.4665e+00, -1.8417e-01, -4.3629e-03, 1.7388e-02, 7.7613e-02],
[ 2.4445e-01, 2.5985e-01, 8.6094e-02, 9.4790e-03, -1.5800e-01,
4.8416e+00, -2.5323e-02, -2.7836e-01, 2.2070e-01, -1.4078e+00],
[-7.7526e-01, 2.4683e-01, -1.4079e+00, -1.2232e-01, -6.1607e-02,
-4.4529e-01, -2.0109e-01, -5.1614e-02, 2.3994e-01, 1.7788e+00],
[ 2.0120e-01, -1.8883e-01, -2.0688e-01, 2.7597e-01, 1.1186e-01,
8.4014e-03, 4.2796e-02, -2.5415e-01, -1.0558e-01, 3.0441e-01],
[-3.4170e-01, 4.4826e-01, -8.3544e-01, -1.7690e-01, -6.4572e-03,
-4.6640e-01, -3.9215e-02, 1.2869e-01, -3.0933e-01, 1.1246e+00],
[-2.0909e-01, -1.5434e-01, 1.2140e-01, 2.5144e-01, -8.6428e-02,
-1.2983e-01, -2.8594e-01, -1.6307e-01, -2.7690e-01, -7.2375e-02],
[ 5.0661e-02, -4.1058e+00, 2.6399e-01, 1.9840e-01, -3.0957e-01,
6.6729e+00, 1.0314e-01, -6.4972e-02, -3.4461e-02, -1.4278e-01],
[ 1.2533e-01, -9.9739e-01, 1.1985e-01, 2.0404e-02, 6.3569e-02,
6.0431e+00, -5.2103e-02, -1.8004e-01, -5.1145e-02, 2.5648e-01],
[-4.4407e-01, -1.0588e-01, -5.1852e-01, 1.6580e-01, 1.1684e-01,
1.3406e-02, 1.3572e-01, 3.5567e-04, 1.7499e-01, 7.3931e-02]],
device='cuda:0', requires_grad=True),
'fc2.bias': Parameter containing:
tensor([-0.0464, -0.0669, 0.2887, -0.1685, -0.2300, 0.3106, -0.0705, -0.7083,
-0.8258, -0.2307], device='cuda:0', requires_grad=True),
'fc3.weight': Parameter containing:
tensor([[ 2.8882e-24, 1.9463e-01, -1.0051e-01, 1.8243e-01, -2.4331e-25,
1.3989e-01, 2.1068e-24, -3.3861e-01, -2.2423e-01, 9.3494e-16]],
device='cuda:0', requires_grad=True),
'fc3.bias': Parameter containing:
tensor([0.1158], device='cuda:0', requires_grad=True)}
#### Aside on state_dict in PyTorch
model_7.state_dict()
OrderedDict([('fc1.weight',
tensor([[-1.0390],
[ 1.8074],
[-0.8156],
[ 0.4683],
[-0.2185],
[-1.9768],
[ 0.4070],
[ 0.1533],
[ 0.4458],
[ 1.7506]], device='cuda:0')),
('fc1.bias',
tensor([ 1.0401, 1.7753, 0.8183, -0.8602, -0.7015, -1.7802, -0.8812, -0.5012,
-0.9206, -1.0502], device='cuda:0')),
('fc2.weight',
tensor([[-1.8896e-01, -3.1175e-01, -1.9410e-01, 1.2058e-01, 2.6375e-01,
-9.4066e-02, -9.1984e-02, 1.6885e-01, -1.5602e-01, -1.4952e-01],
[ 6.1491e-01, -2.3305e+00, 7.1547e-01, 2.7935e-01, 5.4229e-02,
-4.4665e+00, -1.8417e-01, -4.3629e-03, 1.7388e-02, 7.7613e-02],
[ 2.4445e-01, 2.5985e-01, 8.6094e-02, 9.4790e-03, -1.5800e-01,
4.8416e+00, -2.5323e-02, -2.7836e-01, 2.2070e-01, -1.4078e+00],
[-7.7526e-01, 2.4683e-01, -1.4079e+00, -1.2232e-01, -6.1607e-02,
-4.4529e-01, -2.0109e-01, -5.1614e-02, 2.3994e-01, 1.7788e+00],
[ 2.0120e-01, -1.8883e-01, -2.0688e-01, 2.7597e-01, 1.1186e-01,
8.4014e-03, 4.2796e-02, -2.5415e-01, -1.0558e-01, 3.0441e-01],
[-3.4170e-01, 4.4826e-01, -8.3544e-01, -1.7690e-01, -6.4572e-03,
-4.6640e-01, -3.9215e-02, 1.2869e-01, -3.0933e-01, 1.1246e+00],
[-2.0909e-01, -1.5434e-01, 1.2140e-01, 2.5144e-01, -8.6428e-02,
-1.2983e-01, -2.8594e-01, -1.6307e-01, -2.7690e-01, -7.2375e-02],
[ 5.0661e-02, -4.1058e+00, 2.6399e-01, 1.9840e-01, -3.0957e-01,
6.6729e+00, 1.0314e-01, -6.4972e-02, -3.4461e-02, -1.4278e-01],
[ 1.2533e-01, -9.9739e-01, 1.1985e-01, 2.0404e-02, 6.3569e-02,
6.0431e+00, -5.2103e-02, -1.8004e-01, -5.1145e-02, 2.5648e-01],
[-4.4407e-01, -1.0588e-01, -5.1852e-01, 1.6580e-01, 1.1684e-01,
1.3406e-02, 1.3572e-01, 3.5567e-04, 1.7499e-01, 7.3931e-02]],
device='cuda:0')),
('fc2.bias',
tensor([-0.0464, -0.0669, 0.2887, -0.1685, -0.2300, 0.3106, -0.0705, -0.7083,
-0.8258, -0.2307], device='cuda:0')),
('fc3.weight',
tensor([[ 2.8882e-24, 1.9463e-01, -1.0051e-01, 1.8243e-01, -2.4331e-25,
1.3989e-01, 2.1068e-24, -3.3861e-01, -2.2423e-01, 9.3494e-16]],
device='cuda:0')),
('fc3.bias', tensor([0.1158], device='cuda:0'))])
Aside on torch.func.functional_call
import torch
import torch.nn as nn
from torch.func import functional_call, grad
= torch.randn(4, 3)
x = torch.randn(4, 3)
t = nn.Linear(3, 3)
model
= functional_call(model, dict(model.named_parameters()), x)
y1 print(dict(model.named_parameters()))
= model(x)
y2 print("*"*20)
print(y1)
print(y2)
{'weight': Parameter containing:
tensor([[ 0.5021, 0.3455, 0.0904],
[ 0.1903, 0.5480, -0.3725],
[-0.2621, 0.4038, -0.3950]], requires_grad=True), 'bias': Parameter containing:
tensor([-0.3184, 0.4215, 0.1822], requires_grad=True)}
********************
tensor([[-0.3097, -0.1970, -0.7683],
[-0.0129, 0.6332, 0.0520],
[-1.2995, -0.4907, -0.0705],
[-0.0834, 0.6312, 0.1858]], grad_fn=<AddmmBackward0>)
tensor([[-0.3097, -0.1970, -0.7683],
[-0.0129, 0.6332, 0.0520],
[-1.2995, -0.4907, -0.0705],
[-0.0834, 0.6312, 0.1858]], grad_fn=<AddmmBackward0>)
from functools import partial
def functional_last_layer_neg_log_prior(last_layer_weights):
= torch.distributions.Normal(0, 1).log_prob(last_layer_weights).sum()
log_prior return -log_prior
def functional_neg_log_likelihood(state_dict, model, x, y, log_noise_std):
= torch.func.functional_call(model, state_dict, x)
out = out.squeeze()
mu_hat assert mu_hat.shape == y.shape
= torch.exp(log_noise_std).expand_as(mu_hat)
noise_std = torch.distributions.Normal(mu_hat, noise_std)
dist return -dist.log_prob(y).sum()
def functional_last_layer_neg_log_joint(
last_layer_weights, state_dict, model, x, y, log_noise_std
):"fc3.weight"] = last_layer_weights
state_dict[return functional_neg_log_likelihood(
state_dict, model, x, y, log_noise_std+ functional_last_layer_neg_log_prior(last_layer_weights) )
= model_7.state_dict()
state_dict = state_dict["fc3.weight"]
last_layer_weights
= partial(
partial_func
functional_last_layer_neg_log_joint,=state_dict,
state_dict=model_7,
model=x_lin[:, None],
x=y,
y=fixed_log_noise_std,
log_noise_std
)
= torch.func.hessian(partial_func)(last_layer_weights)
H print(H.shape)
= H[0, :, 0, :]
H print(H.shape)
torch.Size([1, 10, 1, 10])
torch.Size([10, 10])
plt.imshow(H.cpu().numpy()) plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7fae349d7be0>
import seaborn as sns
sns.heatmap(H.cpu().numpy())"equal") plt.gca().set_aspect(
= torch.inverse(H)
cov = torch.distributions.MultivariateNormal(
laplace_posterior
last_layer_weights.ravel(), cov
)= laplace_posterior.sample((101,))[..., None]
last_layer_weights_samples last_layer_weights_samples.shape
torch.Size([101, 10, 1])
"fc3.weight"].shape, last_layer_weights_samples[0].shape, x_lin[
state_dict[None
:, ].shape
(torch.Size([1, 10]), torch.Size([10, 1]), torch.Size([100, 1]))
def forward_pass(last_layer_weight):
= model_7.state_dict()
state_dict "fc3.weight"] = last_layer_weight.reshape(1, -1)
state_dict[return torch.func.functional_call(model_7, state_dict, x_lin[:, None]).squeeze()
0]).shape forward_pass(last_layer_weights_samples[
torch.Size([100])
= torch.vmap(forward_pass)(last_layer_weights_samples)
mc_outputs print(mc_outputs.shape)
torch.Size([101, 100])
= mc_outputs.mean(0)
mean_mc_outputs = mc_outputs.std(0)
std_mc_outputs mean_mc_outputs.shape, std_mc_outputs.shape
(torch.Size([100]), torch.Size([100]))
with torch.no_grad():
=10, color="C0", label="Data")
plt.scatter(x_lin.cpu(), y.cpu(), s="C1", label="True function")
plt.plot(x_lin.cpu(), f(x_lin.cpu()), color
plt.plot(="C2", label="Laplace approximation"
x_lin.cpu(), mean_mc_outputs.cpu(), color
)# Plot the +- 2 sigma region (where sigma is fixed to 0.2)
plt.fill_between(
x_lin.cpu(),- 2 * std_mc_outputs.cpu(),
mean_mc_outputs.cpu() + 2 * std_mc_outputs.cpu(),
mean_mc_outputs.cpu() =0.2,
alpha="C2",
color
)"$x$")
plt.xlabel("$y$")
plt.ylabel( plt.legend()
### Case 3: Both aleatoric and epistemic uncertainty