Cumulative Distribution Function (CDF)


Nipun Batra


March 5, 2025

import matplotlib.pyplot as plt
import numpy as np
import torch 
import torch.nn as nn

import pandas as pd
# Retina mode
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

Generating random numbers from a categorical distribution

probs = torch.tensor([0.1, 0.2, 0.3, 0.4])
tensor([0.1000, 0.2000, 0.3000, 0.4000])
unif = torch.distributions.uniform.Uniform(0, 1)
cum_sum_prob = torch.cumsum(probs, dim=0)
tensor([0.1000, 0.3000, 0.6000, 1.0000])
symbols  = torch.tensor([1, 2, 3, 4])
sample = unif.sample()
if cum_sum_prob[0] > sample:
elif cum_sum_prob[1] > sample:
elif cum_sum_prob[2] > sample:
tensor([1, 2, 3, 4])
sample <= cum_sum_prob
tensor([False, False, False,  True])
symbols[sample < cum_sum_prob][0]
### Even more efficient
index = torch.searchsorted(cum_sum_prob, sample)
### Vectorized
num_samples = 1000
unif_samples = unif.sample((num_samples,))

index = torch.searchsorted(cum_sum_prob, unif_samples)
our_samples = symbols[index]
tensor([2, 3, 4, 2, 3, 4, 4, 1, 3, 3, 3, 4, 1, 3, 4, 3, 4, 4, 2, 2, 2, 3, 2, 3,
        1, 3, 4, 1, 1, 3, 4, 3, 2, 4, 4, 3, 3, 3, 1, 4, 3, 4, 2, 2, 4, 3, 3, 3,
        4, 4, 4, 4, 2, 1, 4, 2, 4, 2, 3, 4, 4, 3, 3, 1, 4, 3, 3, 4, 4, 3, 3, 3,
        3, 4, 1, 3, 3, 4, 2, 4, 4, 4, 4, 4, 4, 3, 4, 4, 2, 2, 3, 2, 4, 4, 4, 3,
        3, 2, 3, 2, 2, 4, 3, 1, 1, 2, 1, 4, 3, 2, 2, 4, 4, 3, 4, 4, 1, 3, 3, 4,
        4, 3, 4, 3, 4, 4, 4, 3, 4, 3, 4, 1, 4, 4, 1, 4, 3, 3, 3, 4, 3, 4, 2, 3,
        4, 2, 3, 3, 4, 4, 4, 3, 3, 3, 2, 4, 4, 2, 4, 4, 3, 4, 4, 2, 4, 4, 4, 2,
        4, 3, 4, 4, 3, 2, 4, 3, 4, 4, 3, 3, 3, 2, 3, 3, 3, 1, 4, 3, 3, 4, 3, 3,
        3, 2, 2, 2, 3, 4, 1, 2, 2, 4, 2, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 2, 3, 3,
        4, 3, 3, 4, 3, 3, 2, 1, 2, 2, 2, 4, 4, 2, 4, 2, 4, 4, 4, 4, 4, 2, 3, 3,
        2, 3, 2, 3, 3, 4, 3, 2, 4, 4, 4, 3, 2, 2, 1, 1, 2, 3, 4, 3, 2, 2, 4, 1,
        2, 2, 2, 1, 3, 3, 2, 1, 3, 4, 4, 3, 3, 3, 4, 4, 4, 3, 2, 3, 2, 2, 2, 3,
        3, 2, 4, 1, 2, 4, 4, 2, 2, 4, 3, 3, 3, 2, 1, 3, 4, 3, 4, 3, 4, 1, 3, 3,
        1, 3, 3, 3, 4, 4, 3, 4, 4, 3, 1, 4, 4, 4, 3, 4, 3, 4, 1, 3, 4, 4, 4, 4,
        3, 2, 3, 3, 1, 4, 2, 4, 2, 2, 3, 3, 4, 4, 4, 3, 3, 1, 1, 1, 4, 3, 2, 3,
        2, 4, 1, 3, 4, 2, 3, 4, 4, 4, 4, 2, 4, 4, 3, 4, 3, 2, 2, 4, 3, 4, 4, 4,
        4, 2, 4, 4, 4, 4, 3, 1, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, 2, 4, 1, 4, 2,
        3, 3, 4, 1, 4, 4, 2, 2, 4, 1, 3, 2, 2, 3, 2, 3, 3, 3, 4, 4, 3, 1, 3, 4,
        4, 2, 1, 2, 2, 4, 3, 4, 4, 4, 3, 4, 4, 2, 4, 2, 4, 3, 4, 2, 4, 4, 2, 3,
        3, 3, 4, 3, 4, 2, 4, 4, 4, 2, 4, 4, 2, 4, 4, 4, 2, 3, 2, 4, 4, 4, 2, 3,
        2, 2, 4, 4, 3, 2, 2, 1, 4, 2, 4, 1, 4, 2, 2, 2, 1, 4, 2, 4, 4, 1, 4, 3,
        3, 4, 3, 3, 4, 4, 1, 4, 2, 2, 4, 3, 4, 4, 4, 2, 3, 4, 3, 4, 3, 1, 1, 3,
        4, 2, 4, 4, 4, 3, 4, 4, 3, 4, 4, 4, 3, 4, 4, 3, 3, 4, 4, 4, 2, 2, 3, 2,
        4, 4, 4, 4, 3, 2, 3, 2, 4, 4, 4, 2, 2, 2, 3, 4, 2, 4, 4, 4, 1, 2, 3, 3,
        2, 3, 3, 1, 4, 4, 4, 2, 1, 4, 4, 4, 2, 4, 1, 3, 4, 3, 3, 1, 4, 4, 4, 3,
        1, 4, 3, 2, 3, 4, 4, 2, 4, 3, 3, 4, 4, 1, 4, 1, 4, 1, 4, 4, 4, 2, 3, 2,
        2, 2, 4, 4, 2, 3, 4, 4, 2, 4, 1, 3, 2, 2, 2, 1, 3, 4, 4, 3, 2, 1, 2, 4,
        3, 4, 2, 4, 4, 4, 1, 4, 3, 4, 3, 2, 3, 4, 4, 4, 1, 3, 4, 1, 1, 3, 2, 4,
        1, 4, 3, 2, 1, 4, 2, 2, 4, 4, 2, 2, 2, 4, 2, 3, 4, 3, 1, 2, 4, 3, 3, 4,
        4, 4, 4, 4, 1, 4, 1, 2, 2, 3, 4, 4, 3, 3, 4, 2, 2, 4, 3, 4, 2, 4, 3, 3,
        1, 4, 4, 1, 4, 3, 4, 2, 2, 4, 2, 3, 2, 2, 3, 1, 4, 2, 4, 1, 3, 3, 4, 4,
        3, 3, 2, 1, 1, 4, 1, 4, 4, 4, 4, 3, 4, 4, 2, 4, 2, 1, 4, 4, 1, 1, 1, 2,
        2, 4, 3, 4, 2, 3, 3, 4, 3, 3, 2, 4, 4, 2, 4, 3, 3, 1, 4, 4, 3, 3, 4, 2,
        1, 3, 4, 2, 2, 3, 3, 4, 2, 4, 2, 3, 2, 4, 3, 2, 3, 1, 1, 4, 4, 3, 2, 1,
        4, 4, 2, 3, 3, 4, 1, 2, 4, 2, 4, 4, 1, 3, 4, 4, 2, 3, 2, 2, 3, 3, 2, 4,
        4, 4, 3, 1, 3, 2, 4, 4, 4, 3, 4, 4, 4, 2, 2, 4, 1, 3, 2, 2, 1, 1, 3, 4,
        4, 4, 4, 4, 2, 1, 1, 4, 3, 4, 4, 2, 2, 4, 2, 3, 4, 3, 2, 2, 3, 3, 4, 2,
        1, 1, 4, 1, 3, 4, 3, 3, 2, 4, 3, 1, 4, 3, 2, 2, 4, 2, 4, 4, 3, 3, 4, 4,
        2, 3, 2, 2, 3, 4, 4, 2, 3, 3, 4, 4, 3, 1, 2, 3, 3, 2, 1, 3, 4, 4, 4, 4,
        3, 4, 4, 4, 4, 2, 2, 4, 2, 4, 3, 4, 2, 3, 2, 3, 3, 1, 3, 1, 1, 1, 4, 3,
        1, 2, 4, 4, 2, 3, 4, 4, 4, 4, 3, 4, 3, 4, 4, 3, 3, 4, 1, 4, 4, 2, 4, 4,
        4, 2, 4, 3, 4, 4, 4, 1, 4, 4, 1, 3, 3, 3, 3, 4])
samples_series = pd.Series(our_samples.numpy())
samples_series_norm = samples_series.value_counts(normalize=True)
samples_series_norm.plot(kind='bar', rot=0)
for i in range(4):
    plt.axhline(probs[i].item(), color='r', linestyle='--')

### Use CDF function
our_dist = torch.distributions.categorical.Normal(0, 1)
unif_samples = unif.sample((num_samples,))

NotImplementedError                       Traceback (most recent call last)
Cell In[31], line 5
      2 our_dist = torch.distributions.categorical.Categorical(probs)
      3 unif_samples = unif.sample((num_samples,))
----> 5 our_dist_cdf_inverse = our_dist.icdf(unif_samples)
      6 print(our_dist_cdf_inverse)

File ~/miniconda3/lib/python3.9/site-packages/torch/distributions/, in Distribution.icdf(self, value)
    189 def icdf(self, value):
    190     """
    191     Returns the inverse cumulative density/mass function evaluated at
    192     `value`.
    195         value (Tensor):
    196     """
--> 197     raise NotImplementedError

symbols[torch.tensor([0.1, 0.2, 0.3, 0.4]) < cum_sum_prob]
tensor([2, 3, 4])
## Plotting the PDF

def plot_pdf_normal(mu, sigma):
    dist = torch.distributions.Normal(mu, sigma)
    x = torch.linspace(-10, 10, 1000)
    y = dist.log_prob(x).exp()
    plt.plot(x, y, label=f"PDF N({mu}, {sigma})")
plot_pdf_normal(0, 1)
plot_pdf_normal(0, 2)
plot_pdf_normal(1, 2)

# Simulating data with normal distributed noise

x_true = torch.linspace(-5, 5, 1000)
y_true = 2 * x_true + 1
eps = torch.distributions.Normal(0, 1).sample(y_true.shape)
y_obs = y_true + eps

plt.scatter(x_true, y_obs, 
            label="Observed data", 
            marker='o', s=2,
            alpha = 0.5, color='red')
plt.plot(x_true, y_true, label="True data")

Heights and weights data

Dataset link

The dataset contains 25,000 rows and 3 columns. Each row represents a person and the columns represent the person’s index, height, and weight.

df = pd.read_html("")
store_df = df[0]
store_df.columns = store_df.iloc[0]
store_df = store_df.iloc[1:]
store_df = store_df.astype(float)
store_df = store_df.drop(columns=["Index"])
store_df = store_df.dropna()
Height(Inches) Weight(Pounds)
count 25000.000000 25000.000000
mean 67.993114 127.079421
std 1.901679 11.660898
min 60.278360 78.014760
25% 66.704397 119.308675
50% 67.995700 127.157750
75% 69.272958 134.892850
max 75.152800 170.924000
Height(Inches) Weight(Pounds)
1 65.78331 112.9925
2 71.51521 136.4873
3 69.39874 153.0269
4 68.21660 142.3354
5 67.78781 144.2971
fig, ax = plt.subplots(nrows=2, sharex=True)
store_df["Height(Inches)"].plot(kind='density', ax=ax[0])
store_df["Height(Inches)"].plot(kind='hist', bins=30, ax=ax[1])

# Fit a normal distribution to the data
mu = store_df["Height(Inches)"].mean().item()
sigma = store_df["Height(Inches)"].std().item()

dist = torch.distributions.Normal(mu, sigma)
x = torch.linspace(50, 80, 1000)
y = dist.log_prob(x).exp()
plt.plot(x, y, label="Fitted PDF")

store_df["Height(Inches)"].plot(kind='hist', label="Histogram", density=True, bins=30)



Note: I DO NOT FOLLOW or endorse using a normal distribution to model grades in a class. This is just an exercise to practice the PDF of a normal distribution and show how to use percentiles.

marks = torch.distributions.Normal(70, 8).sample((400,))
_ = plt.hist(marks, bins=20, density=True)

mu_marks, sigma_marks = marks.mean().item(), marks.std().item()
dist = torch.distributions.Normal(mu_marks, sigma_marks)
x = torch.linspace(30, 110, 1000)
y = dist.log_prob(x).exp()
plt.plot(x, y, label="Fitted PDF", color='gray', lw=2)

# 99% percentile and above get A+
marks_99_per = dist.icdf(torch.tensor(0.99))
num_students_getting_A_plus = marks[marks>marks_99_per].shape[0]
plt.fill_between(x, y, where=x>marks_99_per, alpha=0.5, label=f"A+ ({num_students_getting_A_plus})")

# 90th percntile to 99th percentile get A
marks_90_per = dist.icdf(torch.tensor(0.90))
num_students_getting_A = marks[(marks>marks_90_per) & (marks<marks_99_per)].shape[0]
plt.fill_between(x, y, where=(x>marks_90_per) & (x<marks_99_per), alpha=0.5, label=f"A ({num_students_getting_A})")

# 75th percentile to 90th percentile get A-
marks_75_per = dist.icdf(torch.tensor(0.75))
num_students_getting_B = marks[(marks>marks_75_per) & (marks<marks_90_per)].shape[0]
plt.fill_between(x, y, where=(x>marks_75_per) & (x<marks_90_per), alpha=0.5, label=f"B ({num_students_getting_B})")

# 60th percentile to 75th percentile get B
marks_60_per = dist.icdf(torch.tensor(0.60))
num_students_getting_B = marks[(marks>marks_60_per) & (marks<marks_75_per)].shape[0]
plt.fill_between(x, y, where=(x>marks_60_per) & (x<marks_75_per), alpha=0.5, label=f"B- ({num_students_getting_B})")

# 45th percentile to 60th percentile get C
marks_45_per = dist.icdf(torch.tensor(0.45))
num_students_getting_B_minus = marks[(marks>marks_45_per) & (marks<marks_60_per)].shape[0]
plt.fill_between(x, y, where=(x>marks_45_per) & (x<marks_60_per), alpha=0.5, label=f"C ({num_students_getting_B_minus})")

#35th percentile to 45th percentile get C-
marks_35_per = dist.icdf(torch.tensor(0.35))
num_students_getting_C = marks[(marks>marks_35_per) & (marks<marks_45_per)].shape[0]
plt.fill_between(x, y, where=(x>marks_35_per) & (x<marks_45_per), alpha=0.5, label=f"C- ({num_students_getting_C})")

# 20th percentile to 35th percentile get D
marks_20_per = dist.icdf(torch.tensor(0.20))
num_students_getting_C_minus = marks[(marks>marks_20_per) & (marks<marks_35_per)].shape[0]
plt.fill_between(x, y, where=(x>marks_20_per) & (x<marks_35_per), alpha=0.5, label=f"D ({num_students_getting_C_minus})")

# 3rd percentile to 20th percentile get E
marks_3_per = dist.icdf(torch.tensor(0.03))
num_students_getting_D = marks[(marks>marks_3_per) & (marks<marks_20_per)].shape[0]
plt.fill_between(x, y, where=(x>marks_3_per) & (x<marks_20_per), alpha=0.5, label=f"E ({num_students_getting_D})")

# 3rd percentile and below get F
num_students_getting_F = marks[marks<marks_3_per].shape[0]
plt.fill_between(x, y, where=x<marks_3_per, alpha=0.5, label=f"F ({num_students_getting_F})")


Laplace Distribution

Let \(X\) be a random variable that follows a Laplace distribution with mean \(\mu\) and scale \(\lambda\). The probability density function (PDF) of \(X\) is given by:

\[ f_X(x) = \frac{1}{2\lambda} \exp\left(-\frac{|x-\mu|}{\lambda}\right). \]

unit_normal = torch.distributions.Normal(0, 1)
unit_laplace = torch.distributions.Laplace(0, 1)

x = torch.linspace(-10, 10, 1000)
y_normal = unit_normal.log_prob(x).exp()
y_laplace = unit_laplace.log_prob(x).exp()
plt.plot(x, y_normal, label="Normal")
plt.plot(x, y_laplace, label="Laplace")

Half Normal Distribution

Let \(Y\) follow the normal distribution with mean \(0\) and variance \(\sigma^2\). The half normal distribution is obtained by taking the absolute value of \(Y\). \(X = |Y|\) follows a half normal distribution with mean \(\sqrt{\frac{2}{\pi}}\sigma\) and variance \(\sigma^2(1-\frac{2}{\pi})\).

The probability density function (PDF) of \(X\) is given by:

$$ f_X(x) = (-).


hn = torch.distributions.HalfNormal(1)
x = torch.linspace(-10, 10, 1000)
    y = hn.log_prob(x).exp()
    plt.plot(x, y, label="HalfNormal")
except Exception as e:
Expected value argument (Tensor of shape (1000,)) to be within the support (GreaterThanEq(lower_bound=0.0)) of the distribution HalfNormal(), but found invalid values:
tensor([-10.0000,  -9.9800,  -9.9600,  -9.9399,  -9.9199,  -9.8999,  -9.8799,
         -9.8599,  -9.8398,  -9.8198,  -9.7998,  -9.7798,  -9.7598,  -9.7397,
         -9.7197,  -9.6997,  -9.6797,  -9.6597,  -9.6396,  -9.6196,  -9.5996,
         -9.5796,  -9.5596,  -9.5395,  -9.5195,  -9.4995,  -9.4795,  -9.4595,
         -9.4394,  -9.4194,  -9.3994,  -9.3794,  -9.3594,  -9.3393,  -9.3193,
         -9.2993,  -9.2793,  -9.2593,  -9.2392,  -9.2192,  -9.1992,  -9.1792,
         -9.1592,  -9.1391,  -9.1191,  -9.0991,  -9.0791,  -9.0591,  -9.0390,
         -9.0190,  -8.9990,  -8.9790,  -8.9590,  -8.9389,  -8.9189,  -8.8989,
         -8.8789,  -8.8589,  -8.8388,  -8.8188,  -8.7988,  -8.7788,  -8.7588,
         -8.7387,  -8.7187,  -8.6987,  -8.6787,  -8.6587,  -8.6386,  -8.6186,
         -8.5986,  -8.5786,  -8.5586,  -8.5385,  -8.5185,  -8.4985,  -8.4785,
         -8.4585,  -8.4384,  -8.4184,  -8.3984,  -8.3784,  -8.3584,  -8.3383,
         -8.3183,  -8.2983,  -8.2783,  -8.2583,  -8.2382,  -8.2182,  -8.1982,
         -8.1782,  -8.1582,  -8.1381,  -8.1181,  -8.0981,  -8.0781,  -8.0581,
         -8.0380,  -8.0180,  -7.9980,  -7.9780,  -7.9580,  -7.9379,  -7.9179,
         -7.8979,  -7.8779,  -7.8579,  -7.8378,  -7.8178,  -7.7978,  -7.7778,
         -7.7578,  -7.7377,  -7.7177,  -7.6977,  -7.6777,  -7.6577,  -7.6376,
         -7.6176,  -7.5976,  -7.5776,  -7.5576,  -7.5375,  -7.5175,  -7.4975,
         -7.4775,  -7.4575,  -7.4374,  -7.4174,  -7.3974,  -7.3774,  -7.3574,
         -7.3373,  -7.3173,  -7.2973,  -7.2773,  -7.2573,  -7.2372,  -7.2172,
         -7.1972,  -7.1772,  -7.1572,  -7.1371,  -7.1171,  -7.0971,  -7.0771,
         -7.0571,  -7.0370,  -7.0170,  -6.9970,  -6.9770,  -6.9570,  -6.9369,
         -6.9169,  -6.8969,  -6.8769,  -6.8569,  -6.8368,  -6.8168,  -6.7968,
         -6.7768,  -6.7568,  -6.7367,  -6.7167,  -6.6967,  -6.6767,  -6.6567,
         -6.6366,  -6.6166,  -6.5966,  -6.5766,  -6.5566,  -6.5365,  -6.5165,
         -6.4965,  -6.4765,  -6.4565,  -6.4364,  -6.4164,  -6.3964,  -6.3764,
         -6.3564,  -6.3363,  -6.3163,  -6.2963,  -6.2763,  -6.2563,  -6.2362,
         -6.2162,  -6.1962,  -6.1762,  -6.1562,  -6.1361,  -6.1161,  -6.0961,
         -6.0761,  -6.0561,  -6.0360,  -6.0160,  -5.9960,  -5.9760,  -5.9560,
         -5.9359,  -5.9159,  -5.8959,  -5.8759,  -5.8559,  -5.8358,  -5.8158,
         -5.7958,  -5.7758,  -5.7558,  -5.7357,  -5.7157,  -5.6957,  -5.6757,
         -5.6557,  -5.6356,  -5.6156,  -5.5956,  -5.5756,  -5.5556,  -5.5355,
         -5.5155,  -5.4955,  -5.4755,  -5.4555,  -5.4354,  -5.4154,  -5.3954,
         -5.3754,  -5.3554,  -5.3353,  -5.3153,  -5.2953,  -5.2753,  -5.2553,
         -5.2352,  -5.2152,  -5.1952,  -5.1752,  -5.1552,  -5.1351,  -5.1151,
         -5.0951,  -5.0751,  -5.0551,  -5.0350,  -5.0150,  -4.9950,  -4.9750,
         -4.9550,  -4.9349,  -4.9149,  -4.8949,  -4.8749,  -4.8549,  -4.8348,
         -4.8148,  -4.7948,  -4.7748,  -4.7548,  -4.7347,  -4.7147,  -4.6947,
         -4.6747,  -4.6547,  -4.6346,  -4.6146,  -4.5946,  -4.5746,  -4.5546,
         -4.5345,  -4.5145,  -4.4945,  -4.4745,  -4.4545,  -4.4344,  -4.4144,
         -4.3944,  -4.3744,  -4.3544,  -4.3343,  -4.3143,  -4.2943,  -4.2743,
         -4.2543,  -4.2342,  -4.2142,  -4.1942,  -4.1742,  -4.1542,  -4.1341,
         -4.1141,  -4.0941,  -4.0741,  -4.0541,  -4.0340,  -4.0140,  -3.9940,
         -3.9740,  -3.9540,  -3.9339,  -3.9139,  -3.8939,  -3.8739,  -3.8539,
         -3.8338,  -3.8138,  -3.7938,  -3.7738,  -3.7538,  -3.7337,  -3.7137,
         -3.6937,  -3.6737,  -3.6537,  -3.6336,  -3.6136,  -3.5936,  -3.5736,
         -3.5536,  -3.5335,  -3.5135,  -3.4935,  -3.4735,  -3.4535,  -3.4334,
         -3.4134,  -3.3934,  -3.3734,  -3.3534,  -3.3333,  -3.3133,  -3.2933,
         -3.2733,  -3.2533,  -3.2332,  -3.2132,  -3.1932,  -3.1732,  -3.1532,
         -3.1331,  -3.1131,  -3.0931,  -3.0731,  -3.0531,  -3.0330,  -3.0130,
         -2.9930,  -2.9730,  -2.9530,  -2.9329,  -2.9129,  -2.8929,  -2.8729,
         -2.8529,  -2.8328,  -2.8128,  -2.7928,  -2.7728,  -2.7528,  -2.7327,
         -2.7127,  -2.6927,  -2.6727,  -2.6527,  -2.6326,  -2.6126,  -2.5926,
         -2.5726,  -2.5526,  -2.5325,  -2.5125,  -2.4925,  -2.4725,  -2.4525,
         -2.4324,  -2.4124,  -2.3924,  -2.3724,  -2.3524,  -2.3323,  -2.3123,
         -2.2923,  -2.2723,  -2.2523,  -2.2322,  -2.2122,  -2.1922,  -2.1722,
         -2.1522,  -2.1321,  -2.1121,  -2.0921,  -2.0721,  -2.0521,  -2.0320,
         -2.0120,  -1.9920,  -1.9720,  -1.9520,  -1.9319,  -1.9119,  -1.8919,
         -1.8719,  -1.8519,  -1.8318,  -1.8118,  -1.7918,  -1.7718,  -1.7518,
         -1.7317,  -1.7117,  -1.6917,  -1.6717,  -1.6517,  -1.6316,  -1.6116,
         -1.5916,  -1.5716,  -1.5516,  -1.5315,  -1.5115,  -1.4915,  -1.4715,
         -1.4515,  -1.4314,  -1.4114,  -1.3914,  -1.3714,  -1.3514,  -1.3313,
         -1.3113,  -1.2913,  -1.2713,  -1.2513,  -1.2312,  -1.2112,  -1.1912,
         -1.1712,  -1.1512,  -1.1311,  -1.1111,  -1.0911,  -1.0711,  -1.0511,
         -1.0310,  -1.0110,  -0.9910,  -0.9710,  -0.9510,  -0.9309,  -0.9109,
         -0.8909,  -0.8709,  -0.8509,  -0.8308,  -0.8108,  -0.7908,  -0.7708,
         -0.7508,  -0.7307,  -0.7107,  -0.6907,  -0.6707,  -0.6507,  -0.6306,
         -0.6106,  -0.5906,  -0.5706,  -0.5506,  -0.5305,  -0.5105,  -0.4905,
         -0.4705,  -0.4505,  -0.4304,  -0.4104,  -0.3904,  -0.3704,  -0.3504,
         -0.3303,  -0.3103,  -0.2903,  -0.2703,  -0.2503,  -0.2302,  -0.2102,
         -0.1902,  -0.1702,  -0.1502,  -0.1301,  -0.1101,  -0.0901,  -0.0701,
         -0.0501,  -0.0300,  -0.0100,   0.0100,   0.0300,   0.0500,   0.0701,
          0.0901,   0.1101,   0.1301,   0.1502,   0.1702,   0.1902,   0.2102,
          0.2302,   0.2503,   0.2703,   0.2903,   0.3103,   0.3303,   0.3504,
          0.3704,   0.3904,   0.4104,   0.4304,   0.4505,   0.4705,   0.4905,
          0.5105,   0.5305,   0.5506,   0.5706,   0.5906,   0.6106,   0.6306,
          0.6507,   0.6707,   0.6907,   0.7107,   0.7307,   0.7508,   0.7708,
          0.7908,   0.8108,   0.8308,   0.8509,   0.8709,   0.8909,   0.9109,
          0.9309,   0.9510,   0.9710,   0.9910,   1.0110,   1.0310,   1.0511,
          1.0711,   1.0911,   1.1111,   1.1311,   1.1512,   1.1712,   1.1912,
          1.2112,   1.2312,   1.2513,   1.2713,   1.2913,   1.3113,   1.3313,
          1.3514,   1.3714,   1.3914,   1.4114,   1.4314,   1.4515,   1.4715,
          1.4915,   1.5115,   1.5315,   1.5516,   1.5716,   1.5916,   1.6116,
          1.6316,   1.6517,   1.6717,   1.6917,   1.7117,   1.7317,   1.7518,
          1.7718,   1.7918,   1.8118,   1.8318,   1.8519,   1.8719,   1.8919,
          1.9119,   1.9319,   1.9520,   1.9720,   1.9920,   2.0120,   2.0320,
          2.0521,   2.0721,   2.0921,   2.1121,   2.1321,   2.1522,   2.1722,
          2.1922,   2.2122,   2.2322,   2.2523,   2.2723,   2.2923,   2.3123,
          2.3323,   2.3524,   2.3724,   2.3924,   2.4124,   2.4324,   2.4525,
          2.4725,   2.4925,   2.5125,   2.5325,   2.5526,   2.5726,   2.5926,
          2.6126,   2.6326,   2.6527,   2.6727,   2.6927,   2.7127,   2.7327,
          2.7528,   2.7728,   2.7928,   2.8128,   2.8328,   2.8529,   2.8729,
          2.8929,   2.9129,   2.9329,   2.9530,   2.9730,   2.9930,   3.0130,
          3.0330,   3.0531,   3.0731,   3.0931,   3.1131,   3.1331,   3.1532,
          3.1732,   3.1932,   3.2132,   3.2332,   3.2533,   3.2733,   3.2933,
          3.3133,   3.3333,   3.3534,   3.3734,   3.3934,   3.4134,   3.4334,
          3.4535,   3.4735,   3.4935,   3.5135,   3.5335,   3.5536,   3.5736,
          3.5936,   3.6136,   3.6336,   3.6537,   3.6737,   3.6937,   3.7137,
          3.7337,   3.7538,   3.7738,   3.7938,   3.8138,   3.8338,   3.8539,
          3.8739,   3.8939,   3.9139,   3.9339,   3.9540,   3.9740,   3.9940,
          4.0140,   4.0340,   4.0541,   4.0741,   4.0941,   4.1141,   4.1341,
          4.1542,   4.1742,   4.1942,   4.2142,   4.2342,   4.2543,   4.2743,
          4.2943,   4.3143,   4.3343,   4.3544,   4.3744,   4.3944,   4.4144,
          4.4344,   4.4545,   4.4745,   4.4945,   4.5145,   4.5345,   4.5546,
          4.5746,   4.5946,   4.6146,   4.6346,   4.6547,   4.6747,   4.6947,
          4.7147,   4.7347,   4.7548,   4.7748,   4.7948,   4.8148,   4.8348,
          4.8549,   4.8749,   4.8949,   4.9149,   4.9349,   4.9550,   4.9750,
          4.9950,   5.0150,   5.0350,   5.0551,   5.0751,   5.0951,   5.1151,
          5.1351,   5.1552,   5.1752,   5.1952,   5.2152,   5.2352,   5.2553,
          5.2753,   5.2953,   5.3153,   5.3353,   5.3554,   5.3754,   5.3954,
          5.4154,   5.4354,   5.4555,   5.4755,   5.4955,   5.5155,   5.5355,
          5.5556,   5.5756,   5.5956,   5.6156,   5.6356,   5.6557,   5.6757,
          5.6957,   5.7157,   5.7357,   5.7558,   5.7758,   5.7958,   5.8158,
          5.8358,   5.8559,   5.8759,   5.8959,   5.9159,   5.9359,   5.9560,
          5.9760,   5.9960,   6.0160,   6.0360,   6.0561,   6.0761,   6.0961,
          6.1161,   6.1361,   6.1562,   6.1762,   6.1962,   6.2162,   6.2362,
          6.2563,   6.2763,   6.2963,   6.3163,   6.3363,   6.3564,   6.3764,
          6.3964,   6.4164,   6.4364,   6.4565,   6.4765,   6.4965,   6.5165,
          6.5365,   6.5566,   6.5766,   6.5966,   6.6166,   6.6366,   6.6567,
          6.6767,   6.6967,   6.7167,   6.7367,   6.7568,   6.7768,   6.7968,
          6.8168,   6.8368,   6.8569,   6.8769,   6.8969,   6.9169,   6.9369,
          6.9570,   6.9770,   6.9970,   7.0170,   7.0370,   7.0571,   7.0771,
          7.0971,   7.1171,   7.1371,   7.1572,   7.1772,   7.1972,   7.2172,
          7.2372,   7.2573,   7.2773,   7.2973,   7.3173,   7.3373,   7.3574,
          7.3774,   7.3974,   7.4174,   7.4374,   7.4575,   7.4775,   7.4975,
          7.5175,   7.5375,   7.5576,   7.5776,   7.5976,   7.6176,   7.6376,
          7.6577,   7.6777,   7.6977,   7.7177,   7.7377,   7.7578,   7.7778,
          7.7978,   7.8178,   7.8378,   7.8579,   7.8779,   7.8979,   7.9179,
          7.9379,   7.9580,   7.9780,   7.9980,   8.0180,   8.0380,   8.0581,
          8.0781,   8.0981,   8.1181,   8.1381,   8.1582,   8.1782,   8.1982,
          8.2182,   8.2382,   8.2583,   8.2783,   8.2983,   8.3183,   8.3383,
          8.3584,   8.3784,   8.3984,   8.4184,   8.4384,   8.4585,   8.4785,
          8.4985,   8.5185,   8.5385,   8.5586,   8.5786,   8.5986,   8.6186,
          8.6386,   8.6587,   8.6787,   8.6987,   8.7187,   8.7387,   8.7588,
          8.7788,   8.7988,   8.8188,   8.8388,   8.8589,   8.8789,   8.8989,
          8.9189,   8.9389,   8.9590,   8.9790,   8.9990,   9.0190,   9.0390,
          9.0591,   9.0791,   9.0991,   9.1191,   9.1391,   9.1592,   9.1792,
          9.1992,   9.2192,   9.2392,   9.2593,   9.2793,   9.2993,   9.3193,
          9.3393,   9.3594,   9.3794,   9.3994,   9.4194,   9.4394,   9.4595,
          9.4795,   9.4995,   9.5195,   9.5395,   9.5596,   9.5796,   9.5996,
          9.6196,   9.6396,   9.6597,   9.6797,   9.6997,   9.7197,   9.7397,
          9.7598,   9.7798,   9.7998,   9.8198,   9.8398,   9.8599,   9.8799,
          9.8999,   9.9199,   9.9399,   9.9600,   9.9800,  10.0000])
hn = torch.distributions.HalfNormal(1)
x = torch.linspace(-10, 10, 1000)
x_mask = x>0
y = torch.zeros_like(x)
y[x_mask] = hn.log_prob(x[x_mask]).exp()
plt.plot(x, y, label="Half Normal")

normal =  torch.distributions.Normal(0, 1)
y_norm = normal.log_prob(x).exp()
plt.plot(x, y_norm, label="Normal")

dist = torch.distributions.Normal(0, 1)
x_lin = torch.linspace(-17, 17, 1000)
log_probs = dist.log_prob(x_lin)
probs = log_probs.exp()

fig, ax = plt.subplots(nrows=2, sharex=True)
ax[0].plot(x_lin, log_probs)
ax[0].set_title("Log Prob")
ax[1].plot(x_lin, probs)
Text(0.5, 1.0, 'Prob')

print(probs[0], probs[1])
print(log_probs[0], log_probs[1])
tensor(0.) tensor(0.)
tensor(-145.4189) tensor(-144.8409)

Log Normal Distribution

Let \(Y \sim \mathcal{N}(\mu, \sigma^2)\) be a normally distributed random variable.

Let us define a new random variable \(X = \exp(Y)\).

We can say that log of \(X\) is normally distributed, i.e., \(\log(X) \sim \mathcal{N}(\mu, \sigma^2)\).

We can also say that \(X\) is log-normally distributed.

The probability density function (PDF) of \(X\) is given by:

\[ f_X(x) = \frac{1}{x\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(\log(x)-\mu)^2}{2\sigma^2}\right). \]

We can derive the PDF of \(X\) using the change of variables formula. (will be covered later in the course)

mu = 1.0
sigma = 1.0

log_normal = torch.distributions.LogNormal(mu, sigma)

x = torch.linspace(-10, 10, 1000)
x_non_neg_mask = x > 0.001

y = torch.zeros_like(x)
y[x_non_neg_mask] = log_normal.log_prob(x[x_non_neg_mask]).exp()
plt.plot(x, y, label="PDF LogNormal(1, 1)")

normal = torch.distributions.Normal(mu, sigma)
plt.plot(x, normal.log_prob(x).exp(), label="PDF Normal(1, 1)")





import kagglehub

# Download latest version
path = kagglehub.dataset_download("datasnaek/chess")

print("Path to dataset files:", path)
Path to dataset files: /Users/nipun/.cache/kagglehub/datasets/datasnaek/chess/versions/1
import os
df = pd.read_csv(os.path.join(path, "games.csv"))
id rated created_at last_move_at turns victory_status winner increment_code white_id white_rating black_id black_rating moves opening_eco opening_name opening_ply
0 TZJHLljE False 1.504210e+12 1.504210e+12 13 outoftime white 15+2 bourgris 1500 a-00 1191 d4 d5 c4 c6 cxd5 e6 dxe6 fxe6 Nf3 Bb4+ Nc3 Ba5... D10 Slav Defense: Exchange Variation 5
1 l1NXvwaE True 1.504130e+12 1.504130e+12 16 resign black 5+10 a-00 1322 skinnerua 1261 d4 Nc6 e4 e5 f4 f6 dxe5 fxe5 fxe5 Nxe5 Qd4 Nc6... B00 Nimzowitsch Defense: Kennedy Variation 4
2 mIICvQHh True 1.504130e+12 1.504130e+12 61 mate white 5+10 ischia 1496 a-00 1500 e4 e5 d3 d6 Be3 c6 Be2 b5 Nd2 a5 a4 c5 axb5 Nc... C20 King's Pawn Game: Leonardis Variation 3
3 kWKvrqYL True 1.504110e+12 1.504110e+12 61 mate white 20+0 daniamurashov 1439 adivanov2009 1454 d4 d5 Nf3 Bf5 Nc3 Nf6 Bf4 Ng4 e3 Nc6 Be2 Qd7 O... D02 Queen's Pawn Game: Zukertort Variation 3
4 9tXo1AUZ True 1.504030e+12 1.504030e+12 95 mate white 30+3 nik221107 1523 adivanov2009 1469 e4 e5 Nf3 d6 d4 Nc6 d5 Nb4 a3 Na6 Nc3 Be7 b4 N... C41 Philidor Defense 5
# Distribution of the number of turns in the games
df["turns"].plot(kind='hist', bins=50)

# Logarithm of the number of turns
df["turns"].apply(np.log).plot(kind='hist', bins=50)

# Log of turns seems to be normally distributed

mu, sigma = df["turns"].apply(np.log).mean(), df["turns"].apply(np.log1p).std()
print(mu, sigma)
3.9070571274448245 0.6822030192719669
# Plot PDF of the fitted log-normal distribution

x = torch.linspace(0.001, 300, 1000)

with torch.no_grad():
    log_normal = torch.distributions.LogNormal(mu, sigma)
y = log_normal.log_prob(x).exp()

plt.plot(x, y, label="Fitted PDF")
plt.hist(df["turns"], bins=50, density=True, alpha=0.5, label="KDE")

Gamma distribution

Let \(X\) be a random variable that follows a gamma distribution with shape parameter \(k\) and scale parameter \(\theta\). The probability density function (PDF) of \(X\) is given by:

\[ f_X(x) = \frac{1}{\Gamma(k)\theta^k} x^{k-1} \exp\left(-\frac{x}{\theta}\right). \]

where \(\Gamma(k)\) is the gamma function defined as:

\[ \Gamma(k) = \int_0^\infty x^{k-1} e^{-x} dx. \]

gamma_dist = torch.distributions.Gamma(2, 1)

x = torch.linspace(0.001, 10, 1000)
y = gamma_dist.log_prob(x).exp()
plt.plot(x, y, label="PDF Gamma(2, 1)")

# Fit a gamma distribution to the data
alpha, beta = torch.tensor([1.0], requires_grad=True), torch.tensor([1.0], requires_grad=True)
gamma_dist = torch.distributions.Gamma(alpha, beta)

optimizer = torch.optim.Adam([alpha, beta], lr=0.01)

x = torch.tensor(df["turns"].values, dtype=torch.float32)

for i in range(1000):
    loss = -gamma_dist.log_prob(x).mean()

print(alpha.item(), beta.item())
2.315873384475708 0.03829348832368851
learnt_gamma_dist = torch.distributions.Gamma(alpha.detach(), beta.detach())
x = torch.linspace(0.001, 300, 1000)
y = learnt_gamma_dist.log_prob(x).exp()
plt.plot(x, y, label="Fitted PDF")
plt.hist(df["turns"], bins=50, density=True, alpha=0.5, label="KDE")

Uniform Distribution

Let \(X\) be a random variable that follows a uniform distribution on the interval \([a, b]\). The probability density function (PDF) of \(X\) is given by:

$$ f_X(x) = \[\begin{cases} \frac{1}{b-a} & \text{if } x \in [a, b], \\ 0 & \text{otherwise}. \end{cases}\]


We can say that \(X \sim \text{Uniform}(a, b)\).

a = 0.0
b = 2.0
dist = torch.distributions.Uniform(a, b)
Interval(lower_bound=0.0, upper_bound=2.0)
dist.high, dist.low
(tensor(2.), tensor(0.))
x_range = torch.linspace(-1, 3, 1000)
    y = dist.log_prob(x_range).exp()
except Exception as e:
Expected value argument (Tensor of shape (1000,)) to be within the support (Interval(lower_bound=0.0, upper_bound=2.0)) of the distribution Uniform(low: 0.0, high: 2.0), but found invalid values:
tensor([-1.0000e+00, -9.9600e-01, -9.9199e-01, -9.8799e-01, -9.8398e-01,
        -9.7998e-01, -9.7598e-01, -9.7197e-01, -9.6797e-01, -9.6396e-01,
        -9.5996e-01, -9.5596e-01, -9.5195e-01, -9.4795e-01, -9.4394e-01,
        -9.3994e-01, -9.3594e-01, -9.3193e-01, -9.2793e-01, -9.2392e-01,
        -9.1992e-01, -9.1592e-01, -9.1191e-01, -9.0791e-01, -9.0390e-01,
        -8.9990e-01, -8.9590e-01, -8.9189e-01, -8.8789e-01, -8.8388e-01,
        -8.7988e-01, -8.7588e-01, -8.7187e-01, -8.6787e-01, -8.6386e-01,
        -8.5986e-01, -8.5586e-01, -8.5185e-01, -8.4785e-01, -8.4384e-01,
        -8.3984e-01, -8.3584e-01, -8.3183e-01, -8.2783e-01, -8.2382e-01,
        -8.1982e-01, -8.1582e-01, -8.1181e-01, -8.0781e-01, -8.0380e-01,
        -7.9980e-01, -7.9580e-01, -7.9179e-01, -7.8779e-01, -7.8378e-01,
        -7.7978e-01, -7.7578e-01, -7.7177e-01, -7.6777e-01, -7.6376e-01,
        -7.5976e-01, -7.5576e-01, -7.5175e-01, -7.4775e-01, -7.4374e-01,
        -7.3974e-01, -7.3574e-01, -7.3173e-01, -7.2773e-01, -7.2372e-01,
        -7.1972e-01, -7.1572e-01, -7.1171e-01, -7.0771e-01, -7.0370e-01,
        -6.9970e-01, -6.9570e-01, -6.9169e-01, -6.8769e-01, -6.8368e-01,
        -6.7968e-01, -6.7568e-01, -6.7167e-01, -6.6767e-01, -6.6366e-01,
        -6.5966e-01, -6.5566e-01, -6.5165e-01, -6.4765e-01, -6.4364e-01,
        -6.3964e-01, -6.3564e-01, -6.3163e-01, -6.2763e-01, -6.2362e-01,
        -6.1962e-01, -6.1562e-01, -6.1161e-01, -6.0761e-01, -6.0360e-01,
        -5.9960e-01, -5.9560e-01, -5.9159e-01, -5.8759e-01, -5.8358e-01,
        -5.7958e-01, -5.7558e-01, -5.7157e-01, -5.6757e-01, -5.6356e-01,
        -5.5956e-01, -5.5556e-01, -5.5155e-01, -5.4755e-01, -5.4354e-01,
        -5.3954e-01, -5.3554e-01, -5.3153e-01, -5.2753e-01, -5.2352e-01,
        -5.1952e-01, -5.1552e-01, -5.1151e-01, -5.0751e-01, -5.0350e-01,
        -4.9950e-01, -4.9550e-01, -4.9149e-01, -4.8749e-01, -4.8348e-01,
        -4.7948e-01, -4.7548e-01, -4.7147e-01, -4.6747e-01, -4.6346e-01,
        -4.5946e-01, -4.5546e-01, -4.5145e-01, -4.4745e-01, -4.4344e-01,
        -4.3944e-01, -4.3544e-01, -4.3143e-01, -4.2743e-01, -4.2342e-01,
        -4.1942e-01, -4.1542e-01, -4.1141e-01, -4.0741e-01, -4.0340e-01,
        -3.9940e-01, -3.9540e-01, -3.9139e-01, -3.8739e-01, -3.8338e-01,
        -3.7938e-01, -3.7538e-01, -3.7137e-01, -3.6737e-01, -3.6336e-01,
        -3.5936e-01, -3.5536e-01, -3.5135e-01, -3.4735e-01, -3.4334e-01,
        -3.3934e-01, -3.3534e-01, -3.3133e-01, -3.2733e-01, -3.2332e-01,
        -3.1932e-01, -3.1532e-01, -3.1131e-01, -3.0731e-01, -3.0330e-01,
        -2.9930e-01, -2.9530e-01, -2.9129e-01, -2.8729e-01, -2.8328e-01,
        -2.7928e-01, -2.7528e-01, -2.7127e-01, -2.6727e-01, -2.6326e-01,
        -2.5926e-01, -2.5526e-01, -2.5125e-01, -2.4725e-01, -2.4324e-01,
        -2.3924e-01, -2.3524e-01, -2.3123e-01, -2.2723e-01, -2.2322e-01,
        -2.1922e-01, -2.1522e-01, -2.1121e-01, -2.0721e-01, -2.0320e-01,
        -1.9920e-01, -1.9520e-01, -1.9119e-01, -1.8719e-01, -1.8318e-01,
        -1.7918e-01, -1.7518e-01, -1.7117e-01, -1.6717e-01, -1.6316e-01,
        -1.5916e-01, -1.5516e-01, -1.5115e-01, -1.4715e-01, -1.4314e-01,
        -1.3914e-01, -1.3514e-01, -1.3113e-01, -1.2713e-01, -1.2312e-01,
        -1.1912e-01, -1.1512e-01, -1.1111e-01, -1.0711e-01, -1.0310e-01,
        -9.9099e-02, -9.5095e-02, -9.1091e-02, -8.7087e-02, -8.3083e-02,
        -7.9079e-02, -7.5075e-02, -7.1071e-02, -6.7067e-02, -6.3063e-02,
        -5.9059e-02, -5.5055e-02, -5.1051e-02, -4.7047e-02, -4.3043e-02,
        -3.9039e-02, -3.5035e-02, -3.1031e-02, -2.7027e-02, -2.3023e-02,
        -1.9019e-02, -1.5015e-02, -1.1011e-02, -7.0070e-03, -3.0030e-03,
         1.0010e-03,  5.0050e-03,  9.0090e-03,  1.3013e-02,  1.7017e-02,
         2.1021e-02,  2.5025e-02,  2.9029e-02,  3.3033e-02,  3.7037e-02,
         4.1041e-02,  4.5045e-02,  4.9049e-02,  5.3053e-02,  5.7057e-02,
         6.1061e-02,  6.5065e-02,  6.9069e-02,  7.3073e-02,  7.7077e-02,
         8.1081e-02,  8.5085e-02,  8.9089e-02,  9.3093e-02,  9.7097e-02,
         1.0110e-01,  1.0511e-01,  1.0911e-01,  1.1311e-01,  1.1712e-01,
         1.2112e-01,  1.2513e-01,  1.2913e-01,  1.3313e-01,  1.3714e-01,
         1.4114e-01,  1.4515e-01,  1.4915e-01,  1.5315e-01,  1.5716e-01,
         1.6116e-01,  1.6517e-01,  1.6917e-01,  1.7317e-01,  1.7718e-01,
         1.8118e-01,  1.8519e-01,  1.8919e-01,  1.9319e-01,  1.9720e-01,
         2.0120e-01,  2.0521e-01,  2.0921e-01,  2.1321e-01,  2.1722e-01,
         2.2122e-01,  2.2523e-01,  2.2923e-01,  2.3323e-01,  2.3724e-01,
         2.4124e-01,  2.4525e-01,  2.4925e-01,  2.5325e-01,  2.5726e-01,
         2.6126e-01,  2.6527e-01,  2.6927e-01,  2.7327e-01,  2.7728e-01,
         2.8128e-01,  2.8529e-01,  2.8929e-01,  2.9329e-01,  2.9730e-01,
         3.0130e-01,  3.0531e-01,  3.0931e-01,  3.1331e-01,  3.1732e-01,
         3.2132e-01,  3.2533e-01,  3.2933e-01,  3.3333e-01,  3.3734e-01,
         3.4134e-01,  3.4535e-01,  3.4935e-01,  3.5335e-01,  3.5736e-01,
         3.6136e-01,  3.6537e-01,  3.6937e-01,  3.7337e-01,  3.7738e-01,
         3.8138e-01,  3.8539e-01,  3.8939e-01,  3.9339e-01,  3.9740e-01,
         4.0140e-01,  4.0541e-01,  4.0941e-01,  4.1341e-01,  4.1742e-01,
         4.2142e-01,  4.2543e-01,  4.2943e-01,  4.3343e-01,  4.3744e-01,
         4.4144e-01,  4.4545e-01,  4.4945e-01,  4.5345e-01,  4.5746e-01,
         4.6146e-01,  4.6547e-01,  4.6947e-01,  4.7347e-01,  4.7748e-01,
         4.8148e-01,  4.8549e-01,  4.8949e-01,  4.9349e-01,  4.9750e-01,
         5.0150e-01,  5.0551e-01,  5.0951e-01,  5.1351e-01,  5.1752e-01,
         5.2152e-01,  5.2553e-01,  5.2953e-01,  5.3353e-01,  5.3754e-01,
         5.4154e-01,  5.4555e-01,  5.4955e-01,  5.5355e-01,  5.5756e-01,
         5.6156e-01,  5.6557e-01,  5.6957e-01,  5.7357e-01,  5.7758e-01,
         5.8158e-01,  5.8559e-01,  5.8959e-01,  5.9359e-01,  5.9760e-01,
         6.0160e-01,  6.0561e-01,  6.0961e-01,  6.1361e-01,  6.1762e-01,
         6.2162e-01,  6.2563e-01,  6.2963e-01,  6.3363e-01,  6.3764e-01,
         6.4164e-01,  6.4565e-01,  6.4965e-01,  6.5365e-01,  6.5766e-01,
         6.6166e-01,  6.6567e-01,  6.6967e-01,  6.7367e-01,  6.7768e-01,
         6.8168e-01,  6.8569e-01,  6.8969e-01,  6.9369e-01,  6.9770e-01,
         7.0170e-01,  7.0571e-01,  7.0971e-01,  7.1371e-01,  7.1772e-01,
         7.2172e-01,  7.2573e-01,  7.2973e-01,  7.3373e-01,  7.3774e-01,
         7.4174e-01,  7.4575e-01,  7.4975e-01,  7.5375e-01,  7.5776e-01,
         7.6176e-01,  7.6577e-01,  7.6977e-01,  7.7377e-01,  7.7778e-01,
         7.8178e-01,  7.8579e-01,  7.8979e-01,  7.9379e-01,  7.9780e-01,
         8.0180e-01,  8.0581e-01,  8.0981e-01,  8.1381e-01,  8.1782e-01,
         8.2182e-01,  8.2583e-01,  8.2983e-01,  8.3383e-01,  8.3784e-01,
         8.4184e-01,  8.4585e-01,  8.4985e-01,  8.5385e-01,  8.5786e-01,
         8.6186e-01,  8.6587e-01,  8.6987e-01,  8.7387e-01,  8.7788e-01,
         8.8188e-01,  8.8589e-01,  8.8989e-01,  8.9389e-01,  8.9790e-01,
         9.0190e-01,  9.0591e-01,  9.0991e-01,  9.1391e-01,  9.1792e-01,
         9.2192e-01,  9.2593e-01,  9.2993e-01,  9.3393e-01,  9.3794e-01,
         9.4194e-01,  9.4595e-01,  9.4995e-01,  9.5395e-01,  9.5796e-01,
         9.6196e-01,  9.6597e-01,  9.6997e-01,  9.7397e-01,  9.7798e-01,
         9.8198e-01,  9.8599e-01,  9.8999e-01,  9.9399e-01,  9.9800e-01,
         1.0020e+00,  1.0060e+00,  1.0100e+00,  1.0140e+00,  1.0180e+00,
         1.0220e+00,  1.0260e+00,  1.0300e+00,  1.0340e+00,  1.0380e+00,
         1.0420e+00,  1.0460e+00,  1.0501e+00,  1.0541e+00,  1.0581e+00,
         1.0621e+00,  1.0661e+00,  1.0701e+00,  1.0741e+00,  1.0781e+00,
         1.0821e+00,  1.0861e+00,  1.0901e+00,  1.0941e+00,  1.0981e+00,
         1.1021e+00,  1.1061e+00,  1.1101e+00,  1.1141e+00,  1.1181e+00,
         1.1221e+00,  1.1261e+00,  1.1301e+00,  1.1341e+00,  1.1381e+00,
         1.1421e+00,  1.1461e+00,  1.1502e+00,  1.1542e+00,  1.1582e+00,
         1.1622e+00,  1.1662e+00,  1.1702e+00,  1.1742e+00,  1.1782e+00,
         1.1822e+00,  1.1862e+00,  1.1902e+00,  1.1942e+00,  1.1982e+00,
         1.2022e+00,  1.2062e+00,  1.2102e+00,  1.2142e+00,  1.2182e+00,
         1.2222e+00,  1.2262e+00,  1.2302e+00,  1.2342e+00,  1.2382e+00,
         1.2422e+00,  1.2462e+00,  1.2503e+00,  1.2543e+00,  1.2583e+00,
         1.2623e+00,  1.2663e+00,  1.2703e+00,  1.2743e+00,  1.2783e+00,
         1.2823e+00,  1.2863e+00,  1.2903e+00,  1.2943e+00,  1.2983e+00,
         1.3023e+00,  1.3063e+00,  1.3103e+00,  1.3143e+00,  1.3183e+00,
         1.3223e+00,  1.3263e+00,  1.3303e+00,  1.3343e+00,  1.3383e+00,
         1.3423e+00,  1.3463e+00,  1.3504e+00,  1.3544e+00,  1.3584e+00,
         1.3624e+00,  1.3664e+00,  1.3704e+00,  1.3744e+00,  1.3784e+00,
         1.3824e+00,  1.3864e+00,  1.3904e+00,  1.3944e+00,  1.3984e+00,
         1.4024e+00,  1.4064e+00,  1.4104e+00,  1.4144e+00,  1.4184e+00,
         1.4224e+00,  1.4264e+00,  1.4304e+00,  1.4344e+00,  1.4384e+00,
         1.4424e+00,  1.4464e+00,  1.4505e+00,  1.4545e+00,  1.4585e+00,
         1.4625e+00,  1.4665e+00,  1.4705e+00,  1.4745e+00,  1.4785e+00,
         1.4825e+00,  1.4865e+00,  1.4905e+00,  1.4945e+00,  1.4985e+00,
         1.5025e+00,  1.5065e+00,  1.5105e+00,  1.5145e+00,  1.5185e+00,
         1.5225e+00,  1.5265e+00,  1.5305e+00,  1.5345e+00,  1.5385e+00,
         1.5425e+00,  1.5465e+00,  1.5506e+00,  1.5546e+00,  1.5586e+00,
         1.5626e+00,  1.5666e+00,  1.5706e+00,  1.5746e+00,  1.5786e+00,
         1.5826e+00,  1.5866e+00,  1.5906e+00,  1.5946e+00,  1.5986e+00,
         1.6026e+00,  1.6066e+00,  1.6106e+00,  1.6146e+00,  1.6186e+00,
         1.6226e+00,  1.6266e+00,  1.6306e+00,  1.6346e+00,  1.6386e+00,
         1.6426e+00,  1.6466e+00,  1.6507e+00,  1.6547e+00,  1.6587e+00,
         1.6627e+00,  1.6667e+00,  1.6707e+00,  1.6747e+00,  1.6787e+00,
         1.6827e+00,  1.6867e+00,  1.6907e+00,  1.6947e+00,  1.6987e+00,
         1.7027e+00,  1.7067e+00,  1.7107e+00,  1.7147e+00,  1.7187e+00,
         1.7227e+00,  1.7267e+00,  1.7307e+00,  1.7347e+00,  1.7387e+00,
         1.7427e+00,  1.7467e+00,  1.7508e+00,  1.7548e+00,  1.7588e+00,
         1.7628e+00,  1.7668e+00,  1.7708e+00,  1.7748e+00,  1.7788e+00,
         1.7828e+00,  1.7868e+00,  1.7908e+00,  1.7948e+00,  1.7988e+00,
         1.8028e+00,  1.8068e+00,  1.8108e+00,  1.8148e+00,  1.8188e+00,
         1.8228e+00,  1.8268e+00,  1.8308e+00,  1.8348e+00,  1.8388e+00,
         1.8428e+00,  1.8468e+00,  1.8509e+00,  1.8549e+00,  1.8589e+00,
         1.8629e+00,  1.8669e+00,  1.8709e+00,  1.8749e+00,  1.8789e+00,
         1.8829e+00,  1.8869e+00,  1.8909e+00,  1.8949e+00,  1.8989e+00,
         1.9029e+00,  1.9069e+00,  1.9109e+00,  1.9149e+00,  1.9189e+00,
         1.9229e+00,  1.9269e+00,  1.9309e+00,  1.9349e+00,  1.9389e+00,
         1.9429e+00,  1.9469e+00,  1.9510e+00,  1.9550e+00,  1.9590e+00,
         1.9630e+00,  1.9670e+00,  1.9710e+00,  1.9750e+00,  1.9790e+00,
         1.9830e+00,  1.9870e+00,  1.9910e+00,  1.9950e+00,  1.9990e+00,
         2.0030e+00,  2.0070e+00,  2.0110e+00,  2.0150e+00,  2.0190e+00,
         2.0230e+00,  2.0270e+00,  2.0310e+00,  2.0350e+00,  2.0390e+00,
         2.0430e+00,  2.0470e+00,  2.0511e+00,  2.0551e+00,  2.0591e+00,
         2.0631e+00,  2.0671e+00,  2.0711e+00,  2.0751e+00,  2.0791e+00,
         2.0831e+00,  2.0871e+00,  2.0911e+00,  2.0951e+00,  2.0991e+00,
         2.1031e+00,  2.1071e+00,  2.1111e+00,  2.1151e+00,  2.1191e+00,
         2.1231e+00,  2.1271e+00,  2.1311e+00,  2.1351e+00,  2.1391e+00,
         2.1431e+00,  2.1471e+00,  2.1512e+00,  2.1552e+00,  2.1592e+00,
         2.1632e+00,  2.1672e+00,  2.1712e+00,  2.1752e+00,  2.1792e+00,
         2.1832e+00,  2.1872e+00,  2.1912e+00,  2.1952e+00,  2.1992e+00,
         2.2032e+00,  2.2072e+00,  2.2112e+00,  2.2152e+00,  2.2192e+00,
         2.2232e+00,  2.2272e+00,  2.2312e+00,  2.2352e+00,  2.2392e+00,
         2.2432e+00,  2.2472e+00,  2.2513e+00,  2.2553e+00,  2.2593e+00,
         2.2633e+00,  2.2673e+00,  2.2713e+00,  2.2753e+00,  2.2793e+00,
         2.2833e+00,  2.2873e+00,  2.2913e+00,  2.2953e+00,  2.2993e+00,
         2.3033e+00,  2.3073e+00,  2.3113e+00,  2.3153e+00,  2.3193e+00,
         2.3233e+00,  2.3273e+00,  2.3313e+00,  2.3353e+00,  2.3393e+00,
         2.3433e+00,  2.3473e+00,  2.3514e+00,  2.3554e+00,  2.3594e+00,
         2.3634e+00,  2.3674e+00,  2.3714e+00,  2.3754e+00,  2.3794e+00,
         2.3834e+00,  2.3874e+00,  2.3914e+00,  2.3954e+00,  2.3994e+00,
         2.4034e+00,  2.4074e+00,  2.4114e+00,  2.4154e+00,  2.4194e+00,
         2.4234e+00,  2.4274e+00,  2.4314e+00,  2.4354e+00,  2.4394e+00,
         2.4434e+00,  2.4474e+00,  2.4515e+00,  2.4555e+00,  2.4595e+00,
         2.4635e+00,  2.4675e+00,  2.4715e+00,  2.4755e+00,  2.4795e+00,
         2.4835e+00,  2.4875e+00,  2.4915e+00,  2.4955e+00,  2.4995e+00,
         2.5035e+00,  2.5075e+00,  2.5115e+00,  2.5155e+00,  2.5195e+00,
         2.5235e+00,  2.5275e+00,  2.5315e+00,  2.5355e+00,  2.5395e+00,
         2.5435e+00,  2.5475e+00,  2.5516e+00,  2.5556e+00,  2.5596e+00,
         2.5636e+00,  2.5676e+00,  2.5716e+00,  2.5756e+00,  2.5796e+00,
         2.5836e+00,  2.5876e+00,  2.5916e+00,  2.5956e+00,  2.5996e+00,
         2.6036e+00,  2.6076e+00,  2.6116e+00,  2.6156e+00,  2.6196e+00,
         2.6236e+00,  2.6276e+00,  2.6316e+00,  2.6356e+00,  2.6396e+00,
         2.6436e+00,  2.6476e+00,  2.6517e+00,  2.6557e+00,  2.6597e+00,
         2.6637e+00,  2.6677e+00,  2.6717e+00,  2.6757e+00,  2.6797e+00,
         2.6837e+00,  2.6877e+00,  2.6917e+00,  2.6957e+00,  2.6997e+00,
         2.7037e+00,  2.7077e+00,  2.7117e+00,  2.7157e+00,  2.7197e+00,
         2.7237e+00,  2.7277e+00,  2.7317e+00,  2.7357e+00,  2.7397e+00,
         2.7437e+00,  2.7477e+00,  2.7518e+00,  2.7558e+00,  2.7598e+00,
         2.7638e+00,  2.7678e+00,  2.7718e+00,  2.7758e+00,  2.7798e+00,
         2.7838e+00,  2.7878e+00,  2.7918e+00,  2.7958e+00,  2.7998e+00,
         2.8038e+00,  2.8078e+00,  2.8118e+00,  2.8158e+00,  2.8198e+00,
         2.8238e+00,  2.8278e+00,  2.8318e+00,  2.8358e+00,  2.8398e+00,
         2.8438e+00,  2.8478e+00,  2.8519e+00,  2.8559e+00,  2.8599e+00,
         2.8639e+00,  2.8679e+00,  2.8719e+00,  2.8759e+00,  2.8799e+00,
         2.8839e+00,  2.8879e+00,  2.8919e+00,  2.8959e+00,  2.8999e+00,
         2.9039e+00,  2.9079e+00,  2.9119e+00,  2.9159e+00,  2.9199e+00,
         2.9239e+00,  2.9279e+00,  2.9319e+00,  2.9359e+00,  2.9399e+00,
         2.9439e+00,  2.9479e+00,  2.9520e+00,  2.9560e+00,  2.9600e+00,
         2.9640e+00,  2.9680e+00,  2.9720e+00,  2.9760e+00,  2.9800e+00,
         2.9840e+00,  2.9880e+00,  2.9920e+00,  2.9960e+00,  3.0000e+00])
x_range_mask = (x_range >= a) & (x_range <= b)
y = torch.zeros_like(x_range)
y[x_range_mask] = dist.log_prob(x_range[x_range_mask]).exp()

plt.plot(x_range, y)

Modeling quantisation error using uniform distribution

Quantization error is the error that arises when representing continuous signals with discrete signals.

NOTE: I am using a simplified convention here. Study DSP for a rigorous treatment (with N and T used in the equations).

Let original signal represented in computer be \(y(t)\). We will quantize the signal to \(x(t)\). The quantization error is given by:

\[ e(t) = y(t) - x(t). \]

We will quantize the signal to \(x(t)\) such that \(x(t)\) can take on \(N\) discrete values. Let \(\Delta\) be the quantization step size. The quantization error is given by:

If the quantization error is uniformly distributed between \(-\Delta/2\) and \(\Delta/2\), then we can model the quantization error as a uniform random variable.

\[ e(t) \sim \text{Uniform}(-\Delta/2, \Delta/2). \]

x_t = torch.linspace(-2, 8, 5000)
y_t = torch.sin(x_t)
plt.plot(x_t, y_t, label="y_t")

max_y_t = y_t.max()
min_y_t = y_t.min()

# Divide the range of y_t into N (=10) equal parts
N = 10
y_bins = torch.linspace(min_y_t, max_y_t, N+1)

# Draw the N levels as horizontal lines
for y_level in y_bins:
    plt.axhline(y_level, color='gray', linestyle='--')
    plt.text(3, y_level+.01, f"{y_level:.2f}")
delta = (max_y_t - min_y_t)/N

# For x = 3, find the bin in which y_t falls
y_t_x_3 = torch.sin(torch.tensor(3))

plt.plot(x_t, y_t, label="y_t")
plt.axvline(3, color='red', linestyle='--')
plt.axhline(y_t_x_3, color='red', linestyle='--')

# Draw the N levels as horizontal lines
for y_level in y_bins:
    plt.axhline(y_level, color='gray', linestyle='--', alpha=0.2)
    plt.text(3, y_level+.01, f"{y_level:.2f}")
# Find the bin in which y_t falls
bin_idx = torch.searchsorted(y_bins, y_t_x_3)
plt.axhline(y_bins[bin_idx], color='green', linestyle='--', label="Closest level")

y_t.shape, y_bins.shape
(torch.Size([5000]), torch.Size([11]))
# Find closest bin for each y_t
bins = (y_t - y_bins.reshape(-1, 1)).abs().argmin(dim=0)

y_binned = y_bins[bins]
plt.plot(x_t, y_t, label="y_t")
plt.plot(x_t, y_binned, label="y_binned")

plt.plot(y_t - y_binned)

_ = plt.hist(y_t - y_binned, density=True)
plt.xlabel("Error in binning")

theoretical_uniform = torch.distributions.Uniform(-delta/2, delta/2)
x = torch.linspace(-0.1, 0.1, 1000)
x_mask = (x >= -delta/2) & (x <= delta/2)
y = torch.zeros_like(x)
y[x_mask] = theoretical_uniform.log_prob(x[x_mask]).exp()
plt.plot(x, y, label="Theoretical PDF")

Amount of bits saved

Originally, each sample of the signal was represented using \(B=32\) bits. After quantization, each sample is represented using \(B_q = \log_2(10)\) bits. The amount of bits saved is given by:

\[ (B - B_q) \times \text{number of samples}. \]

from pydub import AudioSegment
import numpy as np

# Load MP3
audio = AudioSegment.from_mp3("vlog-music.mp3")

# Convert to NumPy array
samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
sample_rate = audio.frame_rate

print(f"Samples shape: {samples.shape}")
print(f"Sample rate: {sample_rate}")
Samples shape: (5430528,)
Sample rate: 44100
from IPython.display import Audio

# Plot 2nd second to 5th second
filtered_audio  = samples[sample_rate*2:sample_rate*5]
fig, ax = plt.subplots(figsize=(20, 5))


# Quantize to 10 levels
min_audio = filtered_audio.min()
max_audio = filtered_audio.max()

N = 10

audio_bins = torch.linspace(min_audio, max_audio, N+1)
# Plotting audio bins
for audio_bin in audio_bins:
    plt.axhline(audio_bin, color='gray', linestyle='--', alpha=0.5)

# Quantize the audio 

audio_bins = np.linspace(min_audio, max_audio, N+1)

# Find closest bin for each audio sample
bins = np.abs(filtered_audio - audio_bins.reshape(-1, 1)).argmin(0)

audio_binned = audio_bins[bins]

fig, ax = plt.subplots(figsize=(20, 5))
ax.plot(filtered_audio, label="Original audio", alpha=0.2)
ax.plot(audio_binned, label="Quantized audio")

Audio(filtered_audio, rate=sample_rate)
# Play the quantized audio
Audio(audio_binned, rate=sample_rate)

Beta Distribution

Let \(X\) be a random variable that follows a beta distribution with parameters \(\alpha\) and \(\beta\). The probability density function (PDF) of \(X\) is given by:

\[ f_X(x) = \frac{\Gamma(\alpha + \beta)}{\Gamma(\alpha)\Gamma(\beta)} x^{\alpha-1} (1-x)^{\beta-1}, \]

where \(\Gamma(\cdot)\) is the gamma function given as:

\[ \Gamma(z) = \int_0^\infty t^{z-1} e^{-t} dt. \]

We can say that \(X \sim \text{Beta}(\alpha, \beta)\).

beta_dist = torch.distributions.Beta(2, 2)
Interval(lower_bound=0.0, upper_bound=1.0)

Beta distribution is used to model random variables that are constrained to lie within a fixed interval. For example, the probability of success in a Bernoulli trial is a random variable that lies in the interval \([0, 1]\). We can model this probability using a beta distribution.

x_lin = torch.linspace(0.001, 0.999, 1000)
for alpha in range(1, 3):
    for beta in range(1, 3):
        beta_dist = torch.distributions.Beta(alpha, beta)
        y = beta_dist.log_prob(x_lin).exp()
        plt.plot(x_lin, y, label=f"PDF Beta({alpha}, {beta})")


Exponential Distribution

Let \(X\) be a random variable that follows an exponential distribution with rate parameter \(\lambda\). The probability density function (PDF) of \(X\) is given by:

$$ f_X(x) = \[\begin{cases} \lambda \exp(-\lambda x) & \text{if } x \geq 0, \\ 0 & \text{otherwise}. \end{cases}\]


We can say that \(X \sim \text{Exponential}(\lambda)\).

The exponential distribution may be viewed as a continuous counterpart of the geometric distribution, which describes the number of Bernoulli trials necessary for a discrete process to change state. In contrast, the exponential distribution describes the time for a continuous process to change state.

l = 5.0
dist = torch.distributions.Exponential(l)
# Plotting the PDF
x_range = torch.linspace(0.001, 10, 1000)
y = dist.log_prob(x_range).exp()
plt.plot(x_range, y)