import os
"CUDA_VISIBLE_DEVICES"] = "3"
os.environ[
import torch
import torch.autograd.functional as F
import torch.distributions as dist
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import pandas as pd
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'
from tueplots import bundles
plt.rcParams.update(bundles.beamer_moml())# plt.rcParams.update(bundles.icml2022())
# Also add despine to the bundle using rcParams
"axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
plt.rcParams[
# Increase font size to match Beamer template
"font.size"] = 16
plt.rcParams[# Make background transparent
"figure.facecolor"] = "none" plt.rcParams[
import hamiltorch
123)
hamiltorch.set_random_seed(= torch.device("cuda" if torch.cuda.is_available() else "cpu") device
device
device(type='cuda')
= torch.distributions.Normal(0, 1)
gt_distribution
# Samples from the ground truth distribution
def sample_gt(n):
return gt_distribution.sample((n,))
= sample_gt(1000) samples
= torch.linspace(-3, 3, 1000)
x_lin = torch.exp(gt_distribution.log_prob(x_lin))
y_lin
="Ground truth")
plt.plot(x_lin, y_lin, label plt.legend()
<matplotlib.legend.Legend at 0x7fb356939c40>
# Logprob function to be passed to Hamiltorch sampler
def logprob(x):
return gt_distribution.log_prob(x)
0.0])) logprob(torch.tensor([
tensor([-0.9189])
# Initial state
= torch.tensor([0.0])
x0 = 5000
num_samples = 0.3
step_size = 5
num_steps_per_sample 123) hamiltorch.set_random_seed(
= hamiltorch.sample(
params_hmc =logprob,
log_prob_func=x0,
params_init=num_samples,
num_samples=step_size,
step_size=num_steps_per_sample,
num_steps_per_sample )
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:07 | 0d:00:00:00 | #################### | 5000/5000 | 698.22
Acceptance Rate 0.99
= torch.tensor(params_hmc) params_hmc
# Trace plot
="Trace")
plt.plot(params_hmc, label"Iteration")
plt.xlabel("Parameter value") plt.ylabel(
Text(0, 0.5, 'Parameter value')
# view first 500 samples
500], label="Trace")
plt.plot(params_hmc[:"Iteration")
plt.xlabel("Parameter value") plt.ylabel(
Text(0, 0.5, 'Parameter value')
# KDE plot
import seaborn as sns
plt.figure()="Samples", shade=True, color="C1")
sns.kdeplot(params_hmc.detach().numpy(), label="Ground truth")
plt.plot(x_lin, y_lin, label"Parameter value")
plt.xlabel("Density")
plt.ylabel( plt.legend()
<matplotlib.legend.Legend at 0x7fb353810dc0>
def plot_samples_gif(x_lin, y_lin, params_hmc, filename, frames=50):
= plt.subplots()
fig, ax ="Ground truth")
plt.plot(x_lin, y_lin, label= ax.scatter([], [], color="C1", marker="x", s=100)
scatter
# Function to update the animation
def update(frame):
0]]))
scatter.set_offsets(np.array([[params_hmc[frame], return (scatter,)
# Create the animation
= FuncAnimation(fig, update, frames=frames, blit=True)
anim
# Save the animation as a GIF or video file (you can change the filename and format)
=200) anim.save(filename, dpi
plot_samples_gif("../figures/sampling/mcmc/hamiltorch-samples-normal.gif"
x_lin, y_lin, params_hmc, )
# sample from Mixture of Gaussians
= dist.MixtureSameFamily(
mog =dist.Categorical(torch.tensor([0.3, 0.7])),
mixture_distribution=dist.Normal(
component_distribution-2.0, 2.0]), torch.tensor([1.0, 0.5])
torch.tensor([
),
)
= mog.sample((1000,))
samples ="Samples", shade=True, color="C1") sns.kdeplot(samples.numpy(), label
<AxesSubplot:ylabel='Density'>
# Logprob function to be passed to Hamiltorch sampler
def logprob(x):
return mog.log_prob(x)
0.0])) logprob(torch.tensor([
tensor([-4.1114])
# Initial state
= torch.tensor([0.0])
x0 = 5000
num_samples = 0.3
step_size = 5
num_steps_per_sample 123)
hamiltorch.set_random_seed(
= hamiltorch.sample(
params_hmc =logprob,
log_prob_func=x0,
params_init=num_samples,
num_samples=step_size,
step_size=num_steps_per_sample,
num_steps_per_sample )
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:10 | 0d:00:00:00 | #################### | 5000/5000 | 462.18
Acceptance Rate 0.99
= torch.tensor(params_hmc) params_hmc
# Trace plot
500], label="Trace") plt.plot(params_hmc[:
= torch.exp(mog.log_prob(x_lin))
y_lin
plot_samples_gif(
x_lin,
y_lin,
params_hmc,"../figures/sampling/mcmc/hamiltorch-samples-mog.gif",
=300,
frames )
def p_tilde(x):
# normalising constant for standard normal distribution
= torch.sqrt(torch.tensor(2 * np.pi))
Z return dist.Normal(0, 1).log_prob(x).exp() * Z
def p_tilde_log_prob(x):
# normalising constant for standard normal distribution
= torch.sqrt(torch.tensor(2 * np.pi))
Z return dist.Normal(0, 1).log_prob(x) + torch.log(Z)
# Plot unnormalized distribution
= torch.linspace(-3, 3, 1000)
x_lin = p_tilde(x_lin)
y_lin ="Unnormalized distribution")
plt.plot(x_lin, y_lin, label# Plot normalized distribution
plt.plot(0, 1).log_prob(x_lin).exp(), label="Normalized distribution"
x_lin, dist.Normal(
) plt.legend()
<matplotlib.legend.Legend at 0x7fb3378375e0>
# HMC over unnormalized distribution
# Logprob function to be passed to Hamiltorch sampler
def logprob(x):
return p_tilde_log_prob(x)
# HMC
= torch.tensor([0.0])
x0 = 5000
num_samples = 0.3
step_size = 5
num_steps_per_sample 123)
hamiltorch.set_random_seed(
= hamiltorch.sample(
params_hmc =logprob,
log_prob_func=x0,
params_init=num_samples,
num_samples=step_size,
step_size=num_steps_per_sample,
num_steps_per_sample
)
= torch.tensor(params_hmc) params_hmc
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:10 | 0d:00:00:00 | #################### | 5000/5000 | 488.07
Acceptance Rate 0.99
# Trace plot
500], label="Trace") plt.plot(params_hmc[:
# KDE plot
="Samples", shade=True, color="C1")
sns.kdeplot(params_hmc.detach().numpy(), label="Unnormalized distribution")
plt.plot(x_lin, y_lin, label
plt.plot(0, 1).log_prob(x_lin).exp(), label="Normalized distribution"
x_lin, dist.Normal(
) plt.legend()
<matplotlib.legend.Legend at 0x7fb3377450d0>
# Coin toss
= dist.Beta(1, 1)
prior = torch.tensor([1.0, 1.0, 1.0, 0.0, 0.0])
data = len(data)
n
def log_prior(theta):
return prior.log_prob(theta)
def log_likelihood(theta):
return dist.Bernoulli(theta).log_prob(data).sum()
def negative_log_joint(theta):
return log_prior(theta) + log_likelihood(theta)
def run_hmc(logprob, x0, num_samples, step_size, num_steps_per_sample):
12)
torch.manual_seed(= hamiltorch.sample(
params_hmc =logprob,
log_prob_func=x0,
params_init=num_samples,
num_samples=step_size,
step_size=num_steps_per_sample,
num_steps_per_sample
)return torch.stack(params_hmc)
try:
= run_hmc(negative_log_joint, torch.tensor([0.5]), 5000, 0.3, 5)
params_hmc_theta except Exception as e:
print(e)
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
Expected value argument (Tensor of shape (1,)) to be within the support (Interval(lower_bound=0.0, upper_bound=1.0)) of the distribution Beta(), but found invalid values:
tensor([-0.1017], requires_grad=True)
# Let us work instead with logits
def log_prior(logits):
return prior.log_prob(torch.sigmoid(logits)).sum()
def log_likelihood(logits):
return dist.Bernoulli(logits=logits).log_prob(data).sum()
def log_joint(logits):
return log_prior(logits) + log_likelihood(logits)
= run_hmc(log_joint, torch.tensor([0.0]), 1000, 0.3, 5) params_hmc_logits
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:03 | 0d:00:00:00 | #################### | 1000/1000 | 278.54
Acceptance Rate 0.99
= plt.subplots(nrows=2, sharex=True)
fig, ax 0].plot(params_hmc_logits[:500], label="Trace")
ax[1].plot(torch.sigmoid(params_hmc_logits[:500]), label="Trace") ax[
0] params_hmc_logits[:,
tensor([ 0.0000e+00, 5.2157e-01, 1.3223e+00, 1.0867e+00, -5.2568e-01,
-2.0241e-01, 2.1629e+00, 1.4570e+00, 6.9446e-01, 1.0426e+00,
-9.3411e-01, 1.3008e+00, -4.8513e-02, 6.5970e-02, -4.6523e-01,
-5.8596e-01, -4.0162e-01, 1.1607e-01, -6.7595e-01, 6.2900e-01,
-5.1382e-02, -1.5011e-02, 8.3155e-01, 5.7020e-01, 1.1027e+00,
1.7802e+00, 6.8140e-01, -1.4304e+00, 8.8081e-01, -7.4787e-01,
1.1560e+00, 9.3348e-02, 9.5877e-01, 2.4764e-02, 2.5223e-01,
7.0275e-01, 2.4133e+00, 5.0036e-01, -3.1154e-01, 2.6437e+00,
1.0347e+00, -3.0697e+00, -9.9357e-02, 7.0691e-03, 6.6288e-01,
4.2240e-01, 2.6325e-01, 5.3176e-01, -5.1853e-01, -2.3356e-01,
-1.3524e-01, 9.2447e-01, -6.1977e-01, 5.7892e-01, 9.1914e-01,
1.2103e+00, -3.4587e-01, -8.8387e-01, 4.2828e-01, 9.6681e-01,
4.3113e-01, -6.5857e-01, -5.3167e-02, -1.2148e+00, -1.8391e+00,
2.7865e-01, -1.1244e+00, 1.6879e+00, 8.4666e-01, 1.7127e-01,
1.4485e+00, -5.0676e-01, 9.4826e-01, 1.2424e-01, -4.2752e-01,
3.0383e-01, -1.5217e-01, 6.3599e-02, 1.3415e+00, 1.1742e+00,
8.9259e-01, 9.5579e-01, 8.8160e-01, 3.3091e-01, -1.1668e+00,
1.0816e+00, 5.1343e-02, -1.4267e-01, 5.5287e-01, 1.1972e+00,
-3.1179e-01, 3.0234e-01, 1.9230e+00, -5.4003e-01, 1.9161e+00,
2.5963e-01, 5.3916e-01, -2.3909e-01, -4.7502e-01, 5.5535e-01,
9.9362e-01, -5.1886e-01, -1.0049e-01, -3.5520e-01, 4.1397e-02,
-2.3988e-01, 7.3340e-01, -6.0357e-01, -1.3952e+00, 1.2863e+00,
6.9414e-01, -3.6833e-01, -9.2557e-01, 7.9619e-01, 5.7926e-01,
5.2438e-01, -4.5046e-01, 9.7921e-01, -1.7372e+00, 7.9464e-01,
-3.8177e-01, 8.8204e-01, 4.0353e-01, 1.4525e+00, 8.4391e-01,
1.8884e+00, 4.2759e-01, 7.8184e-01, 1.4167e+00, 2.3686e+00,
-4.3994e-01, 2.5539e+00, -2.1615e-01, 3.3379e-01, 4.2178e-01,
5.1721e-03, 3.8601e-01, 1.7428e-02, 4.2170e-01, 7.6588e-01,
-1.8764e-01, -1.2382e-01, 9.7327e-01, 2.6694e+00, 1.3677e+00,
3.7424e-01, 6.3888e-01, -8.4959e-01, -8.8009e-01, 2.3543e-02,
1.6492e+00, 2.7357e-01, -1.3845e+00, 1.9475e+00, -9.1431e-01,
-1.3962e+00, 1.4945e+00, -1.6934e-01, 6.8331e-01, 2.7269e+00,
6.1146e-01, 7.2529e-01, -7.0244e-01, 1.5423e+00, 5.0477e-01,
7.5417e-01, 1.5440e+00, 7.5432e-01, -1.0567e+00, 6.8278e-03,
-9.1455e-01, 1.1774e+00, -4.7259e-02, -7.6331e-01, 1.5087e+00,
1.8454e+00, -2.9682e-01, 2.4198e-01, 1.4402e+00, 1.5362e+00,
3.3266e+00, 7.4973e-01, 1.4066e+00, 3.8852e-01, -4.4323e-01,
2.5634e-01, -5.2011e-01, -8.8203e-01, 1.8299e+00, 6.3801e-01,
1.4932e+00, 1.7013e+00, 2.8509e-01, 5.3203e-01, 3.7003e-01,
3.6926e-01, 2.1920e-01, -1.1465e-01, 1.6096e-01, 1.6096e-01,
1.6096e-01, 1.6096e-01, 7.3189e-02, -1.2568e+00, -1.0996e+00,
-6.6443e-01, 4.9931e-01, -1.1474e+00, 8.1839e-01, -4.1737e-01,
-7.6412e-02, 6.2048e-01, 4.2288e-01, -1.1374e-01, 1.0812e+00,
-3.7642e-01, 3.3120e-01, 6.6237e-02, -5.7989e-01, 8.3937e-01,
-1.4361e+00, 6.1602e-01, 8.5366e-01, -3.5081e-02, -1.2576e+00,
9.2308e-01, 1.2499e+00, 7.9868e-01, -4.0161e-01, 1.8000e+00,
2.1205e+00, 2.2769e+00, 1.1112e+00, 1.6897e-01, 3.6750e-01,
-4.5992e-01, -6.0003e-01, 9.8896e-01, 9.2506e-01, -1.2205e+00,
6.2656e-01, 2.4740e+00, 9.4190e-02, 1.1711e+00, -4.7227e-01,
1.7815e+00, 4.5940e-01, 6.8583e-01, -2.6535e-01, 3.0298e-01,
-5.7829e-01, 1.0377e+00, 1.3456e+00, 2.0190e-01, 7.1908e-01,
9.7466e-01, 6.5028e-01, -5.2753e-01, 5.5185e-01, -6.3477e-01,
1.1783e+00, 4.5608e-01, 1.2227e+00, 1.8145e-01, -6.6724e-01,
1.1463e+00, 1.3225e+00, 1.1086e+00, -1.1422e+00, -1.5377e-01,
9.5494e-01, 1.0698e-01, 5.1788e-01, -5.0524e-01, -9.8988e-01,
7.0583e-01, -4.0222e-01, 7.1202e-01, 4.4194e-01, 4.2165e-01,
3.9502e-01, 7.1842e-01, 1.2794e+00, 1.0915e+00, 2.5821e+00,
2.6053e+00, 9.5238e-01, 9.0762e-02, 2.6056e+00, -1.5100e+00,
-1.8069e-01, -2.9206e-01, -6.4654e-01, 1.0541e+00, 8.5845e-01,
-2.9253e-01, 9.3141e-01, 9.9560e-01, 7.7755e-01, 6.4045e-01,
1.3102e-01, 8.4557e-01, 8.8637e-01, 1.8827e+00, 1.8122e+00,
1.7478e-01, 3.1308e-02, 4.8694e-01, 4.5788e-01, 7.9718e-02,
9.6236e-01, 1.1176e+00, 2.1386e+00, 1.5038e+00, -9.0693e-01,
1.6350e+00, 6.3332e-01, -1.7795e-01, 3.8162e-01, 2.7007e+00,
1.2218e+00, 3.7495e-01, 3.5609e-01, -3.4414e-01, 1.8059e+00,
-2.0574e+00, -1.6110e+00, 2.0561e-01, -2.0184e+00, 1.4640e+00,
2.9137e-01, 2.2045e+00, 4.0625e+00, 8.4217e-01, -1.9306e+00,
1.3940e+00, -8.0387e-01, -9.4434e-01, 1.9487e+00, 1.9015e-01,
-1.9470e-01, -3.5154e-01, -4.7942e-01, 2.4984e+00, 1.5140e+00,
-5.1064e-01, -3.2031e+00, 4.6453e-02, -3.7326e-01, -7.8675e-02,
7.8834e-01, -7.0566e-01, 2.0437e+00, -7.3646e-01, -5.6822e-01,
1.1436e+00, 2.5813e+00, 2.1137e+00, 3.4318e-01, 6.2561e-01,
2.4020e+00, -6.3089e-01, 6.2534e-01, 4.0191e-01, 1.5595e+00,
3.4065e-01, 7.1987e-01, 1.3318e+00, 4.2015e-01, 1.0865e+00,
3.3707e-01, 4.4238e-01, 1.0057e+00, 4.0243e-01, -6.3879e-01,
-7.3347e-02, -3.9503e-01, 2.3499e-02, 1.2202e+00, 1.0019e+00,
9.6899e-01, 2.8081e-01, 1.4312e+00, 1.7951e+00, 1.7800e-01,
5.4844e-01, -4.9231e-01, 3.4481e-01, -6.0484e-01, 8.6655e-01,
1.5038e+00, -8.8249e-01, -2.3613e-01, 1.4711e+00, 1.9242e+00,
-1.2292e+00, -7.6492e-01, -1.3766e+00, 8.1369e-01, 1.3737e+00,
3.7870e-02, 7.5596e-01, -8.3360e-01, 5.8478e-01, -4.8147e-02,
1.6379e+00, 4.9499e-01, 1.1217e-01, 4.7315e-01, -3.8294e-01,
4.4057e-01, -6.5282e-01, -6.9137e-02, 2.0160e+00, -1.4416e+00,
8.7326e-02, -5.0134e-01, 5.9610e-01, 1.1678e+00, -7.0852e-01,
1.1161e+00, 5.1063e-01, 1.2536e+00, 9.7398e-01, -5.3866e-01,
-1.5744e-01, 1.3044e-01, 5.9685e-01, 1.3283e+00, -2.1259e-01,
5.9374e-01, 2.9467e-01, 5.7726e-01, -8.1029e-01, 2.4936e-01,
3.2499e-01, 1.4028e+00, 9.2111e-01, 1.6626e-01, 3.2936e-01,
-6.4536e-01, 4.0314e-01, 1.1408e+00, 1.1855e+00, 2.3198e+00,
-5.8188e-01, 1.1133e+00, 4.7731e-01, 1.7239e-01, -2.2253e-01,
5.7479e-01, -3.2582e-01, 3.6510e-01, 7.3250e-01, 1.2381e+00,
7.9945e-01, -2.4175e-01, 1.4102e+00, -2.3131e+00, 5.3915e-01,
1.9302e+00, 7.3718e-01, -1.4272e+00, 1.3649e+00, -3.0076e-02,
9.5023e-01, -3.3479e-01, -1.2208e+00, 2.0888e+00, 4.3582e-01,
3.7189e+00, 2.6196e-01, -9.2166e-01, 1.0887e+00, 3.5772e-01,
1.0647e+00, -1.1317e+00, -1.6681e-01, -1.6681e-01, -1.3717e-01,
-3.4850e-01, 1.5862e+00, 3.1538e+00, 1.4352e+00, 9.1958e-01,
1.2272e+00, -9.4940e-02, 1.4208e+00, 1.4999e+00, 8.8756e-01,
1.4787e+00, -1.6487e-01, -8.3603e-01, -1.1821e+00, 7.5473e-01,
2.0233e+00, -1.4535e+00, 1.6937e+00, -7.1143e-01, -2.5059e-01,
2.3465e-01, 4.7424e-01, 4.7742e-01, 3.3519e-01, 3.3519e-01,
8.8493e-01, 1.1561e+00, -3.1853e-01, -3.9363e-02, 6.7079e-01,
5.6049e-01, -5.4034e-01, -2.2138e-01, 2.1434e-01, 1.5668e+00,
-5.0666e-01, 1.2336e+00, 6.0810e-01, 1.0208e+00, -1.0918e-01,
6.1068e-01, 3.4561e-01, 3.1687e-01, 1.9952e-01, -4.2316e-01,
-6.8547e-01, 1.2960e+00, 4.3325e-01, -7.4669e-01, 3.4048e-01,
5.0448e-01, 1.9143e+00, 3.3597e+00, 7.2327e-01, 7.2327e-01,
1.9814e+00, 1.7529e+00, 1.2328e+00, 2.3874e+00, 8.9311e-01,
1.5371e+00, 6.3213e-01, -5.3584e-01, 1.6869e-01, 7.2835e-01,
5.0431e-01, 1.2030e+00, 1.8420e+00, 3.5781e+00, 2.5207e+00,
9.1733e-01, 1.2418e+00, 1.0947e+00, -3.2848e-01, 1.4335e+00,
1.5814e+00, 1.5399e+00, 1.1806e+00, 2.0780e+00, 2.3361e+00,
7.9356e-02, 1.9177e+00, -2.2122e-01, 2.0555e+00, 1.2023e+00,
5.6297e-01, 1.7306e+00, -4.2852e-01, 9.2878e-01, -5.6072e-01,
6.1711e-01, 1.4516e+00, 9.8128e-01, 1.4595e+00, 6.1320e-01,
-1.5955e+00, 5.4722e-01, 1.9468e+00, 4.5901e-02, 8.8389e-01,
-3.5987e-01, 1.2097e+00, 6.5788e-01, -6.0313e-01, -3.5547e-01,
7.2529e-01, -6.0797e-03, 1.8000e+00, 6.2690e-01, 8.4397e-01,
-5.3849e-01, 1.6233e+00, 1.9809e+00, -3.7226e-01, 1.0213e-01,
3.9237e-01, 5.6708e-02, 5.8428e-01, 9.4924e-01, 2.1116e+00,
1.3148e+00, -1.0424e-01, -2.1544e-02, 1.3272e+00, 3.0616e+00,
9.9475e-01, -1.5997e-01, -5.3483e-01, -5.4581e-01, 1.4050e+00,
1.8539e+00, 1.6646e+00, -3.4664e-01, 2.1784e-01, -7.7635e-01,
2.7252e+00, -6.2144e-01, -8.7422e-01, 4.6023e-01, -6.9424e-01,
-9.1952e-01, -3.1647e-02, 4.7649e-01, 1.8191e+00, -3.3237e-01,
1.6256e+00, 3.2647e-01, 1.3099e+00, 3.6260e-01, 1.0741e+00,
-1.5744e-01, -2.6766e-01, 8.3143e-01, 8.5419e-01, 1.2322e+00,
-1.1485e+00, 3.7815e-01, -2.8660e-01, 1.5768e+00, -3.4889e-01,
9.6623e-01, 7.0870e-01, 5.1634e-01, 1.2491e+00, 8.7313e-01,
-4.7971e-01, 3.6604e-01, -4.4042e-01, 6.1318e-02, -1.5561e-01,
-1.3538e-01, -2.0977e+00, -1.6288e+00, 5.0716e-01, -5.4127e-01,
-5.1384e-01, 1.0936e+00, 1.3244e+00, 3.1679e-01, 1.6746e+00,
1.9762e+00, 1.9939e-01, 1.5693e+00, 1.2094e+00, 3.3478e+00,
1.2666e+00, 2.0992e+00, 2.8287e-02, -4.1169e-01, -6.8056e-01,
1.0113e+00, -6.9923e-01, 1.6349e+00, -7.5508e-01, 1.5240e+00,
-4.1570e-01, -4.9153e-02, 3.1265e-01, 2.4256e+00, 9.4421e-01,
1.2525e-01, 2.6237e-01, -1.7258e-01, 4.4751e-01, -4.6017e-01,
1.7248e-02, 1.3471e+00, -9.5393e-01, 3.8606e-02, 6.6836e-01,
7.4823e-01, 7.3605e-01, 2.7526e-02, 1.6556e+00, 1.6328e+00,
1.8542e+00, 2.4735e+00, 1.0023e+00, -4.9651e-01, -9.9532e-01,
2.1426e-03, -3.1631e-02, -1.9769e+00, -1.2023e+00, 3.9273e-01,
3.5130e-01, 1.9974e-02, 1.8765e+00, -1.3989e+00, 1.0675e+00,
2.0070e-01, 7.1379e-01, -6.7666e-01, 8.5469e-01, 1.8344e-01,
-3.2166e-01, -3.2166e-01, 1.3211e+00, 4.1828e-01, 1.0699e+00,
1.5376e+00, -1.7779e-01, 2.8796e-01, 8.1860e-01, 5.8796e-01,
1.3816e+00, 1.3530e+00, 2.3099e-01, 2.4880e-01, 5.6564e-01,
1.2632e+00, -1.5422e+00, -1.2768e+00, -5.1313e-01, 7.9212e-01,
8.6664e-01, 1.8058e-01, 2.4497e-01, 3.2818e-01, 3.7874e-01,
5.5695e-01, -9.5376e-01, 5.0202e-01, -6.2435e-01, 1.0083e+00,
2.4775e-01, 1.9974e-01, 1.0435e+00, 1.8181e+00, 1.8181e+00,
1.5419e+00, -6.8479e-01, 1.0234e+00, 3.4072e-01, -6.8201e-01,
-6.0094e-03, 5.4767e-01, 1.2085e+00, 1.1278e+00, -2.5554e-01,
-9.6161e-01, -3.8192e-01, 2.3630e-02, 1.0346e+00, -1.9409e-01,
-1.7149e-01, 1.2116e+00, 1.0531e+00, 1.7023e-01, -4.3127e-01,
6.6011e-01, 7.7393e-01, -3.0603e-01, 1.0112e+00, 2.9591e-01,
1.2036e-01, -4.1267e-01, -4.4962e-01, 7.4402e-01, -6.3659e-01,
4.6125e-01, -1.3424e+00, 2.4050e-01, -2.0018e-01, 6.3407e-01,
-2.5910e-01, 9.2403e-02, -3.3480e-01, -3.3480e-01, -6.5366e-01,
2.5464e-01, 1.3055e+00, -1.5981e+00, 7.8798e-01, -4.1375e-01,
7.5003e-01, 6.3855e-02, -3.7143e-01, 1.5067e+00, 1.6203e+00,
5.0651e-01, 1.1662e+00, 6.5912e-01, 6.5912e-01, 9.6239e-01,
-2.9655e-02, -1.8119e+00, 1.4097e+00, 7.1906e-01, 6.5160e-01,
6.4548e-01, 8.0345e-01, 1.8066e+00, -1.8827e+00, -1.9057e+00,
1.7102e-01, -5.2693e-01, 1.8661e-01, 2.8774e-01, 9.7600e-01,
1.5767e+00, 1.0017e+00, 1.4579e+00, -3.2165e-03, 1.1204e+00,
1.5134e+00, 2.0321e+00, 1.6208e-01, -7.9307e-01, 1.5747e+00,
7.5000e-01, -3.5063e-03, 2.1710e+00, 9.3272e-01, -3.8879e-04,
1.1818e+00, -1.7795e-02, -2.4369e-01, -1.9889e+00, -5.1913e-01,
-7.0746e-01, 1.6388e+00, 1.4992e+00, 1.7651e+00, 3.5630e+00,
4.6029e-02, -1.4069e+00, 5.9076e-02, 2.2806e+00, -8.9770e-01,
-6.2536e-01, 1.6230e+00, 1.5135e+00, 7.6925e-01, -3.8131e-01,
1.7301e+00, 1.3660e+00, 6.2360e-01, 5.6905e-01, -3.4309e-01,
-1.1920e+00, 1.8200e+00, -5.0699e-01, 4.0612e-02, 1.2228e+00,
9.6594e-01, -7.0298e-01, -6.2520e-01, 1.6632e+00, 8.8678e-01,
5.2583e-01, 9.1454e-01, 1.4418e+00, 1.4364e+00, 5.6974e-01,
-1.1927e+00, 4.6645e-01, 5.5446e-01, 7.2413e-01, 1.9243e+00,
7.7880e-01, -7.4584e-01, 1.2390e+00, -1.1413e-01, 1.2299e+00,
1.0727e+00, -2.9502e-01, 3.1766e-01, 3.7765e-01, 9.4944e-01,
-2.1452e-01, -4.7935e-01, 1.0967e+00, -2.7905e-01, -8.6457e-01,
1.1860e+00, -1.1904e+00, 5.9031e-01, 6.6663e-01, 1.2639e+00,
6.7896e-02, -1.0120e+00, 2.5375e+00, 1.8990e+00, -1.0375e+00,
4.4640e-01, 1.4028e+00, 1.1473e+00, 6.1459e-01, 3.9804e-01,
7.8219e-01, 1.6068e+00, 5.6333e-01, 1.6903e-01, -8.0579e-01,
-7.3818e-01, 1.0135e+00, 1.1724e+00, -3.5863e-01, -3.4891e-01,
-2.9014e-01, 1.9020e+00, 1.0430e+00, 2.6433e+00, 3.1035e+00,
2.0789e+00, -4.8755e-01, -1.5892e+00, -1.9662e+00, 1.1934e+00,
1.1145e+00, -7.9593e-02, 1.2906e+00, 3.1053e-01, -5.8267e-01,
1.4457e+00, 7.9235e-01, 9.0271e-01, 4.6118e-01, 1.0147e+00,
2.2204e+00, 1.0090e+00, 2.4175e-01, -6.5125e-01, 5.9952e-01,
2.0044e+00, 8.9869e-01, -2.2489e-01, -1.0497e+00, 2.0135e+00,
-4.7356e-01, 4.7832e-01, -1.2162e-01, 1.1530e+00, -1.1612e-02,
1.4690e+00, 1.0307e+00, 2.2840e+00, 3.3513e+00, 8.8326e-01,
1.2590e+00, 1.7236e+00, 4.5989e-01, 6.4716e-01, 3.0418e-01,
3.9410e-01, 3.4654e-01, 1.4171e+00, 6.6960e-01, 6.7712e-01,
-2.0732e-01, -1.0382e-01, 9.9374e-01, 1.6327e+00, 3.1100e-01,
4.4393e-01, -2.9965e-01, 1.9955e-01, -9.0940e-03, 6.6451e-02,
1.8670e+00, 1.3433e+00, -1.4861e+00, -1.5094e+00, 1.1104e+00,
-4.7540e-02, 8.7289e-01, -8.7741e-01, 1.6113e+00, -2.0599e-01,
2.9714e-01, 2.4836e+00, 8.8014e-01, 2.2057e+00, 1.5712e+00,
6.9619e-02, 2.8588e-02, -3.6439e-01, 7.4889e-01, 9.1645e-01])
# Plot posterior KDE using seaborn but clip to [0, 1]
sns.kdeplot(0]).detach().numpy(),
torch.sigmoid(params_hmc_logits[:, ="Samples",
label=True,
shade="C1",
color=(0, 1),
clip
)# True posterior
= torch.linspace(0, 1, 1000)
x_lin = dist.Beta(1 + 3, 1 + 2).log_prob(x_lin).exp()
y_lin ="True posterior")
plt.plot(x_lin, y_lin, label plt.legend()
<matplotlib.legend.Legend at 0x7fb3376f4ac0>
# Linear regression for 1 dimensional input using HMC
123)
torch.manual_seed(= torch.linspace(-3, 3, 90)
x_lin = torch.tensor([2.0])
theta_0_true = torch.tensor([3.0])
theta_1_true = lambda x: theta_0_true + theta_1_true * x
f = torch.randn_like(x_lin) * 1.0
eps = f(x_lin) + eps
y_lin
="Data", color="C0")
plt.scatter(x_lin, y_lin, label="Ground truth")
plt.plot(x_lin, f(x_lin), label"x")
plt.xlabel("y") plt.ylabel(
Text(0, 0.5, 'y')
# Esimate theta_0, theta_1 using HMC assuming noise variance is known to be 1
def logprob(theta):
= theta[0] + x_lin * theta[1]
y_pred # print(y_pred.shape, y_lin.shape)
# print(y_pred.shape, y_lin.shape, y_pred)
return dist.Normal(y_pred, 1).log_prob(y_lin).sum()
def log_prior(theta):
return dist.Normal(0, 1).log_prob(theta).sum()
def log_posterior(theta):
= log_prior(theta)
log_prior_val = logprob(theta)
log_likelihood = log_prior_val + log_likelihood
log_joint # print(log_joint, log_prior_val, log_likelihood)
return log_joint
= torch.tensor([0.1, 0.2])
params print(params.dtype)
= run_hmc(log_posterior, params, 1000, 0.1, 5) params_hmc_lin_reg
torch.float32
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:03 | 0d:00:00:00 | #################### | 1000/1000 | 265.17
Acceptance Rate 0.83
Sampling (Sampler.HMC; Integrator.IMPLICIT) Time spent | Time remain.| Progress | Samples | Samples/sec tensor(-1433.3967, grad_fn=
Acceptance Rate 1.00
params_hmc_lin_reg
tensor([[0.1000, 0.2000],
[1.7248, 0.6831],
[2.1176, 5.2166],
...,
[1.9162, 3.1325],
[2.1559, 2.9292],
[1.8733, 3.0645]])
# Plot the traces corresponding to the two parameters
= plt.subplots(2, 1, sharex=True)
fig, axes
for i, param_vals in enumerate(params_hmc_lin_reg.T):
="Trace")
axes[i].plot(param_vals, label"Iteration")
axes[i].set_xlabel(rf"$\theta_{i}$")
axes[i].set_ylabel(
# Plot the true values as well
for i, param_vals in enumerate([theta_0_true, theta_1_true]):
="C1", label="Ground truth")
axes[i].axhline(param_vals.numpy(), color axes[i].legend()
findfont: Font family ['cursive'] not found. Falling back to DejaVu Sans.
findfont: Generic family 'cursive' not found because none of the following families were found: Apple Chancery, Textile, Zapf Chancery, Sand, Script MT, Felipa, Comic Neue, Comic Sans MS, cursive
# Plot KDE of the samples for the two parameters
= plt.subplots(2, 1, sharex=True)
fig, axes
for i, param_vals in enumerate(params_hmc_lin_reg.T):
sns.kdeplot(="Samples", shade=True, color="C1", ax=axes[i]
param_vals.detach().numpy(), label
)rf"$\theta_{i}$")
axes[i].set_ylabel(
# Plot the true values as well
for i, param_vals in enumerate([theta_0_true, theta_1_true]):
="C0", label="Ground truth")
axes[i].axvline(param_vals.numpy(), color axes[i].legend()
# Plot the posterior predictive distribution
plt.figure()="Data", color="C0")
plt.scatter(x_lin, y_lin, label="Ground truth", color="C1", linestyle="--")
plt.plot(x_lin, f(x_lin), label"x")
plt.xlabel("y")
plt.ylabel(
# Get posterior samples. Thin first 100 samples to remove burn-in
= params_hmc_lin_reg[100:].detach()
posterior_samples = posterior_samples[:, 0].unsqueeze(1) + x_lin * posterior_samples[
y_hat 1
:, 1)
].unsqueeze(
# Plot mean and 95% confidence interval
=0), label="Mean", color="C2")
plt.plot(x_lin, y_hat.mean(axis
plt.fill_between(
x_lin,=0) - 2 * y_hat.std(axis=0),
y_hat.mean(axis=0) + 2 * y_hat.std(axis=0),
y_hat.mean(axis=0.5,
alpha="95% CI",
label="C2",
color
) plt.legend()
<matplotlib.legend.Legend at 0x7fb336ba0d30>
# Using a neural network with HMC
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(1, 1)
def forward(self, x):
= self.fc1(x)
x return x
= Net()
net net
Net(
(fc1): Linear(in_features=1, out_features=1, bias=True)
)
hamiltorch.util.flatten(net).shape
torch.Size([2])
= hamiltorch.util.flatten(net)
theta_params theta_params.shape
torch.Size([2])
= hamiltorch.util.unflatten(net, theta_params)
params_list 0].shape, params_list[1].shape params_list[
(torch.Size([1, 1]), torch.Size([1]))
= theta_params.clone().detach()
params_init params_init.dtype
torch.float32
= torch.tensor([0.1, 0.2])
t # reverse t
0) t.flip(
tensor([0.2000, 0.1000])
# hamiltorch.util.update_model_params_in_place??
= torch.tensor([0.1, 0.2])
theta print(theta.shape)
= hamiltorch.util.unflatten(net, theta)
params_list print(params_list, params_list[0].shape, params_list[1].shape)
hamiltorch.util.update_model_params_in_place(net, params_list)
net.state_dict()1.0]])) net(torch.tensor([[
torch.Size([2])
[tensor([[0.1000]]), tensor([0.2000])] torch.Size([1, 1]) torch.Size([1])
tensor([[0.3000]], grad_fn=<AddmmBackward0>)
def log_prior(theta):
return dist.Normal(0, 1).log_prob(theta).sum()
def log_likelihood(theta):
= theta.flip(0)
theta = hamiltorch.util.unflatten(net, theta)
params_list
### Inplace call
# hamiltorch.util.update_model_params_in_place(net, params_list)
# y_pred = net(x_lin.unsqueeze(1)).squeeze()
# print(y_pred[0:4], "first")
## Functional call
= net.state_dict()
params for i, (name, _) in enumerate(params.items()):
= params_list[i]
params[name] = torch.func.functional_call(net, params, x_lin.unsqueeze(1)).squeeze()
y_pred # print(y_pred[0:4], "second")
# print(y_pred.shape, y_lin.shape, y_pred)
return dist.Normal(y_pred, 1).log_prob(y_lin).sum()
def log_joint(theta):
= log_prior(theta)
log_prior_val = log_likelihood(theta)
log_likelihood_val = log_prior_val + log_likelihood_val
log_joint # print(log_joint, log_prior_val, log_likelihood_val)
return log_joint
= run_hmc(log_joint, torch.tensor([0.1, 0.2]), 1000, 0.1, 5)
params_hmc params_hmc
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent | Time remain.| Progress | Samples | Samples/sec
0d:00:00:04 | 0d:00:00:00 | #################### | 1000/1000 | 201.06
Acceptance Rate 0.81
tensor([[0.1000, 0.2000],
[1.6531, 0.5802],
[2.0851, 5.3383],
...,
[1.9668, 3.0019],
[2.0264, 3.0395],
[2.0264, 3.0395]])
# Plot the traces corresponding to the two parameters
= plt.subplots(2, 1, sharex=True)
fig, axes
for i, param_vals in enumerate(params_hmc.T):
="Trace")
axes[i].plot(param_vals, label"Iteration")
axes[i].set_xlabel(rf"$\theta_{i}$") axes[i].set_ylabel(
# Plot KDE of the samples for the two parameters
= plt.subplots(2, 1, sharex=True)
fig, axes
for i, param_vals in enumerate(params_hmc.T):
sns.kdeplot(="Samples", shade=True, color="C1", ax=axes[i]
param_vals.detach().numpy(), label
)rf"$\theta_{i}$") axes[i].set_ylabel(
NameError: name 'sns' is not defined
# Plot predictions
plt.figure()="Data", color="C0")
plt.scatter(x_lin, y_lin, label="Ground truth", color="C1", linestyle="--")
plt.plot(x_lin, f(x_lin), label"x")
plt.xlabel("y")
plt.ylabel(
# Get posterior samples. Thin first 100 samples to remove burn-in
= params_hmc[100:].detach()
posterior_samples with torch.no_grad():
= net(x_lin.unsqueeze(1))
y_hat
# Plot mean and 95% confidence interval
="Mean", color="C2") plt.plot(x_lin, y_hat.ravel(), label
y_hat.ravel()
tensor([-1.4220, -1.3707, -1.3194, -1.2681, -1.2168, -1.1656, -1.1143, -1.0630,
-1.0117, -0.9604, -0.9091, -0.8579, -0.8066, -0.7553, -0.7040, -0.6527,
-0.6014, -0.5501, -0.4989, -0.4476, -0.3963, -0.3450, -0.2937, -0.2424,
-0.1912, -0.1399, -0.0886, -0.0373, 0.0140, 0.0653, 0.1165, 0.1678,
0.2191, 0.2704, 0.3217, 0.3730, 0.4242, 0.4755, 0.5268, 0.5781,
0.6294, 0.6807, 0.7320, 0.7832, 0.8345, 0.8858, 0.9371, 0.9884,
1.0397, 1.0909, 1.1422, 1.1935, 1.2448, 1.2961, 1.3474, 1.3986,
1.4499, 1.5012, 1.5525, 1.6038, 1.6551, 1.7064, 1.7576, 1.8089,
1.8602, 1.9115, 1.9628, 2.0141, 2.0653, 2.1166, 2.1679, 2.2192,
2.2705, 2.3218, 2.3730, 2.4243, 2.4756, 2.5269, 2.5782, 2.6295,
2.6808, 2.7320, 2.7833, 2.8346, 2.8859, 2.9372, 2.9885, 3.0397,
3.0910, 3.1423])
# Now, solve the above using Hamiltorch's MCMC sample_model function
### Bayesian Logistic Regression
from sklearn.datasets import make_moons
# Generate data
= make_moons(n_samples=1000, noise=0.1, random_state=0)
x, y
0], x[:, 1], c=y)
plt.scatter(x[:,
= torch.tensor(x).float()
x = torch.tensor(y).float() y