import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'
Discrete distributions
from tueplots import bundles
plt.rcParams.update(bundles.beamer_moml())
# Also add despine to the bundle using rcParams
'axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams[
# Increase font size to match Beamer template
'font.size'] = 16
plt.rcParams[# Make background transparent
'figure.facecolor'] = 'none' plt.rcParams[
Bernoulli distribution
The PDF of the Bernoulli distribution is given by
\[ f(x) = p^x (1-p)^{1-x} \]
where \(x \in \{0, 1\}\) and \(p \in [0, 1]\).
= torch.distributions.Bernoulli(probs=0.3)
bernoulli print(bernoulli.probs)
= torch.distributions.Categorical(probs = torch.tensor([0.7, 0.3]))
eq_cat eq_cat.probs
tensor(0.3000)
tensor([0.7000, 0.3000])
# Plot PDF
= bernoulli.probs.item()
p_1 = 1 - p_1
p_0
0, 1], [p_0, p_1], color='C0', edgecolor='k')
plt.bar([0, 1)
plt.ylim(0, 1], ['0', '1']) plt.xticks([
([<matplotlib.axis.XTick at 0x7f31863dbca0>,
<matplotlib.axis.XTick at 0x7f31863dbc70>],
[Text(0, 0, '0'), Text(1, 0, '1')])
### Careful!
= torch.distributions.Bernoulli(logits=-20.0)
bernoulli bernoulli.probs
tensor(2.0612e-09)
Logits?!
Probs range from 0 to 1, logits range from -inf to inf. Logits are the inverse of the sigmoid function.
The sigmoid function is defined as:
\[\sigma(x) = \frac{1}{1 + e^{-x}}\]
The inverse of the sigmoid function is defined as:
\[\sigma^{-1}(x) = \log \frac{x}{1 - x}\]
### Sampling
bernoulli.sample()
tensor(0.)
2,)) bernoulli.sample((
tensor([0., 0.])
= bernoulli.sample((1000,))
data data
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
### Count number of 1s
sum() data.
tensor(0.)
### IID sampling
= 1000
size = torch.empty(size)
data for s_num in range(size):
= torch.distributions.Bernoulli(probs=0.3) # Each sample uses the same distribution (Identical)
dist = dist.sample() # Each sample is independent (Independent) data[s_num]
### Dependent sampling
= 1000
size
### If previous sample was 1, next sample is 1 with probability 0.9
### If previous sample was 1, next sample is 0 with probability 0.1
### If previous sample was 0, next sample is 0 with probability 0.8
### If previous sample was 0, next sample is 1 with probability 0.2
### Categorical distribution
= -0.2
p1 = 0.3
p2 = 0.6
p3
#categorical = torch.distributions.Categorical(probs=torch.tensor([p1, p2, p3]))
#categorical.probs
= torch.distributions.Categorical(logits=torch.tensor([p1, p2, p3]))
cat2 cat2.probs
tensor([0.2052, 0.3383, 0.4566])
# Plot PDF
0, 1, 2], [p1, p2, p3], color='C0', edgecolor='k')
plt.bar([0, 1)
plt.ylim(0, 1, 2], ['0', '1', '2']) plt.xticks([
([<matplotlib.axis.XTick at 0x7f318427d0d0>,
<matplotlib.axis.XTick at 0x7f318427d0a0>,
<matplotlib.axis.XTick at 0x7f318428ef40>],
[Text(0, 0, '0'), Text(1, 0, '1'), Text(2, 0, '2')])
### Uniform distribution
= torch.distributions.Uniform(low=0, high=1) uniform
uniform.sample()
tensor(0.4386)
uniform.support
Interval(lower_bound=0.0, upper_bound=1.0)
### Plot PDF
= torch.linspace(0.0, 0.99999, 500)
xs = uniform.log_prob(xs).exp()
ys
='C0')
plt.plot(xs, ys, color# Filled area
='C0', alpha=0.2) plt.fill_between(xs, ys, color
<matplotlib.collections.PolyCollection at 0x7f3184213e50>
### Why log_prob? and not prob?
### Normal distribution
= torch.distributions.Normal(loc=0, scale=1) normal
normal.support
Real()
### Plot PDF
= torch.linspace(-5, 5, 500)
xs = normal.log_prob(xs).exp()
ys ='C0')
plt.plot(xs, ys, color# Filled area
='C0', alpha=0.2) plt.fill_between(xs, ys, color
<matplotlib.collections.PolyCollection at 0x7f318418df10>
= torch.linspace(-50, 50, 500)
xs = normal.log_prob(xs).exp()
probs ='C0')
plt.plot(xs, probs, color# Filled area
#plt.fill_between(xs, probs, color='C0', alpha=0.2)
-20)), normal.log_prob(torch.tensor(-40))
normal.log_prob(torch.tensor(
-20)).exp(), normal.log_prob(torch.tensor(-40)).exp() normal.log_prob(torch.tensor(
(tensor(0.), tensor(0.))
= torch.linspace(-50, 50, 500)
xs = normal.log_prob(xs)
logprobs ='C0') plt.plot(xs, logprobs, color
def plot_normal(mu, sigma):
= torch.tensor(mu)
mu = torch.tensor(sigma)
sigma = torch.linspace(-40, 40, 1000)
xs = torch.distributions.Normal(mu, sigma)
dist
= dist.log_prob(xs)
logprobs = torch.exp(logprobs)
probs = plt.subplots(nrows=2)
fig, ax 0].plot(xs, probs)
ax[0].set_title("Probability")
ax[1].plot(xs, logprobs)
ax[1].set_title("Log Probability")
ax[
0, 1) plot_normal(
# Interactive slider for plot_normal function
from ipywidgets import interact, FloatSlider
=FloatSlider(min=-2, max=2, step=0.1, value=0), sigma=FloatSlider(min=0.1, max=2, step=0.1, value=1)) interact(plot_normal, mu
<function __main__.plot_normal(mu, sigma)>
= normal.sample((1000,))
samples 20] samples[:
tensor([-0.0520, -1.0746, 0.1424, -0.1028, -1.0832, 0.2766, 1.3047, -2.3132,
-0.7942, 0.0828, -0.9418, -1.4644, 0.6963, 0.5597, 0.2721, -1.8722,
-1.0237, 1.4067, -0.0434, -1.5735])
= plt.hist(samples.numpy(), bins=50, density=True, edgecolor='k') _
= plt.hist(samples.numpy(), bins=30, density=True, edgecolor='k') _
import seaborn as sns
=2.1, shade=True) sns.kdeplot(samples.numpy(), bw_adjust
<AxesSubplot:ylabel='Density'>
### IID sampling
= 1000
n_samples = []
samples for i in range(n_samples):
= torch.distributions.Normal(0, 1) # Using identical distribution over all samples
dist # sample is independent of previous samples
samples.append(dist.sample())
= torch.stack(samples)
samples
= plt.subplots(nrows=2)
fig, ax =2.0, shade=True, ax=ax[0])
sns.kdeplot(samples.numpy(), bw_adjust0].set_title("KDE of IID samples")
ax[
1].plot(samples.numpy())
ax[1].set_title("IID samples") ax[
Text(0.5, 1.0, 'IID samples')
### Non-IID sampling (non-identical distribution)
= 100
n_samples = []
samples for i in range(n_samples):
# Non-indentical distribution
if i%2:
= torch.distributions.Normal(torch.tensor([2.0]), torch.tensor([0.5]))
dist else:
= torch.distributions.Normal(torch.tensor([-2.0]), torch.tensor([0.5]))
dist
samples.append(dist.sample())
= torch.stack(samples)
samples
= plt.subplots(nrows=2)
fig, ax =1.0, shade=True, ax=ax[0])
sns.kdeplot(samples.numpy().flatten(), bw_adjust0].set_title("KDE of non-identical samples")
ax[
1].plot(samples.numpy().flatten())
ax[1].set_title("Samples over time") ax[
Text(0.5, 1.0, 'Samples over time')
### Non-IID sampling (dependent sampling)
= 1000
n_samples = torch.tensor([10.0])
prev_sample = []
samples for i in range(n_samples):
= torch.distributions.Normal(prev_sample, 1)
dist = dist.sample()
sample
samples.append(sample)= sample
prev_sample
= torch.stack(samples)
samples = plt.subplots(nrows=2)
fig, ax =2.0, shade=True, ax=ax[0])
sns.kdeplot(samples.numpy().flatten(), bw_adjust0].set_title("KDE of samples")
ax[
1].plot(samples.numpy().flatten())
ax[1].set_title("IID samples") ax[
Text(0.5, 1.0, 'IID samples')
### Laplace distribution v/s Normal distribution
= torch.distributions.Laplace(loc=0, scale=1)
laplace = torch.distributions.Normal(loc=0, scale=1)
normal = torch.distributions.StudentT(df=1)
student_t_1 = torch.distributions.StudentT(df=2)
student_t_2
= torch.linspace(-6, 6, 500)
xs = laplace.log_prob(xs).exp()
ys_laplace ='C0', label='Laplace')
plt.plot(xs, ys_laplace, color
= normal.log_prob(xs).exp()
ys_normal ='C1', label='Normal')
plt.plot(xs, ys_normal, color
= student_t_1.log_prob(xs).exp()
ys_student_t_1 ='C2', label='Student T (df=1)')
plt.plot(xs, ys_student_t_1, color
= student_t_2.log_prob(xs).exp()
ys_student_t_2 ='C3', label='Student T (df=2)')
plt.plot(xs, ys_student_t_2, color
plt.legend()
= False
zoom if zoom:
5, 6)
plt.xlim(-0.002, 0.02) plt.ylim(
### Beta distribution
= torch.distributions.Beta(concentration1=2, concentration0=2)
beta beta.support
Interval(lower_bound=0.0, upper_bound=1.0)
# PDF
= torch.linspace(0, 1, 500)
xs = beta.log_prob(xs).exp()
ys ='C0')
plt.plot(xs, ys, color# Filled area
='C0', alpha=0.2) plt.fill_between(xs, ys, color
<matplotlib.collections.PolyCollection at 0x7f1912fb6910>
= beta.sample()
s s
tensor(0.2485)
# Add widget to play with parameters
from ipywidgets import interact
def plot_beta(a, b):
= torch.distributions.Beta(concentration1=a, concentration0=b)
beta = torch.linspace(0, 1, 500)
xs = beta.log_prob(xs).exp()
ys ='C0')
plt.plot(xs, ys, color# Filled area
='C0', alpha=0.2)
plt.fill_between(xs, ys, color
=(0.1, 20, 0.1), b=(0.1, 20, 0.1)) interact(plot_beta,a
<function __main__.plot_beta(a, b)>
### Dirichlet distribution
= torch.distributions.Dirichlet(concentration=torch.tensor([2.0, 2.0, 2.0]))
dirichlet dirichlet.support
Simplex()
= dirichlet.sample()
s print(s, s.sum())
tensor([0.0994, 0.7448, 0.1558]) tensor(1.)
= dirichlet.sample()
s print(s, s.sum())
tensor([0.4175, 0.3628, 0.2197]) tensor(1.)
= torch.distributions.Dirichlet(concentration=torch.tensor([0.8, 0.1, 0.1]))
dirichlet2 = dirichlet2.sample()
s print(s, s.sum())
tensor([9.7325e-01, 7.5464e-10, 2.6754e-02]) tensor(1.)