import torch
import torch.autograd.functional as F
import torch.distributions as dist
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import ipywidgets as widgets
import seaborn as sns
import pandas as pd
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'
Sampling from an unnormalized distribution
from tueplots import bundles
plt.rcParams.update(bundles.beamer_moml())#plt.rcParams.update(bundles.icml2022())
# Also add despine to the bundle using rcParams
'axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams[
# Increase font size to match Beamer template
'font.size'] = 16
plt.rcParams[# Make background transparent
'figure.facecolor'] = 'none' plt.rcParams[
import hamiltorch
123)
hamiltorch.set_random_seed(= torch.device('cuda' if torch.cuda.is_available() else 'cpu') device
device
device(type='cuda')
= torch.distributions.Normal(0, 1)
gt_distribution
# Samples from the ground truth distribution
def sample_gt(n):
return gt_distribution.sample((n,))
= sample_gt(1000) samples
= torch.linspace(-3, 3, 1000)
x_lin = torch.exp(gt_distribution.log_prob(x_lin))
y_lin
='Ground truth')
plt.plot(x_lin, y_lin, label plt.legend()
<matplotlib.legend.Legend at 0x7efdfc51dd30>
# Logprob function to be passed to Hamiltorch sampler
def logprob(x):
return gt_distribution.log_prob(x)
0.0])) logprob(torch.tensor([
tensor([-0.9189])
# Markov chain
= torch.tensor([0.0])
x_start = []
samples for i in range(100):
= torch.distributions.Normal(x_start, 10).sample()
prop
samples.append(prop)= prop
x_start
plt.plot(torch.stack(samples).ravel())
# Initial state
= torch.tensor([0.0])
x0 = 5000
num_samples = 0.3
step_size = 5
num_steps_per_sample 123) hamiltorch.set_random_seed(
= hamiltorch.sample(log_prob_func=logprob, params_init=x0,
params_hmc =num_samples, step_size=step_size,
num_samples=num_steps_per_sample) num_steps_per_sample
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:07 | 0d:00:00:00 | #################### | 5000/5000 | 686.98
Acceptance Rate 0.99
= torch.tensor(params_hmc) params_hmc
def run_hmc(logprob, x0, num_samples, step_size, num_steps_per_sample):
123)
torch.manual_seed(= hamiltorch.sample(log_prob_func=logprob, params_init=x0,
params_hmc =num_samples, step_size=step_size,
num_samples=num_steps_per_sample)
num_steps_per_samplereturn torch.stack(params_hmc)
= run_hmc(logprob, x0, num_samples, step_size, num_steps_per_sample) params_hmc
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:07 | 0d:00:00:00 | #################### | 5000/5000 | 676.50
Acceptance Rate 0.99
params_hmc.shape
torch.Size([5000, 1])
# Trace plot
='Trace')
plt.plot(params_hmc, label'Iteration')
plt.xlabel('Parameter value') plt.ylabel(
Text(0, 0.5, 'Parameter value')
# view first 500 samples
50], label='Trace')
plt.plot(params_hmc[:'Iteration')
plt.xlabel('Parameter value') plt.ylabel(
Text(0, 0.5, 'Parameter value')
# KDE plot
import seaborn as sns
plt.figure()='Samples', shade=True, color='C1')
sns.kdeplot(params_hmc.ravel().detach().numpy(), label='Ground truth')
plt.plot(x_lin, y_lin, label'Parameter value')
plt.xlabel('Density')
plt.ylabel( plt.legend()
<matplotlib.legend.Legend at 0x7f2dbc637d30>
# Create MP4 HTML5 video showing sampling process
def create_mp4_samples(samples, x_lin, y_lin, filename='samples.mp4', dpi=600):
= plt.subplots(figsize=(4,2), dpi=dpi)
fig, ax -3, 3)
ax.set_xlim(0, 1)
ax.set_ylim('Parameter value')
ax.set_xlabel('Density')
ax.set_ylabel(='Ground truth')
ax.plot(x_lin, y_lin, label
ax.legend()
# Add a "x" marker to the plot for each sample at y=0
= ax.plot([], [], 'x', color='C1', label='Samples')
x_marker,
ax.legend()
def init():
x_marker.set_data([], [])return x_marker,
def animate(i):
x_marker.set_data(samples[:i], torch.zeros(i))return x_marker,
= FuncAnimation(fig, animate, init_func=init,
anim =len(samples), interval=20, blit=True)
frames=dpi, writer='ffmpeg')
anim.save(filename, dpi
100], x_lin, y_lin, filename='../figures/sampling/mcmc/normal.mp4', dpi=600) create_mp4_samples(params_hmc[:
from IPython.display import Video
'../figures/sampling/mcmc/normal.mp4', width=400) Video(
# sample from Mixture of Gaussians
= dist.MixtureSameFamily(
mog =dist.Categorical(torch.tensor([0.3, 0.7])),
mixture_distribution=dist.Normal(torch.tensor([-2.0, 2.0]), torch.tensor([1.0, 0.5]))
component_distribution
)
= mog.sample((1000,))
samples ='Samples', shade=True, color='C1') sns.kdeplot(samples.numpy(), label
<AxesSubplot:ylabel='Density'>
# Logprob function to be passed to Hamiltorch sampler
def logprob(x):
return mog.log_prob(x)
0.0])) logprob(torch.tensor([
tensor([-4.1114])
= run_hmc(logprob, x0, num_samples, step_size, num_steps_per_sample) params_hmc
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:10 | 0d:00:00:00 | #################### | 5000/5000 | 459.19
Acceptance Rate 0.99
# Trace plot
='Trace') plt.plot(params_hmc, label
= torch.exp(mog.log_prob(x_lin))
y_lin
# KDE plot
plt.figure()='Samples', shade=True, color='C1')
sns.kdeplot(params_hmc.ravel().detach().numpy(), label# Limit KDE plot to range of ground truth
-3, 3)
plt.xlim(='Ground truth')
plt.plot(x_lin, y_lin, label'Parameter value')
plt.xlabel('Density')
plt.ylabel( plt.legend()
<matplotlib.legend.Legend at 0x7f2dbc3a65b0>
# Create MP4 HTML5 video showing sampling process
500], x_lin, y_lin, filename='../figures/sampling/mcmc/mog.mp4', dpi=600) create_mp4_samples(params_hmc[:
'../figures/sampling/mcmc/mog.mp4', width=400) Video(
def p_tilde(x):
# normalising constant for standard normal distribution
= torch.sqrt(torch.tensor(2*np.pi))
Z return dist.Normal(0, 1).log_prob(x).exp()*Z
def p_tilde_log_prob(x):
# normalising constant for standard normal distribution
= torch.sqrt(torch.tensor(2*np.pi))
Z return dist.Normal(0, 1).log_prob(x) + torch.log(Z)
# Plot unnormalized distribution
= torch.linspace(-3, 3, 1000)
x_lin = p_tilde(x_lin)
y_lin ='Unnormalized distribution')
plt.plot(x_lin, y_lin, label# Plot normalized distribution
0, 1).log_prob(x_lin).exp(), label='Normalized distribution')
plt.plot(x_lin, dist.Normal( plt.legend()
<matplotlib.legend.Legend at 0x7f2dbc385c70>
# HMC over unnormalized distribution
# Logprob function to be passed to Hamiltorch sampler
def logprob(x):
return p_tilde_log_prob(x)
# HMC
= run_hmc(logprob, x0, num_samples, step_size, num_steps_per_sample) params_hmc
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:10 | 0d:00:00:00 | #################### | 5000/5000 | 479.28
Acceptance Rate 0.99
# Trace plot
500], label='Trace') plt.plot(params_hmc[:
# KDE plot
='Samples', shade=True, color='C1')
sns.kdeplot(params_hmc.ravel().detach().numpy(), label='Unnormalized distribution', lw=2)
plt.plot(x_lin, y_lin, label0, 1).log_prob(x_lin).exp(), label='Normalized distribution', lw=2)
plt.plot(x_lin, dist.Normal( plt.legend()
<matplotlib.legend.Legend at 0x7f2dbc2657f0>
Coin Toss
Working with probabilities
= dist.Beta(1, 1)
prior = torch.tensor([1.0, 1.0, 1.0, 0.0, 0.0])
data = len(data)
n
def log_prior(theta):
return prior.log_prob(theta)
def log_likelihood(theta):
return dist.Bernoulli(theta).log_prob(data).sum()
def log_joint(theta):
return log_prior(theta) + log_likelihood(theta)
try:
= run_hmc(log_joint, torch.tensor([0.5]), 5000, 0.3, 5)
params_hmc_theta except Exception as e:
print(e)
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
Expected value argument (Tensor of shape (1,)) to be within the support (Interval(lower_bound=0.0, upper_bound=1.0)) of the distribution Beta(), but found invalid values:
tensor([-0.1017], requires_grad=True)
Working with logits
# Let us work instead with logits
def log_prior(logits):
return prior.log_prob(torch.sigmoid(logits)).sum()
def log_likelihood(logits):
return dist.Bernoulli(logits=logits).log_prob(data).sum()
def log_joint(logits):
return (log_prior(logits) + log_likelihood(logits))
= run_hmc(log_joint, torch.tensor([0.0]), 1000, 0.3, 5) params_hmc_logits
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:03 | 0d:00:00:00 | #################### | 1000/1000 | 275.76
Acceptance Rate 0.99
= plt.subplots(nrows=2, sharex=True)
fig, ax 0].plot(params_hmc_logits[:500], label='Trace for logits')
ax[1].plot(torch.sigmoid(params_hmc_logits[:500]), label='Trace for probabilities')
ax[0].set_ylabel('Logits')
ax[1].set_ylabel('Computed Probabilities\n (From Logits)')
ax[1].set_xlabel('Iteration') ax[
Text(0.5, 0, 'Iteration')
# Create a function to update the KDE plot with the specified bw_adjust value
def update_kde_plot(bw_adjust):
# Clear the previous plot
plt.clf() 0]).detach().numpy(), bins=100, density=True, label='Samples (Histogram)', color='C2', alpha=0.5, lw=1)
plt.hist(torch.sigmoid(params_hmc_logits[:, 0]).detach().numpy(), label='Samples (KDE)', shade=False, color='C1', clip=(0, 1), bw_adjust=bw_adjust, lw=2)
sns.kdeplot(torch.sigmoid(params_hmc_logits[:, = torch.linspace(0, 1, 1000)
x_lin = dist.Beta(1+3, 1+2).log_prob(x_lin).exp()
y_lin ='True posterior')
plt.plot(x_lin, y_lin, label
plt.legend()
# Create the slider widget for bw_adjust
= widgets.FloatSlider(value=0.1, min=0.01, max=4.0, step=0.01, description='bw_adjust:')
bw_adjust_slider
# Create the interactive plot
= widgets.interactive(update_kde_plot, bw_adjust=bw_adjust_slider)
interactive_plot
# Display the interactive plot
display(interactive_plot)
# Plot histogram of samples
0]).detach().numpy(), bins=100, density=True, label='Samples (Histogram)', color='C2', alpha=0.5, lw=1 )
plt.hist(torch.sigmoid(params_hmc_logits[:,
# Plot posterior KDE using seaborn but clip to [0, 1]
0]).detach().numpy(), label='Samples (KDE)', shade=False, color='C1', clip=(0, 1), bw_adjust=0.1, lw=2)
sns.kdeplot(torch.sigmoid(params_hmc_logits[:, # True posterior
= torch.linspace(0, 1, 1000)
x_lin = dist.Beta(1+3, 1+2).log_prob(x_lin).exp()
y_lin ='True posterior')
plt.plot(x_lin, y_lin, label plt.legend()
<matplotlib.legend.Legend at 0x7f1850e258b0>
# Linear regression for 1 dimensional input using HMC
= torch.linspace(-3, 3, 90)
x_lin = torch.tensor([2.0])
theta_0_true = torch.tensor([3.0])
theta_1_true = lambda x: theta_0_true + theta_1_true * x
f = torch.randn_like(x_lin) *1.0
eps = f(x_lin) + eps
y_lin
='Data', color='C0')
plt.scatter(x_lin, y_lin, label='Ground truth')
plt.plot(x_lin, f(x_lin), label'x')
plt.xlabel('y') plt.ylabel(
Text(0, 0.5, 'y')
# Esimate theta_0, theta_1 using HMC assuming noise variance is known to be 1
def logprob(theta):
= theta[0] + x_lin * theta[1]
y_pred return dist.Normal(y_pred, 1).log_prob(y_lin).sum()
def log_prior(theta):
return dist.Normal(0, 1).log_prob(theta).sum()
def log_posterior(theta):
return logprob(theta) + log_prior(theta)
= run_hmc(log_posterior, torch.tensor([0.0, 0.0]), 1000, 0.05, 10) params_hmc_lin_reg
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:06 | 0d:00:00:00 | #################### | 1000/1000 | 157.83
Acceptance Rate 0.95
params_hmc_lin_reg
tensor([[0.0000, 0.0000],
[1.8463, 1.5993],
[2.1721, 3.8480],
...,
[1.9128, 2.9478],
[2.1689, 2.9928],
[1.8748, 3.0217]])
= []
lps for p in params_hmc_lin_reg:
lps.append(log_posterior(p))
100:]) plt.plot(torch.stack(lps).ravel()[
0]), log_posterior(params_hmc_lin_reg[1]), log_posterior(params_hmc_lin_reg[2]) log_posterior(params_hmc_lin_reg[
(tensor(-1554.2708), tensor(-392.7923), tensor(-242.6276))
# Plot the traces corresponding to the two parameters
= plt.subplots(2, 1, sharex=True)
fig, axes
for i, param_vals in enumerate(params_hmc_lin_reg.T):
='Trace')
axes[i].plot(param_vals, label'Iteration')
axes[i].set_xlabel(fr'$\theta_{i}$')
axes[i].set_ylabel(
# Plot the true values as well
for i, param_vals in enumerate([theta_0_true, theta_1_true]):
='C1', label='Ground truth')
axes[i].axhline(param_vals.numpy(), color axes[i].legend()
findfont: Font family ['cursive'] not found. Falling back to DejaVu Sans.
findfont: Generic family 'cursive' not found because none of the following families were found: Apple Chancery, Textile, Zapf Chancery, Sand, Script MT, Felipa, Comic Neue, Comic Sans MS, cursive
# Plot KDE of the samples for the two parameters
= plt.subplots(2, 1, sharex=True)
fig, axes
for i, param_vals in enumerate(params_hmc_lin_reg.T):
='Samples', shade=True, color='C1', ax=axes[i])
sns.kdeplot(param_vals.detach().numpy(), labelfr'$\theta_{i}$')
axes[i].set_ylabel(
# Plot the true values as well
for i, param_vals in enumerate([theta_0_true, theta_1_true]):
='C0', label='Ground truth')
axes[i].axvline(param_vals.numpy(), color axes[i].legend()
# Plot the posterior predictive distribution
plt.figure()='Data', color='C0')
plt.scatter(x_lin, y_lin, label='Ground truth', color='C1', linestyle='--')
plt.plot(x_lin, f(x_lin), label'x')
plt.xlabel('y')
plt.ylabel(
# Get posterior samples. Thin first 100 samples to remove burn-in
= params_hmc_lin_reg[100:].detach()
posterior_samples = posterior_samples[:, 0].unsqueeze(1) + x_lin * posterior_samples[:, 1].unsqueeze(1)
y_hat
# Plot mean and 95% confidence interval
=0), label='Mean', color='C2')
plt.plot(x_lin, y_hat.mean(axis=0) - 2 * y_hat.std(axis=0), y_hat.mean(axis=0) + 2 * y_hat.std(axis=0), alpha=0.5, label='95% CI', color='C2')
plt.fill_between(x_lin, y_hat.mean(axis plt.legend()
<matplotlib.legend.Legend at 0x7f2db1983640>
# Using a neural network with HMC
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(1, 1)
def forward(self, x):
= self.fc1(x)
x return x
= Net()
net net
Net(
(fc1): Linear(in_features=1, out_features=1, bias=True)
)
net.state_dict()
OrderedDict([('fc1.weight', tensor([[0.7689]])),
('fc1.bias', tensor([0.2034]))])
hamiltorch.util.flatten(net)
tensor([0.7689, 0.2034], grad_fn=<CatBackward0>)
= hamiltorch.util.flatten(net) + 1.0
theta_params theta_params
tensor([1.7689, 1.2034], grad_fn=<AddBackward0>)
from nn_manual_hmc import log_joint as log_joint_nn
= run_hmc(log_joint_nn, torch.tensor([0.2, 0.5]), 1000, 0.05, 5) params_hmc
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:04 | 0d:00:00:00 | #################### | 1000/1000 | 202.36
Acceptance Rate 0.94
# Plot the traces corresponding to the two parameters
= plt.subplots(2, 1, sharex=True)
fig, axes
for i, param_vals in enumerate(params_hmc.T):
='Trace')
axes[i].plot(param_vals, label'Iteration')
axes[i].set_xlabel(fr'$\theta_{i}$') axes[i].set_ylabel(
# Plot KDE of the samples for the two parameters
= plt.subplots(2, 1, sharex=True)
fig, axes
for i, param_vals in enumerate(params_hmc.T):
='Samples', shade=True, color='C1', ax=axes[i])
sns.kdeplot(param_vals.detach().numpy(), labelfr'$\theta_{i}$')
axes[i].set_ylabel(
# Mark the true values
0].axvline(theta_0_true.numpy(), color='C0', label='Ground truth')
axes[1].axvline(theta_1_true.numpy(), color='C0', label='Ground truth') axes[
<matplotlib.lines.Line2D at 0x7f2db24ad8e0>
# Get posterior samples. Thin first 100 samples to remove burn-in
= params_hmc.detach()
posterior_samples = []
y_preds with torch.no_grad():
for theta in posterior_samples:
= hamiltorch.util.unflatten(net, theta)
params_list = net.state_dict()
params for i, (name, _) in enumerate(params.items()):
= params_list[i]
params[name] = torch.func.functional_call(net, params, x_lin.unsqueeze(1)).squeeze()
y_pred
y_preds.append(y_pred)
torch.stack(y_preds).shape
torch.Size([1000, 90])
= torch.stack(y_preds).mean(axis=0)
y_mean = torch.stack(y_preds).std(axis=0)
y_std
='Mean', color='C2')
plt.plot(x_lin, y_mean, label- 2 * y_std, y_mean + 2 * y_std, alpha=0.5, label='95% CI', color='C2')
plt.fill_between(x_lin, y_mean
='Data', color='C0')
plt.scatter(x_lin, y_lin, label='Ground truth', color='C1', linestyle='--')
plt.plot(x_lin, f(x_lin), label'x')
plt.xlabel('y')
plt.ylabel(
Text(0, 0.5, 'y')
= 0.0005
step_size = 1000
num_samples = 30
L = -1
burn = False
store_on_GPU = False
debug = 'regression'
model_loss = 1.0
mass
# Effect of tau
# Set to tau = 1000. to see a function that is less bendy (weights restricted to small bends)
# Set to tau = 1. for more flexible
= 1.0 # Prior Precision
tau = 110.4439498986428 # Output Precision
tau_out = 0 # Random seed
r
= []
tau_list for w in net.parameters():
# set the prior precision to be the same for each set of weights
tau_list.append(tau) = torch.tensor(tau_list)
tau_list
# Set initial weights
= hamiltorch.util.flatten(net).clone()
params_init # Set the Inverse of the Mass matrix
= torch.ones(params_init.shape) / mass
inv_mass
= hamiltorch.Integrator.EXPLICIT
integrator = hamiltorch.Sampler.HMC
sampler
hamiltorch.set_random_seed(r)= hamiltorch.sample_model(net, x_lin.view(-1, 1), y_lin.view(-1, 1), params_init=params_init,
params_hmc_f =model_loss, num_samples=100,
model_loss= burn, inv_mass=inv_mass,step_size=step_size,
burn =L,tau_out=tau_out, tau_list=tau_list,
num_steps_per_sample=debug, store_on_GPU=store_on_GPU,
debug= sampler)
sampler
# At the moment, params_hmc_f is on the CPU so we move to GPU
= [ll for ll in params_hmc_f[1:]]
params_hmc_gpu
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:02 | 0d:00:00:00 | #################### | 100/100 | 49.73
Acceptance Rate 1.00
torch.stack(params_hmc_gpu).shape
torch.Size([100, 2])
# Let's predict over the entire test range [-2,2]
= hamiltorch.predict_model(net, x = x_lin.view(-1, 1), y = y_lin.view(-1, 1), samples=params_hmc_gpu,
pred_list, log_probs_f =model_loss, tau_out=tau_out,
model_loss=tau_list) tau_list
pred_list.shape
torch.Size([100, 90, 1])
=0).ravel())
plt.plot(x_lin, pred_list.mean(axis# Plot the true function
='Ground truth', color='C1', linestyle='--')
plt.plot(x_lin, f(x_lin), label
# Plot standard deviation
=0).ravel() - 2 * pred_list.std(axis=0).ravel(), pred_list.mean(axis=0).ravel() + 2 * pred_list.std(axis=0).ravel(), alpha=0.5, label='95% CI', color='C2')
plt.fill_between(x_lin, pred_list.mean(axis
='Data', color='C0')
plt.scatter(x_lin, y_lin, label'x')
plt.xlabel('y')
plt.ylabel( plt.legend()
<matplotlib.legend.Legend at 0x7f1a94574c70>
from nn_manual_hmc_classification import log_joint, x_moon, y_moon, net_classification
0].cpu().numpy(), x_moon[:, 1].cpu().numpy(), c=y_moon.cpu().numpy(), cmap='bwr', alpha=0.5) plt.scatter(x_moon[:,
<matplotlib.collections.PathCollection at 0x7efdcec96820>
net_classification
Net_Classification(
(fc1): Linear(in_features=2, out_features=5, bias=True)
(fc2): Linear(in_features=5, out_features=5, bias=True)
(fc3): Linear(in_features=5, out_features=1, bias=True)
)
hamiltorch.util.flatten(net_classification).shape
torch.Size([51])
# number of params in the network
= hamiltorch.util.flatten(net_classification).shape[0] D
log_joint(torch.zeros(D).to(device))
tensor(-740.0131, device='cuda:0')
= run_hmc(log_joint, torch.tensor(torch.zeros(D).to(device)), 2000, 0.01, 2) params_hmc
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:21 | 0d:00:00:00 | #################### | 2000/2000 | 93.82
Acceptance Rate 0.96
/tmp/ipykernel_948360/3753994728.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
params_hmc = run_hmc(log_joint, torch.tensor(torch.zeros(D).to(device)), 2000, 0.01, 2)
params_hmc.shape
torch.Size([2000, 51])
2].cpu().numpy()) plt.plot(params_hmc[:,
# Get posterior predictive over the 2D grid
= params_hmc.detach()
posterior_samples # Consider burning the first 100 samples
= posterior_samples[1000:]
posterior_samples = []
y_preds = 200
n_grid = 4
lims = torch.tensor(np.meshgrid(np.linspace(-lims, lims, n_grid), np.linspace(-lims, lims, n_grid))).float().to(device)
twod_grid with torch.no_grad():
for theta in posterior_samples:
= hamiltorch.util.unflatten(net_classification, theta)
params_list = net_classification.state_dict()
params for i, (name, _) in enumerate(params.items()):
= params_list[i]
params[name] = torch.func.functional_call(net_classification, params, twod_grid.view(2, -1).T).squeeze()
y_pred
y_preds.append(y_pred)
x_moon.shape
torch.Size([1000, 2])
0].shape y_preds[
torch.Size([40000])
= torch.stack(y_preds).mean(axis=0).reshape(n_grid, n_grid)
logits logits
tensor([[-20.4181, -19.7507, -19.0839, ..., 4.1155, 4.1154, 4.1154],
[-20.7736, -20.1060, -19.4386, ..., 4.1154, 4.1154, 4.1154],
[-21.1293, -20.4614, -19.7938, ..., 4.1154, 4.1153, 4.1153],
...,
[-98.6203, -98.0886, -97.5601, ..., -5.8248, -5.5091, -5.2027],
[-99.1825, -98.6539, -98.1300, ..., -6.1079, -5.7854, -5.4731],
[-99.7477, -99.2237, -98.7036, ..., -6.3935, -6.0654, -5.7471]],
device='cuda:0')
= torch.sigmoid(logits)
probs probs
tensor([[1.3568e-09, 2.6447e-09, 5.1520e-09, ..., 9.8394e-01, 9.8394e-01,
9.8394e-01],
[9.5090e-10, 1.8539e-09, 3.6136e-09, ..., 9.8394e-01, 9.8394e-01,
9.8394e-01],
[6.6629e-10, 1.2994e-09, 2.5332e-09, ..., 9.8394e-01, 9.8394e-01,
9.8394e-01],
...,
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 2.9446e-03, 4.0335e-03,
5.4718e-03],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 2.2204e-03, 3.0628e-03,
4.1806e-03],
[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 1.6697e-03, 2.3164e-03,
3.1817e-03]], device='cuda:0')
# Plot the posterior predictive distribution decision boundary
plt.figure()0].cpu().numpy(), twod_grid[1].cpu().numpy(), probs.cpu().numpy(), cmap='bwr', alpha=0.5)
plt.contourf(twod_grid[
plt.colorbar()0].cpu().numpy(), x_moon[:, 1].cpu().numpy(), c=y_moon.cpu().numpy(), cmap='bwr', alpha=0.5) plt.scatter(x_moon[:,
<matplotlib.collections.PathCollection at 0x7efdcd18fe50>
# Plot the variance of the posterior predictive distribution
plt.figure()0].cpu().numpy(), twod_grid[1].cpu().numpy(), torch.stack(y_preds).std(axis=0).reshape(n_grid, n_grid).cpu().numpy(), cmap='bwr', alpha=0.5)
plt.contourf(twod_grid[0].cpu().numpy(), x_moon[:, 1].cpu().numpy(), c=y_moon.cpu().numpy(), cmap='bwr', alpha=0.5)
plt.scatter(x_moon[:, plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7efdcd0f4e80>