import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'
from tueplots import bundles
plt.rcParams.update(bundles.beamer_moml())
# Also add despine to the bundle using rcParams
'axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams[
# Increase font size to match Beamer template
'font.size'] = 16
plt.rcParams[# Make background transparent
'figure.facecolor'] = 'none' plt.rcParams[
## Rejection sampling
import torch.distributions as D
= D.Categorical(torch.ones(2,))
mix = D.Normal(torch.tensor([-0.5, 1.0]), torch.tensor([0.3, 0.5]))
comp = D.MixtureSameFamily(mix, comp) mog
comp.scale
tensor([0.3000, 0.5000])
# Plot the mizxture of Gaussians
= torch.linspace(-3, 3, 100)
xs plt.plot(xs, mog.log_prob(xs).exp())
# Take a proposal distribution q(x) = N(0, 1)
= D.Normal(0, 1) q
# Let $M$ be a constant such that $M \geq \frac{p(x)}{q(x)} \forall x$.
= torch.max(mog.log_prob(xs) - q.log_prob(xs)).exp()
M M
tensor(1.9451)
- q.log_prob(xs)) torch.argmax(mog.log_prob(xs)
tensor(72)
31] xs[
tensor(-1.1212)
def plot_base():
# Plot the mixture of Gaussians
plt.legend() plot_base()
No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
# Sample from the proposal distribution and accept or reject.
# Accepted shown in green, rejected shown in red.
def plot_sample(x, show_q = False, show_Mq = False, show_sample=False, show_vline = False,
=False, show_Mqx=False, show_uMqx=False, show_accept=False):
show_px= torch.linspace(-3, 3, 100)
xs =r'$\tilde{p}(x)$', color='C0')
plt.plot(xs, mog.log_prob(xs).exp(), labelr"Target distribution $\tilde{p}(x)$")
plt.title(
if show_q:
# Plot the proposal distribution
='$q(x)$', color='C1')
plt.plot(xs, q.log_prob(xs).exp(), labelr"Proposal distribution $q(x)$")
plt.title(
if show_Mq:
# Plot the scaled proposal distribution
* q.log_prob(xs).exp(), label='$Mq(x)$', color='C2', linestyle='--')
plt.plot(xs, M "Scaled proposal distribution Mq(x)")
plt.title(if show_sample:
0,marker='x', color='k', label = r"$x\sim q(x)$")
plt.scatter(x, "Sample from proposal distribution")
plt.title(
if show_vline:
='C3', linestyle='--')
plt.axvline(x, color"Sample from proposal distribution")
plt.title(
if show_px:
='C4', label=r"$\tilde{p}(x)$")
plt.scatter(x, mog.log_prob(x).exp(), colorr"Evaluate target distribution $\tilde{p}(x)$ at sample x")
plt.title(
if show_Mqx:
* q.log_prob(x).exp(), color='k', label=r"$Mq(x)$")
plt.scatter(x, M r"Evaluate scaled proposal distribution $Mq(x)$ at sample x")
plt.title(
if show_uMqx:
0)
torch.manual_seed(= torch.rand(1)
u * M * q.log_prob(x).exp(), label=r"$uMq(x)$", color='purple')
plt.scatter(x, u r"Draw a uniform u between 0 and 1 and evaluate $uMq(x)$ at sample x")
plt.title(
if show_accept:
if u * M * q.log_prob(x).exp() < mog.log_prob(x).exp():
* M * q.log_prob(x).exp(), label=r"Accepted", color='g')
plt.scatter(x, u r"Accept sample as $uMq(x)$ $<$ $\tilde{p}(x)$")
plt.title(else:
* M * q.log_prob(x).exp(), label=r"Rejected", color='r')
plt.scatter(x, u r"Reject sample as $uMq(x)$ $>$ $\tilde{p}(x)$")
plt.title(
-.05, 1.0)
plt.ylim(
plt.legend()= f"../figures/sampling/rejection-sampling-{x:0.1f}-{show_q}-{show_Mq}-{show_sample}-{show_vline}-{show_px}-{show_Mqx}-{show_uMqx}-{show_accept}"
fn + ".pdf", bbox_inches='tight')
plt.savefig(fn + ".png", bbox_inches='tight', dpi=600) plt.savefig(fn
-1.0)) plot_sample(torch.tensor(
findfont: Font family ['cursive'] not found. Falling back to DejaVu Sans.
findfont: Generic family 'cursive' not found because none of the following families were found: Apple Chancery, Textile, Zapf Chancery, Sand, Script MT, Felipa, Comic Neue, Comic Sans MS, cursive
-1.0), show_q=True) plot_sample(torch.tensor(
-1.0), show_q=True, show_Mq=True) plot_sample(torch.tensor(
-1.0), show_q=True, show_Mq=True, show_sample=True) plot_sample(torch.tensor(
-1.0), show_q=True, show_Mq=True, show_sample=True, show_vline=True) plot_sample(torch.tensor(
-1.0), show_q=True, show_Mq=True, show_sample=True, show_vline=True, show_px=True) plot_sample(torch.tensor(
-1.0), show_q=True, show_Mq=True, show_sample=True, show_vline=True, show_px=True, show_Mqx=True) plot_sample(torch.tensor(
-1.0), show_q=True, show_Mq=True, show_sample=True, show_vline=True, show_px=True, show_Mqx=True, show_uMqx=True) plot_sample(torch.tensor(
-1.0), show_q=True, show_Mq=True, show_sample=True, show_vline=True, show_px=True, show_Mqx=True, show_uMqx=True, show_accept=True) plot_sample(torch.tensor(
-0.2), show_q=True, show_Mq=True, show_sample=True, show_vline=True, show_px=True, show_Mqx=True, show_uMqx=True, show_accept=True) plot_sample(torch.tensor(
# Create an animation out of the .png generated above
import os
import imageio
import glob
= []
images # Get all the pngs in the figures directory
= sorted(glob.glob('../figures/sampling/rejection-sampling*.png'))
fs for filename in fs:
= imageio.imread(filename)
ist
images.append(ist)# Print image size to make sure they are all the same size
print(ist.shape)
# Save with high resolution
'../figures/sampling/rejection-sampling.gif', images, duration=0.6)
imageio.mimsave(# save as mp4
# os.system("ffmpeg -i figures/sampling/rejection-sampling.gif figures/sampling/rejection-sampling.mp4")
/tmp/ipykernel_678943/3473788105.py:10: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.
ist = imageio.imread(filename)
(1456, 3275, 4)
(1456, 3275, 4)
(1456, 3275, 4)
(1456, 3275, 4)
(1456, 3275, 4)
(1456, 3275, 4)
(1456, 3275, 4)
(1456, 3275, 4)
(1456, 3275, 4)
(1456, 3275, 4)
(1456, 3275, 4)
(1456, 3275, 4)
(1456, 3275, 4)
(1456, 3275, 4)
-0.5), show_q=True, show_Mq=True, show_sample=True, show_vline=True, show_px=True, show_Mqx=True, show_uMqx=True, show_accept=True) plot_sample(torch.tensor(
= 1000
N
def plot_N_samples(N=100, seed=0, plot_kde=False):
torch.manual_seed(seed)# Now, run the algorithm for a few iterations and plot the results
= q.sample((N,))
samples_from_q
# Evaluate the target distribution at the samples
= mog.log_prob(samples_from_q).exp()
pxs
# Evaluate the scaled proposal distribution at the samples
= M * q.log_prob(samples_from_q).exp()
Mqxs
# Draw a uniform u between 0 and 1
= torch.rand(N)
us
# Accept or reject the samples
= us * Mqxs < pxs
accepted
# Plot p, q, and Mq
=r'$\tilde{p}(x)$', lw=2)
plt.plot(xs, mog.log_prob(xs).exp(), label#plt.plot(xs, q.log_prob(xs).exp(), label='q(x)')
* q.log_prob(xs).exp(), label=r'$Mq(x)$', lw=2)
plt.plot(xs, M
if not plot_kde:
# Plot the sample as red if it was rejected, and green if it was accepted, height is p(x)
#plt.scatter(samples_from_q, pxs, color='r', label='Rejected samples')
*Mqxs)[accepted], color='g', label='Accepted samples', alpha=0.2, marker='.', s=20)
plt.scatter(samples_from_q[accepted], (us~accepted], (us*Mqxs)[~accepted], color='r', label='Rejected samples', alpha=0.2, marker='.', s=20)
plt.scatter(samples_from_q[
plt.legend()
if plot_kde:
import seaborn as sns
='g', label='Density of accepted samples', lw=2)
sns.kdeplot(samples_from_q[accepted].numpy(), color
plt.legend()f"Rejection sampling with N={N} samples\n Acceptance rate: {accepted.float().mean().item():.2f}")
plt.title(= f"../figures/sampling/rejection-sampling-N{N}-{plot_kde}.pdf"
fn ='tight')
plt.savefig(fn, bbox_inchesprint(fn)
=1000, seed=0) plot_N_samples(N
../figures/sampling/rejection-sampling-N1000-False.pdf
=1000, seed=0, plot_kde=True) plot_N_samples(N
../figures/sampling/rejection-sampling-N1000-True.pdf
=10, seed=0, plot_kde=True) plot_N_samples(N
../figures/sampling/rejection-sampling-N10-True.pdf
=10, seed=0, plot_kde=False) plot_N_samples(N
../figures/sampling/rejection-sampling-N10-False.pdf
=10000, seed=0, plot_kde=True) plot_N_samples(N
../figures/sampling/rejection-sampling-N10000-True.pdf
=10000, seed=0, plot_kde=False) plot_N_samples(N
../figures/sampling/rejection-sampling-N10000-False.pdf
# Gaussian p and q
= 1.0
sigma_p = 1.1
sigma_q
= 1
DIM
= D.MultivariateNormal(torch.zeros(DIM), sigma_p**2 * torch.eye(DIM))
p = D.MultivariateNormal(torch.zeros(DIM), sigma_q**2 * torch.eye(DIM)) q
p
MultivariateNormal(loc: tensor([0.]), covariance_matrix: tensor([[1.]]))
q
MultivariateNormal(loc: tensor([0.]), covariance_matrix: tensor([[1.2100]]))
# plot the distributions
= torch.linspace(-3, 3, 100).view(-1, 1)
xs =r'$\tilde{p}(x)$')
plt.plot(xs, p.log_prob(xs).exp(), label=r'$q(x)$')
plt.plot(xs, q.log_prob(xs).exp(), label
plt.legend()'../figures/sampling/rejection-sampling-gaussian-p-q.pdf', bbox_inches='tight')
plt.savefig(
# Compute the constant M
= torch.max(p.log_prob(xs) - q.log_prob(xs)).exp()
M_emp M_emp
tensor(1.0999)
= sigma_q/sigma_p
M M
1.1
# Now, plot for varying D
= {}
Ms for DIM in [1, 2, 5, 10, 20, 50, 100]:
= (sigma_q/sigma_p)**DIM Ms[DIM]
Ms
{1: 1.1,
2: 1.2100000000000002,
5: 1.6105100000000006,
10: 2.5937424601000023,
20: 6.727499949325611,
50: 117.39085287969579,
100: 13780.61233982238}
import pandas as pd
=True, marker='o')
pd.Series(Ms).plot(logy"Dimensionality")
plt.xlabel("M (log scale)")
plt.ylabel('../figures/sampling/rejection-sampling-gaussian-p-q-M.pdf', bbox_inches='tight') plt.savefig(
import pandas as pd
= 1/pd.Series(Ms)
new_series =True, marker='o')
new_series.plot(logy"Dimensionality")
plt.xlabel("Acceptance rate (log scale)")
plt.ylabel('../figures/sampling/rejection-sampling-gaussian-p-q-acceptance.pdf', bbox_inches='tight') plt.savefig(
mog
MixtureSameFamily(
Categorical(probs: torch.Size([2]), logits: torch.Size([2])),
Normal(loc: torch.Size([2]), scale: torch.Size([2])))
# Plot the mizxture of Gaussians
= torch.linspace(-3, 3, 100)
xs =r'$p(x)$', color='C0')
plt.plot(xs, mog.log_prob(xs).exp(), label plt.legend()
<matplotlib.legend.Legend at 0x7f555c8e1070>
# Let f(x) = x^2
def f(x):
return (10-x**2)/20.0
=r'$p(x)$', color='C0')
plt.plot(xs, mog.log_prob(xs).exp(), label# Plot f(x)
=r'$f(x)$', color='C1')
plt.plot(xs, f(xs), label plt.legend()
<matplotlib.legend.Legend at 0x7f555c9e5be0>
# Let q(x) = N(0, 1)
= D.Normal(0, 1)
q
=r'$p(x)$', color='C0')
plt.plot(xs, mog.log_prob(xs).exp(), label# Plot f(x)
=r'$f(x)$', color='C1')
plt.plot(xs, f(xs), label# Plot q(x)
=r'$q(x)$', color='C2')
plt.plot(xs, q.log_prob(xs).exp(), label plt.legend()
<matplotlib.legend.Legend at 0x7f555c9482b0>
# Get a sample from q(x) and evaluate w(x) = f(x)/q(x) and evaluate f(x) at the sample
= q.sample()
x
def plot_sample(x):
= mog.log_prob(x).exp() / q.log_prob(x).exp()
w
=r'$p(x)$', color='C0')
plt.plot(xs, mog.log_prob(xs).exp(), label# Plot f(x)
=r'$f(x)$', color='C1')
plt.plot(xs, f(xs), label# Plot q(x)
=r'$q(x)$', color='C2')
plt.plot(xs, q.log_prob(xs).exp(), label# Plot sample x
0, marker='x', color='k', label=r"$x\sim q(x)$")
plt.scatter(x, # Draw vertical line at x
='k', linestyle='--')
plt.axvline(x, color
# Put title containing sample x, f(x), q(x), w(x), f(x)*w(x)
f"Sample x={x:0.3f}, f(x)={f(x):0.3f}, q(x)={q.log_prob(x).exp():0.3f}, w(x)={w:0.3f}, f(x)*w(x)={f(x)*w:0.3f}")
plt.title(
plt.legend()
# Case 1: p(x) and q(x) are comparable and high
-0.9)) plot_sample(torch.tensor(
# Case 2: p(x) is low and q(x) is high
# In case of rejection sampling, we would reject this sample
-1.5)) plot_sample(torch.tensor(
# Case 3: p(x) is high and q(x) is low
# Rare sample that we would accept in rejection sampling so high weight
-0.5)) plot_sample(torch.tensor(
# Case 4: p(x) is low and q(x) is low
2.5)) plot_sample(torch.tensor(
=r'$p(x)$', color='C0')
plt.plot(xs, mog.log_prob(xs).exp(), label# Plot f(x)
=r'$f(x)$', color='C1')
plt.plot(xs, f(xs), label# Plot q(x)
=r'$q(x)$', color='C2')
plt.plot(xs, q.log_prob(xs).exp(), label# Plot the weight function w(x) = p(x)/q(x)
/ q.log_prob(xs).exp(), label=r'$w(x)$', color='C3')
plt.plot(xs, mog.log_prob(xs).exp() plt.legend()
<matplotlib.legend.Legend at 0x7f555c547850>