import torch
import torch.autograd.functional as F
import torch.distributions as dist
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import ipywidgets as widgets
import seaborn as sns
import pandas as pd
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'Sampling from an unnormalized distribution
from tueplots import bundles
plt.rcParams.update(bundles.beamer_moml())
#plt.rcParams.update(bundles.icml2022())
# Also add despine to the bundle using rcParams
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
# Increase font size to match Beamer template
plt.rcParams['font.size'] = 16
# Make background transparent
plt.rcParams['figure.facecolor'] = 'none'import hamiltorchhamiltorch.set_random_seed(123)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')devicedevice(type='cuda')
gt_distribution = torch.distributions.Normal(0, 1)
# Samples from the ground truth distribution
def sample_gt(n):
return gt_distribution.sample((n,))
samples = sample_gt(1000)x_lin = torch.linspace(-3, 3, 1000)
y_lin = torch.exp(gt_distribution.log_prob(x_lin))
plt.plot(x_lin, y_lin, label='Ground truth')
plt.legend()<matplotlib.legend.Legend at 0x7efdfc51dd30>

# Logprob function to be passed to Hamiltorch sampler
def logprob(x):
return gt_distribution.log_prob(x)logprob(torch.tensor([0.0]))tensor([-0.9189])
# Markov chain
x_start = torch.tensor([0.0])
samples = []
for i in range(100):
prop = torch.distributions.Normal(x_start, 10).sample()
samples.append(prop)
x_start = prop
plt.plot(torch.stack(samples).ravel())
# Initial state
x0 = torch.tensor([0.0])
num_samples = 5000
step_size = 0.3
num_steps_per_sample = 5
hamiltorch.set_random_seed(123)params_hmc = hamiltorch.sample(log_prob_func=logprob, params_init=x0,
num_samples=num_samples, step_size=step_size,
num_steps_per_sample=num_steps_per_sample)Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:07 | 0d:00:00:00 | #################### | 5000/5000 | 686.98
Acceptance Rate 0.99
params_hmc = torch.tensor(params_hmc)def run_hmc(logprob, x0, num_samples, step_size, num_steps_per_sample):
torch.manual_seed(123)
params_hmc = hamiltorch.sample(log_prob_func=logprob, params_init=x0,
num_samples=num_samples, step_size=step_size,
num_steps_per_sample=num_steps_per_sample)
return torch.stack(params_hmc)
params_hmc = run_hmc(logprob, x0, num_samples, step_size, num_steps_per_sample)Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:07 | 0d:00:00:00 | #################### | 5000/5000 | 676.50
Acceptance Rate 0.99
params_hmc.shapetorch.Size([5000, 1])
# Trace plot
plt.plot(params_hmc, label='Trace')
plt.xlabel('Iteration')
plt.ylabel('Parameter value')Text(0, 0.5, 'Parameter value')

# view first 500 samples
plt.plot(params_hmc[:50], label='Trace')
plt.xlabel('Iteration')
plt.ylabel('Parameter value')Text(0, 0.5, 'Parameter value')

# KDE plot
import seaborn as sns
plt.figure()
sns.kdeplot(params_hmc.ravel().detach().numpy(), label='Samples', shade=True, color='C1')
plt.plot(x_lin, y_lin, label='Ground truth')
plt.xlabel('Parameter value')
plt.ylabel('Density')
plt.legend()<matplotlib.legend.Legend at 0x7f2dbc637d30>

# Create MP4 HTML5 video showing sampling process
def create_mp4_samples(samples, x_lin, y_lin, filename='samples.mp4', dpi=600):
fig, ax = plt.subplots(figsize=(4,2), dpi=dpi)
ax.set_xlim(-3, 3)
ax.set_ylim(0, 1)
ax.set_xlabel('Parameter value')
ax.set_ylabel('Density')
ax.plot(x_lin, y_lin, label='Ground truth')
ax.legend()
# Add a "x" marker to the plot for each sample at y=0
x_marker, = ax.plot([], [], 'x', color='C1', label='Samples')
ax.legend()
def init():
x_marker.set_data([], [])
return x_marker,
def animate(i):
x_marker.set_data(samples[:i], torch.zeros(i))
return x_marker,
anim = FuncAnimation(fig, animate, init_func=init,
frames=len(samples), interval=20, blit=True)
anim.save(filename, dpi=dpi, writer='ffmpeg')
create_mp4_samples(params_hmc[:100], x_lin, y_lin, filename='../figures/sampling/mcmc/normal.mp4', dpi=600)
from IPython.display import Video
Video('../figures/sampling/mcmc/normal.mp4', width=400)# sample from Mixture of Gaussians
mog = dist.MixtureSameFamily(
mixture_distribution=dist.Categorical(torch.tensor([0.3, 0.7])),
component_distribution=dist.Normal(torch.tensor([-2.0, 2.0]), torch.tensor([1.0, 0.5]))
)
samples = mog.sample((1000,))
sns.kdeplot(samples.numpy(), label='Samples', shade=True, color='C1')<AxesSubplot:ylabel='Density'>

# Logprob function to be passed to Hamiltorch sampler
def logprob(x):
return mog.log_prob(x)
logprob(torch.tensor([0.0]))tensor([-4.1114])
params_hmc = run_hmc(logprob, x0, num_samples, step_size, num_steps_per_sample)Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:10 | 0d:00:00:00 | #################### | 5000/5000 | 459.19
Acceptance Rate 0.99
# Trace plot
plt.plot(params_hmc, label='Trace')
y_lin = torch.exp(mog.log_prob(x_lin))
# KDE plot
plt.figure()
sns.kdeplot(params_hmc.ravel().detach().numpy(), label='Samples', shade=True, color='C1')
# Limit KDE plot to range of ground truth
plt.xlim(-3, 3)
plt.plot(x_lin, y_lin, label='Ground truth')
plt.xlabel('Parameter value')
plt.ylabel('Density')
plt.legend()<matplotlib.legend.Legend at 0x7f2dbc3a65b0>

# Create MP4 HTML5 video showing sampling process
create_mp4_samples(params_hmc[:500], x_lin, y_lin, filename='../figures/sampling/mcmc/mog.mp4', dpi=600)
Video('../figures/sampling/mcmc/mog.mp4', width=400)def p_tilde(x):
# normalising constant for standard normal distribution
Z = torch.sqrt(torch.tensor(2*np.pi))
return dist.Normal(0, 1).log_prob(x).exp()*Z
def p_tilde_log_prob(x):
# normalising constant for standard normal distribution
Z = torch.sqrt(torch.tensor(2*np.pi))
return dist.Normal(0, 1).log_prob(x) + torch.log(Z)# Plot unnormalized distribution
x_lin = torch.linspace(-3, 3, 1000)
y_lin = p_tilde(x_lin)
plt.plot(x_lin, y_lin, label='Unnormalized distribution')
# Plot normalized distribution
plt.plot(x_lin, dist.Normal(0, 1).log_prob(x_lin).exp(), label='Normalized distribution')
plt.legend()<matplotlib.legend.Legend at 0x7f2dbc385c70>

# HMC over unnormalized distribution
# Logprob function to be passed to Hamiltorch sampler
def logprob(x):
return p_tilde_log_prob(x)# HMC
params_hmc = run_hmc(logprob, x0, num_samples, step_size, num_steps_per_sample)Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:10 | 0d:00:00:00 | #################### | 5000/5000 | 479.28
Acceptance Rate 0.99
# Trace plot
plt.plot(params_hmc[:500], label='Trace')
# KDE plot
sns.kdeplot(params_hmc.ravel().detach().numpy(), label='Samples', shade=True, color='C1')
plt.plot(x_lin, y_lin, label='Unnormalized distribution', lw=2)
plt.plot(x_lin, dist.Normal(0, 1).log_prob(x_lin).exp(), label='Normalized distribution', lw=2)
plt.legend()<matplotlib.legend.Legend at 0x7f2dbc2657f0>

Coin Toss
Working with probabilities
prior = dist.Beta(1, 1)
data = torch.tensor([1.0, 1.0, 1.0, 0.0, 0.0])
n = len(data)
def log_prior(theta):
return prior.log_prob(theta)
def log_likelihood(theta):
return dist.Bernoulli(theta).log_prob(data).sum()
def log_joint(theta):
return log_prior(theta) + log_likelihood(theta)try:
params_hmc_theta = run_hmc(log_joint, torch.tensor([0.5]), 5000, 0.3, 5)
except Exception as e:
print(e)Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
Expected value argument (Tensor of shape (1,)) to be within the support (Interval(lower_bound=0.0, upper_bound=1.0)) of the distribution Beta(), but found invalid values:
tensor([-0.1017], requires_grad=True)
Working with logits
# Let us work instead with logits
def log_prior(logits):
return prior.log_prob(torch.sigmoid(logits)).sum()
def log_likelihood(logits):
return dist.Bernoulli(logits=logits).log_prob(data).sum()
def log_joint(logits):
return (log_prior(logits) + log_likelihood(logits))params_hmc_logits = run_hmc(log_joint, torch.tensor([0.0]), 1000, 0.3, 5)Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:03 | 0d:00:00:00 | #################### | 1000/1000 | 275.76
Acceptance Rate 0.99
fig, ax = plt.subplots(nrows=2, sharex=True)
ax[0].plot(params_hmc_logits[:500], label='Trace for logits')
ax[1].plot(torch.sigmoid(params_hmc_logits[:500]), label='Trace for probabilities')
ax[0].set_ylabel('Logits')
ax[1].set_ylabel('Computed Probabilities\n (From Logits)')
ax[1].set_xlabel('Iteration')Text(0.5, 0, 'Iteration')

# Create a function to update the KDE plot with the specified bw_adjust value
def update_kde_plot(bw_adjust):
plt.clf() # Clear the previous plot
plt.hist(torch.sigmoid(params_hmc_logits[:, 0]).detach().numpy(), bins=100, density=True, label='Samples (Histogram)', color='C2', alpha=0.5, lw=1)
sns.kdeplot(torch.sigmoid(params_hmc_logits[:, 0]).detach().numpy(), label='Samples (KDE)', shade=False, color='C1', clip=(0, 1), bw_adjust=bw_adjust, lw=2)
x_lin = torch.linspace(0, 1, 1000)
y_lin = dist.Beta(1+3, 1+2).log_prob(x_lin).exp()
plt.plot(x_lin, y_lin, label='True posterior')
plt.legend()
# Create the slider widget for bw_adjust
bw_adjust_slider = widgets.FloatSlider(value=0.1, min=0.01, max=4.0, step=0.01, description='bw_adjust:')
# Create the interactive plot
interactive_plot = widgets.interactive(update_kde_plot, bw_adjust=bw_adjust_slider)
# Display the interactive plot
display(interactive_plot)# Plot histogram of samples
plt.hist(torch.sigmoid(params_hmc_logits[:, 0]).detach().numpy(), bins=100, density=True, label='Samples (Histogram)', color='C2', alpha=0.5, lw=1 )
# Plot posterior KDE using seaborn but clip to [0, 1]
sns.kdeplot(torch.sigmoid(params_hmc_logits[:, 0]).detach().numpy(), label='Samples (KDE)', shade=False, color='C1', clip=(0, 1), bw_adjust=0.1, lw=2)
# True posterior
x_lin = torch.linspace(0, 1, 1000)
y_lin = dist.Beta(1+3, 1+2).log_prob(x_lin).exp()
plt.plot(x_lin, y_lin, label='True posterior')
plt.legend()<matplotlib.legend.Legend at 0x7f1850e258b0>

# Linear regression for 1 dimensional input using HMC
x_lin = torch.linspace(-3, 3, 90)
theta_0_true = torch.tensor([2.0])
theta_1_true = torch.tensor([3.0])
f = lambda x: theta_0_true + theta_1_true * x
eps = torch.randn_like(x_lin) *1.0
y_lin = f(x_lin) + eps
plt.scatter(x_lin, y_lin, label='Data', color='C0')
plt.plot(x_lin, f(x_lin), label='Ground truth')
plt.xlabel('x')
plt.ylabel('y')Text(0, 0.5, 'y')

# Esimate theta_0, theta_1 using HMC assuming noise variance is known to be 1
def logprob(theta):
y_pred = theta[0] + x_lin * theta[1]
return dist.Normal(y_pred, 1).log_prob(y_lin).sum()
def log_prior(theta):
return dist.Normal(0, 1).log_prob(theta).sum()
def log_posterior(theta):
return logprob(theta) + log_prior(theta)params_hmc_lin_reg = run_hmc(log_posterior, torch.tensor([0.0, 0.0]), 1000, 0.05, 10)Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:06 | 0d:00:00:00 | #################### | 1000/1000 | 157.83
Acceptance Rate 0.95
params_hmc_lin_regtensor([[0.0000, 0.0000],
[1.8463, 1.5993],
[2.1721, 3.8480],
...,
[1.9128, 2.9478],
[2.1689, 2.9928],
[1.8748, 3.0217]])
lps = []
for p in params_hmc_lin_reg:
lps.append(log_posterior(p))plt.plot(torch.stack(lps).ravel()[100:])
log_posterior(params_hmc_lin_reg[0]), log_posterior(params_hmc_lin_reg[1]), log_posterior(params_hmc_lin_reg[2])(tensor(-1554.2708), tensor(-392.7923), tensor(-242.6276))
# Plot the traces corresponding to the two parameters
fig, axes = plt.subplots(2, 1, sharex=True)
for i, param_vals in enumerate(params_hmc_lin_reg.T):
axes[i].plot(param_vals, label='Trace')
axes[i].set_xlabel('Iteration')
axes[i].set_ylabel(fr'$\theta_{i}$')
# Plot the true values as well
for i, param_vals in enumerate([theta_0_true, theta_1_true]):
axes[i].axhline(param_vals.numpy(), color='C1', label='Ground truth')
axes[i].legend()findfont: Font family ['cursive'] not found. Falling back to DejaVu Sans.
findfont: Generic family 'cursive' not found because none of the following families were found: Apple Chancery, Textile, Zapf Chancery, Sand, Script MT, Felipa, Comic Neue, Comic Sans MS, cursive

# Plot KDE of the samples for the two parameters
fig, axes = plt.subplots(2, 1, sharex=True)
for i, param_vals in enumerate(params_hmc_lin_reg.T):
sns.kdeplot(param_vals.detach().numpy(), label='Samples', shade=True, color='C1', ax=axes[i])
axes[i].set_ylabel(fr'$\theta_{i}$')
# Plot the true values as well
for i, param_vals in enumerate([theta_0_true, theta_1_true]):
axes[i].axvline(param_vals.numpy(), color='C0', label='Ground truth')
axes[i].legend()
# Plot the posterior predictive distribution
plt.figure()
plt.scatter(x_lin, y_lin, label='Data', color='C0')
plt.plot(x_lin, f(x_lin), label='Ground truth', color='C1', linestyle='--')
plt.xlabel('x')
plt.ylabel('y')
# Get posterior samples. Thin first 100 samples to remove burn-in
posterior_samples = params_hmc_lin_reg[100:].detach()
y_hat = posterior_samples[:, 0].unsqueeze(1) + x_lin * posterior_samples[:, 1].unsqueeze(1)
# Plot mean and 95% confidence interval
plt.plot(x_lin, y_hat.mean(axis=0), label='Mean', color='C2')
plt.fill_between(x_lin, y_hat.mean(axis=0) - 2 * y_hat.std(axis=0), y_hat.mean(axis=0) + 2 * y_hat.std(axis=0), alpha=0.5, label='95% CI', color='C2')
plt.legend()<matplotlib.legend.Legend at 0x7f2db1983640>

# Using a neural network with HMC
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(1, 1)
def forward(self, x):
x = self.fc1(x)
return xnet = Net()
netNet(
(fc1): Linear(in_features=1, out_features=1, bias=True)
)
net.state_dict()OrderedDict([('fc1.weight', tensor([[0.7689]])),
('fc1.bias', tensor([0.2034]))])
hamiltorch.util.flatten(net)tensor([0.7689, 0.2034], grad_fn=<CatBackward0>)
theta_params = hamiltorch.util.flatten(net) + 1.0
theta_paramstensor([1.7689, 1.2034], grad_fn=<AddBackward0>)
from nn_manual_hmc import log_joint as log_joint_nnparams_hmc = run_hmc(log_joint_nn, torch.tensor([0.2, 0.5]), 1000, 0.05, 5)Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:04 | 0d:00:00:00 | #################### | 1000/1000 | 202.36
Acceptance Rate 0.94
# Plot the traces corresponding to the two parameters
fig, axes = plt.subplots(2, 1, sharex=True)
for i, param_vals in enumerate(params_hmc.T):
axes[i].plot(param_vals, label='Trace')
axes[i].set_xlabel('Iteration')
axes[i].set_ylabel(fr'$\theta_{i}$')
# Plot KDE of the samples for the two parameters
fig, axes = plt.subplots(2, 1, sharex=True)
for i, param_vals in enumerate(params_hmc.T):
sns.kdeplot(param_vals.detach().numpy(), label='Samples', shade=True, color='C1', ax=axes[i])
axes[i].set_ylabel(fr'$\theta_{i}$')
# Mark the true values
axes[0].axvline(theta_0_true.numpy(), color='C0', label='Ground truth')
axes[1].axvline(theta_1_true.numpy(), color='C0', label='Ground truth')<matplotlib.lines.Line2D at 0x7f2db24ad8e0>

# Get posterior samples. Thin first 100 samples to remove burn-in
posterior_samples = params_hmc.detach()
y_preds = []
with torch.no_grad():
for theta in posterior_samples:
params_list = hamiltorch.util.unflatten(net, theta)
params = net.state_dict()
for i, (name, _) in enumerate(params.items()):
params[name] = params_list[i]
y_pred = torch.func.functional_call(net, params, x_lin.unsqueeze(1)).squeeze()
y_preds.append(y_pred)torch.stack(y_preds).shapetorch.Size([1000, 90])
y_mean = torch.stack(y_preds).mean(axis=0)
y_std = torch.stack(y_preds).std(axis=0)
plt.plot(x_lin, y_mean, label='Mean', color='C2')
plt.fill_between(x_lin, y_mean - 2 * y_std, y_mean + 2 * y_std, alpha=0.5, label='95% CI', color='C2')
plt.scatter(x_lin, y_lin, label='Data', color='C0')
plt.plot(x_lin, f(x_lin), label='Ground truth', color='C1', linestyle='--')
plt.xlabel('x')
plt.ylabel('y')
Text(0, 0.5, 'y')

step_size = 0.0005
num_samples = 1000
L = 30
burn = -1
store_on_GPU = False
debug = False
model_loss = 'regression'
mass = 1.0
# Effect of tau
# Set to tau = 1000. to see a function that is less bendy (weights restricted to small bends)
# Set to tau = 1. for more flexible
tau = 1.0 # Prior Precision
tau_out = 110.4439498986428 # Output Precision
r = 0 # Random seed
tau_list = []
for w in net.parameters():
tau_list.append(tau) # set the prior precision to be the same for each set of weights
tau_list = torch.tensor(tau_list)
# Set initial weights
params_init = hamiltorch.util.flatten(net).clone()
# Set the Inverse of the Mass matrix
inv_mass = torch.ones(params_init.shape) / mass
integrator = hamiltorch.Integrator.EXPLICIT
sampler = hamiltorch.Sampler.HMC
hamiltorch.set_random_seed(r)
params_hmc_f = hamiltorch.sample_model(net, x_lin.view(-1, 1), y_lin.view(-1, 1), params_init=params_init,
model_loss=model_loss, num_samples=100,
burn = burn, inv_mass=inv_mass,step_size=step_size,
num_steps_per_sample=L,tau_out=tau_out, tau_list=tau_list,
debug=debug, store_on_GPU=store_on_GPU,
sampler = sampler)
# At the moment, params_hmc_f is on the CPU so we move to GPU
params_hmc_gpu = [ll for ll in params_hmc_f[1:]]
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:02 | 0d:00:00:00 | #################### | 100/100 | 49.73
Acceptance Rate 1.00
torch.stack(params_hmc_gpu).shapetorch.Size([100, 2])
# Let's predict over the entire test range [-2,2]
pred_list, log_probs_f = hamiltorch.predict_model(net, x = x_lin.view(-1, 1), y = y_lin.view(-1, 1), samples=params_hmc_gpu,
model_loss=model_loss, tau_out=tau_out,
tau_list=tau_list)pred_list.shapetorch.Size([100, 90, 1])
plt.plot(x_lin, pred_list.mean(axis=0).ravel())
# Plot the true function
plt.plot(x_lin, f(x_lin), label='Ground truth', color='C1', linestyle='--')
# Plot standard deviation
plt.fill_between(x_lin, pred_list.mean(axis=0).ravel() - 2 * pred_list.std(axis=0).ravel(), pred_list.mean(axis=0).ravel() + 2 * pred_list.std(axis=0).ravel(), alpha=0.5, label='95% CI', color='C2')
plt.scatter(x_lin, y_lin, label='Data', color='C0')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()<matplotlib.legend.Legend at 0x7f1a94574c70>

from nn_manual_hmc_classification import log_joint, x_moon, y_moon, net_classificationplt.scatter(x_moon[:, 0].cpu().numpy(), x_moon[:, 1].cpu().numpy(), c=y_moon.cpu().numpy(), cmap='bwr', alpha=0.5)<matplotlib.collections.PathCollection at 0x7efdcec96820>

net_classificationNet_Classification(
(fc1): Linear(in_features=2, out_features=5, bias=True)
(fc2): Linear(in_features=5, out_features=5, bias=True)
(fc3): Linear(in_features=5, out_features=1, bias=True)
)
hamiltorch.util.flatten(net_classification).shapetorch.Size([51])
# number of params in the network
D = hamiltorch.util.flatten(net_classification).shape[0]log_joint(torch.zeros(D).to(device))tensor(-740.0131, device='cuda:0')
params_hmc = run_hmc(log_joint, torch.tensor(torch.zeros(D).to(device)), 2000, 0.01, 2)Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:21 | 0d:00:00:00 | #################### | 2000/2000 | 93.82
Acceptance Rate 0.96
/tmp/ipykernel_948360/3753994728.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
params_hmc = run_hmc(log_joint, torch.tensor(torch.zeros(D).to(device)), 2000, 0.01, 2)
params_hmc.shapetorch.Size([2000, 51])
plt.plot(params_hmc[:, 2].cpu().numpy())
# Get posterior predictive over the 2D grid
posterior_samples = params_hmc.detach()
# Consider burning the first 100 samples
posterior_samples = posterior_samples[1000:]
y_preds = []
n_grid = 200
lims = 4
twod_grid = torch.tensor(np.meshgrid(np.linspace(-lims, lims, n_grid), np.linspace(-lims, lims, n_grid))).float().to(device)
with torch.no_grad():
for theta in posterior_samples:
params_list = hamiltorch.util.unflatten(net_classification, theta)
params = net_classification.state_dict()
for i, (name, _) in enumerate(params.items()):
params[name] = params_list[i]
y_pred = torch.func.functional_call(net_classification, params, twod_grid.view(2, -1).T).squeeze()
y_preds.append(y_pred)x_moon.shapetorch.Size([1000, 2])
y_preds[0].shapetorch.Size([40000])
logits = torch.stack(y_preds).mean(axis=0).reshape(n_grid, n_grid)
logitstensor([[-20.4181, -19.7507, -19.0839, ..., 4.1155, 4.1154, 4.1154],
[-20.7736, -20.1060, -19.4386, ..., 4.1154, 4.1154, 4.1154],
[-21.1293, -20.4614, -19.7938, ..., 4.1154, 4.1153, 4.1153],
...,
[-98.6203, -98.0886, -97.5601, ..., -5.8248, -5.5091, -5.2027],
[-99.1825, -98.6539, -98.1300, ..., -6.1079, -5.7854, -5.4731],
[-99.7477, -99.2237, -98.7036, ..., -6.3935, -6.0654, -5.7471]],
device='cuda:0')
probs = torch.sigmoid(logits)
probstensor([[1.3568e-09, 2.6447e-09, 5.1520e-09, ..., 9.8394e-01, 9.8394e-01,
9.8394e-01],
[9.5090e-10, 1.8539e-09, 3.6136e-09, ..., 9.8394e-01, 9.8394e-01,
9.8394e-01],
[6.6629e-10, 1.2994e-09, 2.5332e-09, ..., 9.8394e-01, 9.8394e-01,
9.8394e-01],
...,
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 2.9446e-03, 4.0335e-03,
5.4718e-03],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 2.2204e-03, 3.0628e-03,
4.1806e-03],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 1.6697e-03, 2.3164e-03,
3.1817e-03]], device='cuda:0')
# Plot the posterior predictive distribution decision boundary
plt.figure()
plt.contourf(twod_grid[0].cpu().numpy(), twod_grid[1].cpu().numpy(), probs.cpu().numpy(), cmap='bwr', alpha=0.5)
plt.colorbar()
plt.scatter(x_moon[:, 0].cpu().numpy(), x_moon[:, 1].cpu().numpy(), c=y_moon.cpu().numpy(), cmap='bwr', alpha=0.5)<matplotlib.collections.PathCollection at 0x7efdcd18fe50>

# Plot the variance of the posterior predictive distribution
plt.figure()
plt.contourf(twod_grid[0].cpu().numpy(), twod_grid[1].cpu().numpy(), torch.stack(y_preds).std(axis=0).reshape(n_grid, n_grid).cpu().numpy(), cmap='bwr', alpha=0.5)
plt.scatter(x_moon[:, 0].cpu().numpy(), x_moon[:, 1].cpu().numpy(), c=y_moon.cpu().numpy(), cmap='bwr', alpha=0.5)
plt.colorbar()<matplotlib.colorbar.Colorbar at 0x7efdcd0f4e80>
