import torch
import matplotlib.pyplot as plt
= torch.distributions
dist
%matplotlib inline
# retina
%config InlineBackend.figure_format = 'retina'
= 20
alpha = 30
beta = dist.Beta(alpha, beta) post
= torch.linspace(0., 1, 500) x_lin
= post.log_prob(x_lin).exp()
ys plt.plot(x_lin, ys)
# MAP of post (Beta(20, 30))
= torch.tensor((alpha - 1) / (alpha + beta - 2)) theta_map
theta_map
tensor(0.3958)
plt.plot(x_lin, ys)='red') plt.axvline(theta_map.item(), color
<matplotlib.lines.Line2D at 0x7f6d507192e0>
= lambda x: post.log_prob(x)
f
from torch.autograd.functional import hessian
= 1/torch.sqrt(-hessian(f, theta_map))
scale
# Find gradient of log_post wrt theta
plt.plot(x_lin, post.log_prob(x_lin))='red')
plt.axvline(theta_map.item(), color
= dist.Normal(theta_map, scale)
appx plt.plot(x_lin, appx.log_prob(x_lin))
import jax
import tensorflow_probability.substrates.jax.distributions as tfd
= tfd.Beta(alpha, beta) beta_jax
-jax.hessian(beta_jax.log_prob)(theta_map)
TypeError: Argument '0.3958333432674408' of type <class 'torch.Tensor'> is not a valid JAX type.
jax.hessian(beta_jax.log_prob(theta_map))
TypeError: Expected a callable value, got 1.7470932006835938
= dist.Normal(theta_map, ) appx_posterior