import torch
import torch.nn as nn
import torch.distributions as dist
import torch.autograd.functional as F
from torch.func import grad, jacfwd, hessian
from math import factorial
from ipywidgets import interact
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.datasets import make_blobs
from tueplots.bundles import beamer_moml
import matplotlib.pyplot as plt
# Use render mode to run the notebook and save the plots in beamer format
# Use interactive mode to run the notebook and show the plots in notebook-friendly format
= "render" # "interactive" or "render"
mode
if mode == "render":
= 0.6
width =width, rel_height=width * 0.8))
plt.rcParams.update(beamer_moml(rel_width# update marker size
"lines.markersize": 4})
plt.rcParams.update({"figure.facecolor"] = "none"
plt.rcParams[else:
plt.rcdefaults()
1D Taylor approximation
def plt_show(name=None):
if mode == "interactive":
plt.show()elif mode == "render":
f"../figures/laplace-approx/{name}.pdf")
plt.savefig(else:
raise ValueError(f"Unknown mode: {mode}")
# Plot
= lambda x: torch.sin(x + 1)
f
= torch.linspace(-5, 5, 100)
x
plt.plot(x, f(x))"x")
plt.xlabel("f(x)")
plt.ylabel("sin") plt_show(
def get_taylor_1d_fn(f, x, a, ord):
= f(a).repeat(x.size())
term = r"$\tilde{f}(x) = $" + f"{f(a).item():.2f}"
poly for i in range(1, ord + 1):
= grad(f)
f = f(a)
value = f"{abs(value.item()):.2f}"
value_str = factorial(i)
denominator = f"(x - {a.item():.2f})^{i}"
poly_term = term + value * (x - a) ** i / denominator
term if i <= 5:
+= (
poly f"{' - ' if value < 0 else ' + '}"
+ value_str
+ r"$\frac{"
+ poly_term
+ "}{"
+ f"{i}!"
+ "}"
+ "$"
)elif i == 6:
+= " + ..."
poly return term.detach().numpy(), poly
@interact(ord=(0, 12))
def plot_1d_taylor(ord):
="f(x)")
plt.plot(x, f(x), label= get_taylor_1d_fn(f, x, a=torch.tensor(0.0), ord=ord)
term, poly
plt.plot(
x,
term,# get_taylor_1d_fn(f, x, a=torch.tensor(0.0), ord=ord),
=f"Taylor aproximation\nPolynomial degree: {ord}",
label="--",
linestyle
)"x")
plt.xlabel("p(x)")
plt.ylabel(
plt.title(poly)=(1.05, 1), loc="upper left")
plt.legend(bbox_to_anchor-1.5, 1.5)
plt.ylim(f"sin-taylor-{ord}") plt_show(
ND Taylor approximation
Checking that two times jacobian is a hessian
= lambda x: torch.sin(x[0] + 1) + torch.cos(x[1] - 1)
f = torch.tensor([1.0, 2.0])
inp
display(jacfwd(jacfwd(f))(inp)) display(hessian(f)(inp))
tensor([[-0.9093, -0.0000],
[-0.0000, -0.5403]])
tensor([[-0.9093, 0.0000],
[ 0.0000, -0.5403]])
# define a function which is 45 degree rotated sin function
= lambda x: torch.sin(x.sum() + 1)
f
= torch.linspace(-5, 5, 40)
x = torch.meshgrid(x, x)
X1, X2 = torch.stack([X1.ravel(), X2.ravel()], dim=-1)
X = torch.vmap(f)(X).reshape(X1.shape)
Y print(X.shape, Y.shape)
=20, vmin=-1, vmax=1, cmap="coolwarm")
plt.contourf(X1, X2, Y, levels"$x_1$")
plt.xlabel("$x_2$")
plt.ylabel(="$f(x)$", ticks=[-1, -0.5, 0, 0.5, 1])
plt.colorbar(label"equal")
plt.gca().set_aspect("sin2d") plt_show(
torch.Size([1600, 2]) torch.Size([40, 40])
def nd_taylor_approx(x):
= x.reshape(1, -1)
x = f(a)
term if ord >= 1:
= jacfwd(f)(a)
jacobian = (term + jacobian @ (x - a).T).squeeze()
term if ord >= 2:
= hessian(f)(a.ravel())
hess = term + (((x - a) @ hess @ (x - a).T) / 2).squeeze()
term if ord >= 3:
raise NotImplementedError
return term
= torch.vmap(nd_taylor_approx)
nd_taylor_approx
ord = 2
= torch.tensor([0.0, 0.0], dtype=torch.float64).reshape(1, 2)
a
= nd_taylor_approx(X).reshape(X1.shape)
T_Y = plt.figure()
fig # add a 3d axis
= fig.add_subplot(121, projection="3d")
ax ="coolwarm", vmin=-1.5, vmax=1.5)
ax.plot_surface(X1, X2, Y, cmap-1.5, 1.5)
ax.set_zlim(= fig.add_subplot(122, projection="3d")
ax2 ="coolwarm", vmin=-1.5, vmax=1.5)
ax2.plot_surface(X1, X2, T_Y, cmap-1.5, 1.5)
ax2.set_zlim(# mappable = ax[0].contourf(X1, X2, Y.reshape(X1.shape), levels=20)
# fig.colorbar(mappable, ax=ax[0])
# mappable = ax[1].contourf(X1, X2, T_Y, levels=20, vmin=Y.min(), vmax=Y.max())
"$x_1$", labelpad=-2)
ax.set_xlabel("$x_1$", labelpad=-2)
ax2.set_xlabel("$x_2$", labelpad=-2)
ax2.set_ylabel("$x_2$", labelpad=-2)
ax.set_ylabel(f"Taylor approximation\nPolynomial degree: {ord}")
fig.suptitle(f"sin2d-taylor-{ord}") plt_show(
10, :], label="$f(10, x_2)$")
plt.plot(x, Y[10, :], label=r"$\tilde{{f}}(10, x_2)$")
plt.plot(x, T_Y[-1.5, 1.5)
plt.ylim("$x_2$")
plt.xlabel("$y$")
plt.ylabel(=(1.05, 1), loc="upper left")
plt.legend(bbox_to_anchor"Cross section at $x_1 = 0$")
plt.title(f"sin2d-taylor-1d-{ord}") plt_show(
# Simple PDF approximation and assume it to be our prob function
= torch.distributions.Normal(0, 1)
p
= lambda x: p.log_prob(x)
logp
# Plot
= torch.linspace(-4, 4, 100)
x
plt.plot(x, logp(x).exp())"x")
plt.xlabel("p(x)")
plt.ylabel("standard-normal") plt_show(
Beta prior for coin toss
= torch.tensor([1] * 9 + [0] * 1, dtype=torch.float64)
data = 2
alpha = 2
beta
def neg_log_prior(theta):
return -dist.Beta(alpha, beta).log_prob(theta)
def neg_log_likelihood(theta):
= torch.where(data == 1, torch.log(theta), torch.log(1 - theta))
likelihood #likelihood = dist.Bernoulli(probs=theta).log_prob(data)
return -likelihood.sum()
def neg_log_joint(theta):
return neg_log_prior(theta) + neg_log_likelihood(theta)
= torch.linspace(0.01, 0.99, 100)
theta_grid
def plot_prior_and_lik():
= torch.exp(-neg_log_prior(theta_grid))
prior = plt.subplots()
fig, ax ="Prior", color="C0")
ax.plot(theta_grid, prior, label= ax.twinx()
twinx
twinx.plot(
theta_grid,-torch.vmap(neg_log_likelihood)(theta_grid)),
torch.exp(="Likelihood",
label="C1",
color
)r"$\theta$")
ax.set_xlabel(r"$p(\theta)$")
ax.set_ylabel(r"$p(D|\theta)$")
twinx.set_ylabel(return fig, ax, twinx
= plot_prior_and_lik()
fig, ax, twinx =(1.05, 1), loc="upper left")
fig.legend(bbox_to_anchor"beta-prior-coin-toss") plt_show(
= torch.tensor(0.5, requires_grad=True)
theta_map
def optimize(theta, epochs, lr):
= torch.optim.Adam([theta], lr=lr)
optimizer
= tqdm(range(epochs))
pbar = []
losses for epoch in pbar:
optimizer.zero_grad()= neg_log_joint(theta)
loss
loss.backward()
optimizer.step()f"loss={loss.item():.2f}")
pbar.set_description(
losses.append(loss.item())return losses
= optimize(theta_map, epochs=500, lr=0.01)
losses print(theta_map)
plt.plot(losses) plt.show()
loss=3.61: 100%|██████████| 500/500 [00:00<00:00, 694.99it/s]
tensor(0.8333, requires_grad=True)
= plot_prior_and_lik()
fig, ax, twinx with torch.no_grad():
*ax.get_ylim(), label="MAP", color="C2")
ax.vlines(theta_map.item(),
=(1.05, 1), loc="upper left")
fig.legend(bbox_to_anchor"beta-prior-coin-toss-map") plt_show(
= hessian(neg_log_joint)(theta_map)
hess = 1 / hess
posterior_variance = dist.Normal(theta_map, posterior_variance**0.5) approx_posterior
= dist.Beta(alpha + data.sum(), beta + len(data) - data.sum())
true_posterior = plot_prior_and_lik()
fig, ax, twinx with torch.no_grad():
ax.plot(
theta_grid,
torch.exp(approx_posterior.log_prob(theta_grid)),="Laplace Approx Posterior",
label="C3",
color="--",
linestyle
)
ax.plot(
theta_grid,
torch.exp(true_posterior.log_prob(theta_grid)),="True Posterior",
label="C4",
color
)*ax.get_ylim(), label="MAP", color="C2")
ax.vlines(theta_map.item(),
=(1.05, 1), loc="upper left")
fig.legend(bbox_to_anchor"beta-prior-coin-toss-laplace") plt_show(
Normal prior for coin toss
= torch.tensor([1] * 9 + [0] * 1, dtype=torch.float64)
data = torch.tensor(-2.0)
prior_mean = torch.tensor(1.0)
prior_variance
def neg_log_prior(theta):
return -dist.Normal(prior_mean, prior_variance**0.5).log_prob(theta).squeeze()
def neg_log_likelihood(theta):
= torch.sigmoid(theta)
preds = torch.where(data == 1, torch.log(preds), torch.log(1 - preds))
likelihood return -likelihood.sum()
def neg_log_joint(theta):
return neg_log_prior(theta) + neg_log_likelihood(theta)
= torch.linspace(-5, 5, 100)
theta_grid = plot_prior_and_lik()
fig, ax, twinx =(1.05, 1), loc="upper left")
fig.legend(bbox_to_anchor"normal-prior-coin-toss") plt_show(
= torch.tensor(-2.0, requires_grad=True)
theta_map = optimize(theta_map, epochs=500, lr=0.01)
losses
print(theta_map)
plt.plot(losses) plt.show()
loss=9.28: 100%|██████████| 500/500 [00:00<00:00, 729.16it/s]
tensor(0.5143, requires_grad=True)
= plot_prior_and_lik()
fig, ax, twinx with torch.no_grad():
*ax.get_ylim(), label="MAP", color="C2")
ax.vlines(theta_map.item(),
=(1.05, 1), loc="upper left")
fig.legend(bbox_to_anchor"normal-prior-coin-toss-map") plt_show(
= hessian(neg_log_joint)(theta_map)
hess = 1 / hess
posterior_variance = dist.Normal(theta_map, posterior_variance**0.5) approx_posterior
= plot_prior_and_lik()
fig, ax, twinx with torch.no_grad():
ax.plot(
theta_grid,
torch.exp(approx_posterior.log_prob(theta_grid)),="Laplace Approx Posterior",
label="C3",
color="--",
linestyle
)*ax.get_ylim(), label="MAP", color="C2")
ax.vlines(theta_map.item(),
# monte carlo estimation of posterior
= torch.exp(-torch.stack([neg_log_joint(theta) for theta in theta_grid]))
unnorm_p print(unnorm_p.shape)
= dist.Normal(prior_mean, prior_variance**0.5)
prior = prior.sample((100000,))
samples = torch.exp(-torch.vmap(neg_log_likelihood)(samples))
lik = lik.mean()
approx_evidence = unnorm_p / approx_evidence
norm_p
="True (MC) Posterior", color="C4")
ax.plot(theta_grid, norm_p, label
=(1.05, 1), loc="upper left")
fig.legend(bbox_to_anchor"normal-prior-coin-toss-laplace") plt_show(
torch.Size([100])
Multi-Mode
= dist.MixtureSameFamily(
target 0.7, 0.3])),
torch.distributions.Categorical(torch.tensor([-2.0, 2.0]), torch.tensor([1.0, 1.0])),
dist.Normal(torch.tensor([
)
="Target")
plt.plot(theta_grid, torch.exp(target.log_prob(theta_grid)), labelr"$\theta$")
plt.xlabel(r"$p(\theta)$")
plt.ylabel(=(1.05, 1), loc="upper left")
plt.legend(bbox_to_anchor"mixture-density") plt_show(
= lambda x: -target.log_prob(x)
neg_log_joint = torch.tensor(0.0, requires_grad=True)
theta_map = optimize(theta_map, epochs=500, lr=0.01)
losses plt.plot(losses)
loss=1.28: 100%|██████████| 500/500 [00:00<00:00, 829.85it/s]
= hessian(neg_log_joint)(theta_map)
hess = 1 / hess
posterior_variance = dist.Normal(theta_map, posterior_variance**0.5) approx_posterior
="Target")
plt.plot(theta_grid, torch.exp(target.log_prob(theta_grid)), labelwith torch.no_grad():
plt.plot(
theta_grid,
torch.exp(approx_posterior.log_prob(theta_grid)),="Laplace\nApproximation",
label="--",
linestyle
)r"$\theta$")
plt.xlabel(r"$p(\theta)$")
plt.ylabel(=(1.05, 1), loc="upper left")
plt.legend(bbox_to_anchor"mixture-density-laplace") plt_show(
Old
# Optimize logp using SGD
= torch.tensor(4.0, requires_grad=True)
theta = torch.optim.AdamW([theta], lr=0.01)
optimizer
for i in range(2000):
optimizer.zero_grad()= -logp(theta)
loss if i % 100 == 0:
print(i, theta.item(), loss.item())
loss.backward() optimizer.step()
0 4.0 8.918938636779785
100 3.0124826431274414 5.4564642906188965
200 2.1752238273620605 3.284738063812256
300 1.497883915901184 2.040766716003418
400 0.9779170751571655 1.397099494934082
500 0.6023826003074646 1.1003708839416504
600 0.3488367199897766 0.9797820448875427
700 0.18942442536354065 0.9368793368339539
800 0.09626010805368423 0.9235715270042419
900 0.045690640807151794 0.9199823141098022
1000 0.02021404542028904 0.9191428422927856
1100 0.008314372971653938 0.9189730882644653
1200 0.003170008771121502 0.9189435243606567
1300 0.0011164519237354398 0.9189391136169434
1400 0.0003617757756728679 0.9189385771751404
1500 0.00010737218690337613 0.9189385175704956
1600 2.9038295906502753e-05 0.9189385175704956
1700 7.114796062523965e-06 0.9189385175704956
1800 1.5690019381509046e-06 0.9189385175704956
1900 3.091245162067935e-07 0.9189385175704956
= theta.detach()
theta_map theta_map
tensor(5.3954e-08)
="True PDF")
plt.plot(x, p.log_prob(x).exp(), label
# Plot theta_map point
plt.scatter(0,
p.log_prob(theta_map).exp(),=r"$\theta_\textrm{MAP}$",
label="C1",
color=10,
zorder
) plt.legend()
<matplotlib.legend.Legend at 0x7f6a30d168e0>
Error in callback <function flush_figures at 0x7f6a8325d0d0> (for post_execute):
= F.hessian(logp, theta_map)
hessian hessian
= 1 / torch.sqrt(-hessian)
scale scale
# Approximate the PDF using the Laplace approximation
= dist.Normal(theta_map, scale)
approx_p approx_p
# Plot original PDF
= torch.linspace(-10, 10, 100)
x ="True PDF")
plt.plot(x, p.log_prob(x).exp(), label# Plot Laplace approximation
="Laplace Approximation", linestyle="-.")
plt.plot(x, approx_p.log_prob(x).exp(), label plt.legend()
def laplace_approximation(logp, theta_init, lr=0.01, n_iter=2000):
# Optimize logp using an optimizer
= torch.tensor(theta_init, requires_grad=True)
theta = torch.optim.AdamW([theta], lr=lr)
optimizer for i in range(n_iter):
optimizer.zero_grad()= -logp(theta)
loss
loss.backward()
optimizer.step()= theta.detach()
theta_map = F.hessian(logp, theta_map)
hessian = 1 / torch.sqrt(-hessian)
scale return dist.Normal(theta_map, scale)
def plot_orig_approx(logp, approx_p, min_x=-10, max_x=10):
# Plot original PDF
= torch.linspace(min_x, max_x, 500)
x ="True PDF")
plt.plot(x, p.log_prob(x).exp(), label# Plot Laplace approximation
plt.plot(="Laplace Approximation", linestyle="-."
x, approx_p.log_prob(x).exp(), label
) plt.legend()
# Create a Student's t-distribution
= dist.StudentT(5, 0, 1)
p = lambda x: p.log_prob(x)
logp
= laplace_approximation(logp, 4.0)
approx_p plot_orig_approx(logp, approx_p)
= dist.LogNormal(0, 1)
p = lambda x: p.log_prob(x)
logp
= laplace_approximation(logp, 4.0)
approx_p =0.01) plot_orig_approx(logp, approx_p, min_x
= dist.Beta(2, 2)
p = lambda x: p.log_prob(x)
logp
= laplace_approximation(logp, 0.5)
approx_p =0, max_x=1) plot_orig_approx(logp, approx_p, min_x