import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'
Discrete distributions
from tueplots import bundles
plt.rcParams.update(bundles.beamer_moml())
# Also add despine to the bundle using rcParams
'axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams[
# Increase font size to match Beamer template
'font.size'] = 16
plt.rcParams[# Make background transparent
'figure.facecolor'] = 'none' plt.rcParams[
Bernoulli distribution
The PDF of the Bernoulli distribution is given by
\[ f(x) = p^x (1-p)^{1-x} \]
where \(x \in \{0, 1\}\) and \(p \in [0, 1]\).
= torch.distributions.Bernoulli(probs=0.3)
bernoulli bernoulli.probs
tensor(0.3000)
# Plot PDF
= bernoulli.probs.item()
p_0 = 1 - p_0
p_1
0, 1], [p_0, p_1], color='C0', edgecolor='k')
plt.bar([0, 1)
plt.ylim(0, 1], ['0', '1']) plt.xticks([
([<matplotlib.axis.XTick at 0x7f4c5e1f1790>,
<matplotlib.axis.XTick at 0x7f4c5e1f1760>],
[Text(0, 0, '0'), Text(1, 0, '1')])
### Careful!
= torch.distributions.Bernoulli(logits=0.3)
bernoulli bernoulli.probs
tensor(0.5744)
Logits?!
Probs range from 0 to 1, logits range from -inf to inf. Logits are the inverse of the sigmoid function.
The sigmoid function is defined as:
\[\sigma(x) = \frac{1}{1 + e^{-x}}\]
The inverse of the sigmoid function is defined as:
\[\sigma^{-1}(x) = \log \frac{x}{1 - x}\]
### Sampling
bernoulli.sample()
tensor(1.)
10,)) bernoulli.sample((
tensor([0., 1., 0., 0., 1., 1., 0., 1., 0., 0.])
= bernoulli.sample((1000,))
data data
tensor([1., 1., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
0., 0., 0., 1., 1., 1., 0., 1., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0.,
0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1., 1.,
1., 0., 0., 0., 1., 0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 1., 1., 0.,
1., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 0., 1., 1., 1., 1., 0., 1.,
0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1.,
1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 0., 1., 0.,
1., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 0.,
1., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 1., 1., 1., 1., 1., 1., 0.,
0., 1., 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0.,
1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 0.,
1., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1.,
1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1., 0.,
0., 1., 1., 0., 0., 1., 0., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1.,
1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 0., 1.,
0., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0.,
1., 1., 0., 1., 0., 1., 1., 1., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0.,
1., 0., 1., 0., 0., 1., 0., 1., 1., 0., 1., 1., 1., 0., 0., 1., 0., 0.,
1., 1., 1., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1.,
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 0., 1.,
1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0., 1.,
1., 0., 1., 1., 0., 0., 0., 1., 1., 0., 1., 1., 1., 1., 0., 1., 0., 0.,
1., 0., 1., 0., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 1., 1., 0.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0.,
1., 0., 1., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 0., 0., 1., 1., 1.,
0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0.,
0., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1.,
1., 1., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 1.,
0., 1., 0., 1., 1., 1., 1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 1.,
1., 0., 0., 1., 1., 0., 0., 1., 1., 1., 1., 0., 1., 0., 1., 0., 1., 1.,
0., 0., 0., 1., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 0., 1.,
1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0.,
0., 1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 1., 1., 1.,
0., 0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0.,
1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 1.,
1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 0.,
1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 0.,
1., 0., 0., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 0.,
1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1.,
1., 0., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 0., 0.,
1., 0., 1., 1., 0., 0., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 0.,
1., 0., 1., 1., 0., 1., 1., 0., 1., 0., 0., 0., 1., 1., 0., 1., 1., 1.,
1., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 0., 0.,
1., 0., 1., 1., 0., 1., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0.,
1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 0., 1., 0., 1., 0., 0., 1.,
1., 1., 1., 1., 0., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1., 0.,
0., 0., 0., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 0.,
1., 0., 1., 0., 0., 1., 1., 1., 1., 0., 1., 1., 0., 0., 0., 1., 0., 1.,
1., 1., 0., 1., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 1.,
1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1.,
0., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0., 0.,
0., 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 0., 1., 1., 0., 1.,
0., 1., 0., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 0., 1., 0., 1.,
1., 1., 0., 0., 1., 1., 1., 0., 0., 1.])
### Count number of 1s
sum() data.
tensor(586.)
### IID sampling
= 1000
size = torch.empty(size)
data for s_num in range(size):
= torch.distributions.Bernoulli(probs=0.3) # Each sample uses the same distribution (Identical)
dist = dist.sample() # Each sample is independent (Independent) data[s_num]
### Dependent sampling
= 1000
size
### If previous sample was 1, next sample is 1 with probability 0.9
### If previous sample was 1, next sample is 0 with probability 0.1
### If previous sample was 0, next sample is 0 with probability 0.8
### If previous sample was 0, next sample is 1 with probability 0.2
### Categorical distribution
= 0.2
p1 = 0.3
p2 = 0.5
p3
= torch.distributions.Categorical(probs=torch.tensor([p1, p2, p3]))
categorical categorical.probs
tensor([0.2000, 0.3000, 0.5000])
# Plot PDF
0, 1, 2], [p1, p2, p3], color='C0', edgecolor='k')
plt.bar([0, 1)
plt.ylim(0, 1, 2], ['0', '1', '2']) plt.xticks([
([<matplotlib.axis.XTick at 0x7f4c5c094ac0>,
<matplotlib.axis.XTick at 0x7f4c5c094a90>,
<matplotlib.axis.XTick at 0x7f4c5c094970>],
[Text(0, 0, '0'), Text(1, 0, '1'), Text(2, 0, '2')])
### Uniform distribution
= torch.distributions.Uniform(low=0, high=1) uniform
uniform.sample()
tensor(0.0553)
uniform.support
Interval(lower_bound=0.0, upper_bound=1.0)
### Plot PDF
= torch.linspace(0.0, 0.99999, 500)
xs = uniform.log_prob(xs).exp()
ys
='C0')
plt.plot(xs, ys, color# Filled area
='C0', alpha=0.2) plt.fill_between(xs, ys, color
<matplotlib.collections.PolyCollection at 0x7f4c5c01a940>
### Why log_prob? and not prob?
### Normal distribution
= torch.distributions.Normal(loc=0, scale=1) normal
normal.support
Real()
### Plot PDF
= torch.linspace(-5, 5, 500)
xs = normal.log_prob(xs).exp()
ys ='C0')
plt.plot(xs, ys, color# Filled area
='C0', alpha=0.2) plt.fill_between(xs, ys, color
<matplotlib.collections.PolyCollection at 0x7fe120a2f4c0>
= torch.linspace(-50, 50, 500)
xs = normal.log_prob(xs).exp()
probs ='C0')
plt.plot(xs, probs, color# Filled area
#plt.fill_between(xs, probs, color='C0', alpha=0.2)
-20)), normal.log_prob(torch.tensor(-40))
normal.log_prob(torch.tensor(
-20)).exp(), normal.log_prob(torch.tensor(-40)).exp() normal.log_prob(torch.tensor(
(tensor(0.), tensor(0.))
= torch.linspace(-50, 50, 500)
xs = normal.log_prob(xs)
logprobs ='C0') plt.plot(xs, logprobs, color
def plot_normal(mu, sigma):
= torch.tensor(mu)
mu = torch.tensor(sigma)
sigma = torch.linspace(-40, 40, 1000)
xs = torch.distributions.Normal(mu, sigma)
dist
= dist.log_prob(xs)
logprobs = torch.exp(logprobs)
probs = plt.subplots(nrows=2)
fig, ax 0].plot(xs, probs)
ax[0].set_title("Probability")
ax[1].plot(xs, logprobs)
ax[1].set_title("Log Probability")
ax[
0, 1) plot_normal(
# Interactive slider for plot_normal function
from ipywidgets import interact, FloatSlider
=FloatSlider(min=-2, max=2, step=0.1, value=0), sigma=FloatSlider(min=0.1, max=2, step=0.1, value=1)) interact(plot_normal, mu
<function __main__.plot_normal(mu, sigma)>
= normal.sample((1000,))
samples 20] samples[:
tensor([ 1.8764, 0.4868, -0.7966, -0.8190, 1.4538, 0.0766, -2.0262, 0.9965,
-1.1971, -0.4764, -2.1042, 0.2489, -0.2859, 1.1970, -0.7265, -0.8898,
-0.4592, -0.3581, -0.7239, -0.0790])
= plt.hist(samples.numpy(), bins=50, density=True, edgecolor='k') _
= plt.hist(samples.numpy(), bins=30, density=True, edgecolor='k') _
import seaborn as sns
=2.1, shade=True) sns.kdeplot(samples.numpy(), bw_adjust
<AxesSubplot:ylabel='Density'>
### IID sampling
= 1000
n_samples = []
samples for i in range(n_samples):
= torch.distributions.Normal(0, 1) # Using identical distribution over all samples
dist # sample is independent of previous samples
samples.append(dist.sample())
= torch.stack(samples)
samples
= plt.subplots(nrows=2)
fig, ax =2.0, shade=True, ax=ax[0])
sns.kdeplot(samples.numpy(), bw_adjust0].set_title("KDE of IID samples")
ax[
1].plot(samples.numpy())
ax[1].set_title("IID samples") ax[
Text(0.5, 1.0, 'IID samples')
### Non-IID sampling (non-identical distribution)
= 100
n_samples = []
samples for i in range(n_samples):
# Non-indentical distribution
if i%2:
= torch.distributions.Normal(torch.tensor([2.0]), torch.tensor([0.5]))
dist else:
= torch.distributions.Normal(torch.tensor([-2.0]), torch.tensor([0.5]))
dist
samples.append(dist.sample())
= torch.stack(samples)
samples
= plt.subplots(nrows=2)
fig, ax =1.0, shade=True, ax=ax[0])
sns.kdeplot(samples.numpy().flatten(), bw_adjust0].set_title("KDE of non-identical samples")
ax[
1].plot(samples.numpy().flatten())
ax[1].set_title("Samples over time") ax[
Text(0.5, 1.0, 'Samples over time')
### Non-IID sampling (dependent sampling)
= 1000
n_samples = torch.tensor([10.0])
prev_sample = []
samples for i in range(n_samples):
= torch.distributions.Normal(prev_sample, 1)
dist = dist.sample()
sample
samples.append(sample)= sample
prev_sample
= torch.stack(samples)
samples = plt.subplots(nrows=2)
fig, ax =2.0, shade=True, ax=ax[0])
sns.kdeplot(samples.numpy().flatten(), bw_adjust0].set_title("KDE of samples")
ax[
1].plot(samples.numpy().flatten())
ax[1].set_title("IID samples") ax[
Text(0.5, 1.0, 'IID samples')
### Laplace distribution v/s Normal distribution
= torch.distributions.Laplace(loc=0, scale=1)
laplace = torch.distributions.Normal(loc=0, scale=1)
normal = torch.distributions.StudentT(df=1)
student_t_1 = torch.distributions.StudentT(df=2)
student_t_2
= torch.linspace(-6, 6, 500)
xs = laplace.log_prob(xs).exp()
ys_laplace ='C0', label='Laplace')
plt.plot(xs, ys_laplace, color
= normal.log_prob(xs).exp()
ys_normal ='C1', label='Normal')
plt.plot(xs, ys_normal, color
= student_t_1.log_prob(xs).exp()
ys_student_t_1 ='C2', label='Student T (df=1)')
plt.plot(xs, ys_student_t_1, color
= student_t_2.log_prob(xs).exp()
ys_student_t_2 ='C3', label='Student T (df=2)')
plt.plot(xs, ys_student_t_2, color
plt.legend()
= False
zoom if zoom:
5, 6)
plt.xlim(-0.002, 0.02) plt.ylim(
### Beta distribution
= torch.distributions.Beta(concentration1=2, concentration0=2)
beta beta.support
Interval(lower_bound=0.0, upper_bound=1.0)
# PDF
= torch.linspace(0, 1, 500)
xs = beta.log_prob(xs).exp()
ys ='C0')
plt.plot(xs, ys, color# Filled area
='C0', alpha=0.2) plt.fill_between(xs, ys, color
<matplotlib.collections.PolyCollection at 0x7f5125cf0430>
= beta.sample()
s s
tensor(0.3356)
# Add widget to play with parameters
from ipywidgets import interact
def plot_beta(a, b):
= torch.distributions.Beta(concentration1=a, concentration0=b)
beta = torch.linspace(0, 1, 500)
xs = beta.log_prob(xs).exp()
ys ='C0')
plt.plot(xs, ys, color# Filled area
='C0', alpha=0.2)
plt.fill_between(xs, ys, color
=(0.1, 10, 0.1), b=(0.1, 10, 0.1)) interact(plot_beta,a
<function __main__.plot_beta(a, b)>
### Dirichlet distribution
= torch.distributions.Dirichlet(concentration=torch.tensor([2.0, 2.0, 2.0]))
dirichlet dirichlet.support
Simplex()
= dirichlet.sample()
s print(s, s.sum())
tensor([0.2924, 0.3254, 0.3821]) tensor(1.)
= dirichlet.sample()
s print(s, s.sum())
tensor([0.3898, 0.1071, 0.5030]) tensor(1.)
= torch.distributions.Dirichlet(concentration=torch.tensor([0.8, 0.1, 0.1]))
dirichlet2 = dirichlet2.sample()
s print(s, s.sum())
tensor([0.8190, 0.0111, 0.1699]) tensor(1.)