import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import torch
import torch.autograd.functional as F
import torch.distributions as dist

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

import pandas as pd
%matplotlib inline

# Retina display
%config InlineBackend.figure_format = 'retina'
from tueplots import bundles

# plt.rcParams.update(bundles.icml2022())

# Also add despine to the bundle using rcParams
plt.rcParams["axes.spines.right"] = False
plt.rcParams[""] = False

# Increase font size to match Beamer template
plt.rcParams["font.size"] = 16
# Make background transparent
plt.rcParams["figure.facecolor"] = "none"
import hamiltorch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gt_distribution = torch.distributions.Normal(0, 1)

# Samples from the ground truth distribution
def sample_gt(n):
    return gt_distribution.sample((n,))

samples = sample_gt(1000)
x_lin = torch.linspace(-3, 3, 1000)
y_lin = torch.exp(gt_distribution.log_prob(x_lin))

plt.plot(x_lin, y_lin, label="Ground truth")
# Logprob function to be passed to Hamiltorch sampler
def logprob(x):
    return gt_distribution.log_prob(x)
# Initial state
x0 = torch.tensor([0.0])
num_samples = 5000
step_size = 0.3
num_steps_per_sample = 5
params_hmc = hamiltorch.sample(
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent  | Time remain.| Progress             | Samples   | Samples/sec
0d:00:00:07 | 0d:00:00:00 | #################### | 5000/5000 | 698.22       
Acceptance Rate 0.99
params_hmc = torch.tensor(params_hmc)
# Trace plot
plt.plot(params_hmc, label="Trace")
plt.ylabel("Parameter value")
# view first 500 samples
plt.plot(params_hmc[:500], label="Trace")
plt.ylabel("Parameter value")
# KDE plot
import seaborn as sns

sns.kdeplot(params_hmc.detach().numpy(), label="Samples", shade=True, color="C1")
plt.plot(x_lin, y_lin, label="Ground truth")
plt.xlabel("Parameter value")
def plot_samples_gif(x_lin, y_lin, params_hmc, filename, frames=50):
    fig, ax = plt.subplots()
    plt.plot(x_lin, y_lin, label="Ground truth")
    scatter = ax.scatter([], [], color="C1", marker="x", s=100)

    # Function to update the animation
    def update(frame):
        scatter.set_offsets(np.array([[params_hmc[frame], 0]]))
        return (scatter,)

    # Create the animation
    anim = FuncAnimation(fig, update, frames=frames, blit=True)

    # Save the animation as a GIF or video file (you can change the filename and format), dpi=200)
    x_lin, y_lin, params_hmc, "../figures/sampling/mcmc/hamiltorch-samples-normal.gif"

# sample from Mixture of Gaussians

mog = dist.MixtureSameFamily(
    mixture_distribution=dist.Categorical(torch.tensor([0.3, 0.7])),
        torch.tensor([-2.0, 2.0]), torch.tensor([1.0, 0.5])

samples = mog.sample((1000,))
sns.kdeplot(samples.numpy(), label="Samples", shade=True, color="C1")

# Logprob function to be passed to Hamiltorch sampler
def logprob(x):
    return mog.log_prob(x)

# Initial state
x0 = torch.tensor([0.0])
num_samples = 5000
step_size = 0.3
num_steps_per_sample = 5

params_hmc = hamiltorch.sample(
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent  | Time remain.| Progress             | Samples   | Samples/sec
0d:00:00:10 | 0d:00:00:00 | #################### | 5000/5000 | 462.18       
Acceptance Rate 0.99
params_hmc = torch.tensor(params_hmc)
# Trace plot
plt.plot(params_hmc[:500], label="Trace")

y_lin = torch.exp(mog.log_prob(x_lin))

def p_tilde(x):
    # normalising constant for standard normal distribution
    Z = torch.sqrt(torch.tensor(2 * np.pi))
    return dist.Normal(0, 1).log_prob(x).exp() * Z

def p_tilde_log_prob(x):
    # normalising constant for standard normal distribution
    Z = torch.sqrt(torch.tensor(2 * np.pi))
    return dist.Normal(0, 1).log_prob(x) + torch.log(Z)
# Plot unnormalized distribution
x_lin = torch.linspace(-3, 3, 1000)
y_lin = p_tilde(x_lin)
plt.plot(x_lin, y_lin, label="Unnormalized distribution")
# Plot normalized distribution
    x_lin, dist.Normal(0, 1).log_prob(x_lin).exp(), label="Normalized distribution"
# HMC over unnormalized distribution

# Logprob function to be passed to Hamiltorch sampler
def logprob(x):
    return p_tilde_log_prob(x)
x0 = torch.tensor([0.0])
num_samples = 5000
step_size = 0.3
num_steps_per_sample = 5

params_hmc = hamiltorch.sample(

params_hmc = torch.tensor(params_hmc)
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent  | Time remain.| Progress             | Samples   | Samples/sec
0d:00:00:10 | 0d:00:00:00 | #################### | 5000/5000 | 488.07       
Acceptance Rate 0.99
# Trace plot
plt.plot(params_hmc[:500], label="Trace")

# KDE plot
sns.kdeplot(params_hmc.detach().numpy(), label="Samples", shade=True, color="C1")
plt.plot(x_lin, y_lin, label="Unnormalized distribution")
    x_lin, dist.Normal(0, 1).log_prob(x_lin).exp(), label="Normalized distribution"
# Coin toss

prior = dist.Beta(1, 1)
data = torch.tensor([1.0, 1.0, 1.0, 0.0, 0.0])
n = len(data)

def log_prior(theta):
    return prior.log_prob(theta)

def log_likelihood(theta):
    return dist.Bernoulli(theta).log_prob(data).sum()

def negative_log_joint(theta):
    return log_prior(theta) + log_likelihood(theta)
def run_hmc(logprob, x0, num_samples, step_size, num_steps_per_sample):
    params_hmc = hamiltorch.sample(
    return torch.stack(params_hmc)
    params_hmc_theta = run_hmc(negative_log_joint, torch.tensor([0.5]), 5000, 0.3, 5)
except Exception as e:
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent  | Time remain.| Progress             | Samples   | Samples/sec
# Let us work instead with logits
def log_prior(logits):
    return prior.log_prob(torch.sigmoid(logits)).sum()

def log_likelihood(logits):
    return dist.Bernoulli(logits=logits).log_prob(data).sum()

def log_joint(logits):
    return log_prior(logits) + log_likelihood(logits)
params_hmc_logits = run_hmc(log_joint, torch.tensor([0.0]), 1000, 0.3, 5)
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent  | Time remain.| Progress             | Samples   | Samples/sec
0d:00:00:03 | 0d:00:00:00 | #################### | 1000/1000 | 278.54       
Acceptance Rate 0.99
fig, ax = plt.subplots(nrows=2, sharex=True)
ax[0].plot(params_hmc_logits[:500], label="Trace")
ax[1].plot(torch.sigmoid(params_hmc_logits[:500]), label="Trace")

params_hmc_logits[:, 0]
# Plot posterior KDE using seaborn but clip to [0, 1]
    torch.sigmoid(params_hmc_logits[:, 0]).detach().numpy(),
    clip=(0, 1),
# True posterior
x_lin = torch.linspace(0, 1, 1000)
y_lin = dist.Beta(1 + 3, 1 + 2).log_prob(x_lin).exp()
plt.plot(x_lin, y_lin, label="True posterior")
# Linear regression for 1 dimensional input using HMC

x_lin = torch.linspace(-3, 3, 90)
theta_0_true = torch.tensor([2.0])
theta_1_true = torch.tensor([3.0])
f = lambda x: theta_0_true + theta_1_true * x
eps = torch.randn_like(x_lin) * 1.0
y_lin = f(x_lin) + eps

plt.scatter(x_lin, y_lin, label="Data", color="C0")
plt.plot(x_lin, f(x_lin), label="Ground truth")
# Esimate theta_0, theta_1 using HMC assuming noise variance is known to be 1
def logprob(theta):
    y_pred = theta[0] + x_lin * theta[1]
    # print(y_pred.shape, y_lin.shape)
    # print(y_pred.shape, y_lin.shape, y_pred)
    return dist.Normal(y_pred, 1).log_prob(y_lin).sum()

def log_prior(theta):
    return dist.Normal(0, 1).log_prob(theta).sum()

def log_posterior(theta):
    log_prior_val = log_prior(theta)
    log_likelihood = logprob(theta)
    log_joint = log_prior_val + log_likelihood
    # print(log_joint, log_prior_val, log_likelihood)
    return log_joint
params = torch.tensor([0.1, 0.2])
params_hmc_lin_reg = run_hmc(log_posterior, params, 1000, 0.1, 5)
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent  | Time remain.| Progress             | Samples   | Samples/sec
0d:00:00:03 | 0d:00:00:00 | #################### | 1000/1000 | 265.17       
Acceptance Rate 0.83

Acceptance Rate 1.00

# Plot the traces corresponding to the two parameters
fig, axes = plt.subplots(2, 1, sharex=True)

for i, param_vals in enumerate(params_hmc_lin_reg.T):
    axes[i].plot(param_vals, label="Trace")

# Plot the true values as well
for i, param_vals in enumerate([theta_0_true, theta_1_true]):
    axes[i].axhline(param_vals.numpy(), color="C1", label="Ground truth")
# Plot KDE of the samples for the two parameters
fig, axes = plt.subplots(2, 1, sharex=True)

for i, param_vals in enumerate(params_hmc_lin_reg.T):
        param_vals.detach().numpy(), label="Samples", shade=True, color="C1", ax=axes[i]

# Plot the true values as well
for i, param_vals in enumerate([theta_0_true, theta_1_true]):
    axes[i].axvline(param_vals.numpy(), color="C0", label="Ground truth")

# Plot the posterior predictive distribution
plt.scatter(x_lin, y_lin, label="Data", color="C0")
plt.plot(x_lin, f(x_lin), label="Ground truth", color="C1", linestyle="--")

# Get posterior samples. Thin first 100 samples to remove burn-in
posterior_samples = params_hmc_lin_reg[100:].detach()
y_hat = posterior_samples[:, 0].unsqueeze(1) + x_lin * posterior_samples[
    :, 1

# Plot mean and 95% confidence interval

plt.plot(x_lin, y_hat.mean(axis=0), label="Mean", color="C2")
    y_hat.mean(axis=0) - 2 * y_hat.std(axis=0),
    y_hat.mean(axis=0) + 2 * y_hat.std(axis=0),
    label="95% CI",
# Using a neural network with HMC

class Net(torch.nn.Module):
    def __init__(self):
        self.fc1 = torch.nn.Linear(1, 1)

    def forward(self, x):
        x = self.fc1(x)
        return x
net = Net()
  (fc1): Linear(in_features=1, out_features=1, bias=True)
theta_params = hamiltorch.util.flatten(net)
params_list = hamiltorch.util.unflatten(net, theta_params)
params_list[0].shape, params_list[1].shape
(torch.Size([1, 1]), torch.Size([1]))
params_init = theta_params.clone().detach()
t = torch.tensor([0.1, 0.2])
# reverse t
# hamiltorch.util.update_model_params_in_place??
theta = torch.tensor([0.1, 0.2])
params_list = hamiltorch.util.unflatten(net, theta)
print(params_list, params_list[0].shape, params_list[1].shape)
hamiltorch.util.update_model_params_in_place(net, params_list)
[tensor([[0.1000]]), tensor([0.2000])] torch.Size([1, 1]) torch.Size([1])
tensor([[0.3000]], grad_fn=<AddmmBackward0>)
def log_prior(theta):
    return dist.Normal(0, 1).log_prob(theta).sum()

def log_likelihood(theta):
    theta = theta.flip(0)
    params_list = hamiltorch.util.unflatten(net, theta)

    ### Inplace call
    # hamiltorch.util.update_model_params_in_place(net, params_list)
    # y_pred = net(x_lin.unsqueeze(1)).squeeze()
    # print(y_pred[0:4], "first")

    ## Functional call
    params = net.state_dict()
    for i, (name, _) in enumerate(params.items()):
        params[name] = params_list[i]
    y_pred = torch.func.functional_call(net, params, x_lin.unsqueeze(1)).squeeze()
    # print(y_pred[0:4], "second")

    # print(y_pred.shape, y_lin.shape, y_pred)
    return dist.Normal(y_pred, 1).log_prob(y_lin).sum()

def log_joint(theta):
    log_prior_val = log_prior(theta)
    log_likelihood_val = log_likelihood(theta)
    log_joint = log_prior_val + log_likelihood_val
    # print(log_joint, log_prior_val, log_likelihood_val)
    return log_joint

params_hmc = run_hmc(log_joint, torch.tensor([0.1, 0.2]), 1000, 0.1, 5)
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent  | Time remain.| Progress             | Samples   | Samples/sec
0d:00:00:04 | 0d:00:00:00 | #################### | 1000/1000 | 201.06       
Acceptance Rate 0.81
# Plot the traces corresponding to the two parameters
fig, axes = plt.subplots(2, 1, sharex=True)

for i, param_vals in enumerate(params_hmc.T):
    axes[i].plot(param_vals, label="Trace")

# Plot KDE of the samples for the two parameters
fig, axes = plt.subplots(2, 1, sharex=True)

for i, param_vals in enumerate(params_hmc.T):
        param_vals.detach().numpy(), label="Samples", shade=True, color="C1", ax=axes[i]
# Plot predictions
plt.scatter(x_lin, y_lin, label="Data", color="C0")
plt.plot(x_lin, f(x_lin), label="Ground truth", color="C1", linestyle="--")

# Get posterior samples. Thin first 100 samples to remove burn-in
posterior_samples = params_hmc[100:].detach()
with torch.no_grad():
    y_hat = net(x_lin.unsqueeze(1))

# Plot mean and 95% confidence interval

plt.plot(x_lin, y_hat.ravel(), label="Mean", color="C2")

# Now, solve the above using Hamiltorch's MCMC sample_model function
### Bayesian Logistic Regression

from sklearn.datasets import make_moons

# Generate data
x, y = make_moons(n_samples=1000, noise=0.1, random_state=0)

plt.scatter(x[:, 0], x[:, 1], c=y)

x = torch.tensor(x).float()
y = torch.tensor(y).float()