Cumulative Distribution Function (CDF)

ML
Author

Nipun Batra

Published

March 5, 2025

import matplotlib.pyplot as plt
import numpy as np
print(np.__version__)
import torch 
import torch.nn as nn

import pandas as pd
# Retina mode
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
2.2.3

Generating random numbers from a categorical distribution

probs = torch.tensor([0.1, 0.2, 0.3, 0.4])
print(probs)
tensor([0.1000, 0.2000, 0.3000, 0.4000])
unif = torch.distributions.uniform.Uniform(0, 1)
print(unif.sample())
tensor(0.9147)
cum_sum_prob = torch.cumsum(probs, dim=0)
print(cum_sum_prob)
tensor([0.1000, 0.3000, 0.6000, 1.0000])
symbols  = torch.tensor([1, 2, 3, 4])
print(symbols)
sample = unif.sample()
print(sample)
if cum_sum_prob[0] > sample:
    print(symbols[0])
elif cum_sum_prob[1] > sample:
    print(symbols[1])
elif cum_sum_prob[2] > sample:
    print(symbols[2])
else:
    print(symbols[3])
tensor([1, 2, 3, 4])
tensor(0.1879)
tensor(2)
sample <= cum_sum_prob
tensor([False,  True,  True,  True])
symbols[sample < cum_sum_prob][0]
tensor(2)
### Even more efficient
index = torch.searchsorted(cum_sum_prob, sample)
print(symbols[index])
tensor(2)
### Vectorized
num_samples = 1000
unif_samples = unif.sample((num_samples,))

index = torch.searchsorted(cum_sum_prob, unif_samples)
our_samples = symbols[index]
print(our_samples)
tensor([1, 3, 3, 1, 2, 3, 2, 3, 1, 2, 4, 1, 3, 2, 2, 4, 4, 3, 3, 4, 2, 4, 4, 3,
        4, 4, 2, 4, 4, 4, 2, 4, 4, 1, 4, 3, 3, 3, 4, 4, 2, 3, 3, 1, 4, 4, 4, 3,
        3, 4, 1, 4, 4, 1, 3, 4, 4, 3, 4, 2, 2, 3, 1, 3, 4, 2, 4, 4, 3, 4, 4, 4,
        2, 3, 3, 3, 4, 4, 4, 2, 2, 3, 3, 3, 4, 2, 2, 4, 2, 3, 4, 3, 1, 4, 4, 2,
        4, 4, 4, 4, 3, 4, 4, 1, 3, 3, 1, 2, 1, 1, 2, 3, 4, 3, 3, 3, 3, 2, 4, 4,
        4, 3, 1, 3, 1, 2, 4, 4, 3, 4, 1, 2, 3, 4, 4, 2, 2, 3, 2, 2, 3, 4, 4, 4,
        4, 4, 2, 2, 4, 3, 2, 4, 4, 3, 4, 3, 3, 2, 4, 4, 3, 3, 2, 4, 4, 2, 4, 4,
        2, 4, 3, 3, 2, 4, 4, 3, 4, 2, 4, 4, 3, 3, 4, 2, 2, 4, 3, 1, 2, 4, 4, 3,
        3, 4, 3, 4, 2, 4, 3, 4, 3, 2, 4, 3, 3, 4, 2, 1, 4, 3, 2, 4, 1, 3, 4, 4,
        3, 2, 4, 4, 3, 1, 4, 2, 4, 4, 3, 4, 3, 4, 4, 4, 3, 3, 4, 4, 4, 4, 3, 3,
        3, 3, 4, 4, 3, 3, 1, 4, 4, 3, 4, 4, 3, 3, 2, 4, 4, 4, 2, 2, 4, 4, 4, 3,
        2, 4, 4, 2, 2, 4, 4, 4, 4, 2, 2, 2, 4, 4, 4, 4, 2, 3, 3, 3, 1, 4, 3, 4,
        4, 1, 4, 1, 4, 4, 3, 3, 3, 3, 4, 4, 1, 3, 3, 4, 3, 4, 4, 1, 3, 1, 4, 2,
        3, 3, 3, 1, 1, 4, 4, 4, 4, 4, 4, 2, 3, 3, 3, 3, 4, 4, 2, 4, 3, 2, 2, 4,
        4, 4, 4, 3, 4, 3, 4, 3, 4, 2, 3, 3, 2, 4, 3, 4, 4, 3, 4, 4, 1, 3, 3, 2,
        4, 4, 3, 4, 1, 3, 3, 1, 4, 3, 3, 4, 4, 4, 4, 4, 3, 4, 4, 4, 2, 3, 4, 4,
        2, 2, 3, 3, 2, 3, 3, 4, 4, 1, 2, 3, 4, 4, 4, 3, 4, 3, 3, 3, 4, 3, 4, 4,
        4, 1, 3, 4, 2, 2, 2, 4, 1, 4, 4, 2, 4, 2, 4, 1, 3, 2, 3, 3, 4, 4, 3, 2,
        3, 4, 4, 4, 3, 4, 3, 2, 4, 3, 3, 4, 3, 2, 4, 2, 4, 3, 2, 3, 4, 4, 4, 2,
        4, 3, 4, 3, 3, 4, 2, 3, 2, 3, 1, 3, 4, 3, 4, 4, 3, 1, 3, 3, 4, 3, 1, 3,
        1, 3, 1, 1, 4, 4, 3, 4, 4, 2, 1, 3, 2, 3, 1, 3, 4, 3, 4, 1, 4, 3, 2, 3,
        3, 2, 3, 4, 2, 1, 3, 2, 4, 3, 2, 1, 3, 2, 4, 1, 3, 3, 3, 4, 3, 4, 4, 4,
        1, 4, 4, 4, 4, 2, 2, 4, 4, 4, 4, 3, 1, 2, 4, 4, 3, 4, 2, 4, 3, 2, 3, 3,
        3, 2, 4, 4, 2, 1, 4, 4, 4, 2, 3, 3, 2, 1, 1, 2, 4, 3, 4, 3, 2, 2, 3, 4,
        2, 4, 3, 4, 4, 4, 3, 4, 4, 3, 3, 2, 4, 1, 4, 2, 4, 2, 4, 4, 4, 2, 2, 4,
        2, 4, 2, 3, 3, 4, 2, 2, 1, 4, 3, 4, 4, 2, 2, 2, 4, 1, 3, 4, 4, 4, 3, 3,
        2, 2, 2, 3, 4, 2, 4, 4, 4, 3, 2, 4, 4, 4, 3, 4, 3, 3, 1, 1, 2, 3, 3, 1,
        2, 3, 4, 2, 2, 1, 4, 4, 4, 2, 4, 4, 3, 2, 3, 4, 1, 4, 4, 2, 3, 4, 4, 3,
        4, 3, 4, 3, 4, 3, 4, 1, 4, 3, 4, 4, 3, 1, 4, 4, 3, 2, 3, 4, 3, 3, 3, 2,
        3, 2, 4, 3, 3, 3, 3, 2, 3, 2, 4, 3, 3, 4, 2, 4, 3, 1, 4, 2, 4, 3, 4, 3,
        2, 1, 1, 4, 3, 2, 3, 2, 3, 2, 3, 4, 4, 4, 2, 4, 4, 3, 4, 2, 3, 1, 4, 1,
        4, 3, 4, 3, 4, 4, 3, 4, 1, 1, 2, 3, 4, 3, 4, 3, 2, 3, 3, 4, 4, 1, 3, 4,
        2, 4, 2, 2, 4, 3, 4, 4, 3, 4, 4, 3, 3, 3, 4, 4, 4, 2, 2, 3, 4, 3, 3, 2,
        4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 2, 4, 3, 2, 2, 2, 4, 4, 3, 3, 1, 3, 3, 4,
        2, 4, 4, 4, 2, 4, 2, 4, 1, 3, 2, 3, 3, 4, 4, 4, 3, 3, 4, 1, 4, 3, 3, 4,
        4, 1, 2, 2, 2, 2, 4, 3, 3, 4, 2, 1, 3, 2, 3, 3, 1, 4, 1, 4, 3, 4, 3, 4,
        1, 3, 4, 4, 2, 4, 4, 3, 3, 2, 4, 4, 4, 3, 2, 3, 4, 3, 4, 4, 4, 3, 4, 4,
        3, 4, 3, 3, 4, 4, 4, 2, 4, 3, 4, 1, 1, 3, 4, 4, 4, 3, 3, 4, 2, 4, 4, 3,
        3, 3, 3, 3, 3, 4, 4, 2, 4, 4, 4, 1, 4, 4, 4, 3, 4, 3, 2, 2, 4, 2, 3, 3,
        4, 2, 1, 4, 3, 4, 2, 3, 3, 2, 4, 1, 4, 4, 3, 2, 4, 3, 1, 1, 3, 4, 4, 4,
        4, 3, 4, 3, 3, 3, 3, 1, 1, 3, 2, 3, 2, 4, 2, 3, 3, 3, 2, 3, 4, 2, 2, 4,
        2, 2, 4, 1, 1, 1, 4, 2, 4, 3, 4, 3, 4, 3, 1, 4])
samples_series = pd.Series(our_samples)
samples_series_norm = samples_series.value_counts(normalize=True)
samples_series_norm.sort_index(inplace=True)
samples_series_norm.plot(kind='bar', rot=0)
for i in range(4):
    plt.axhline(probs[i].item(), color='r', linestyle='--')

## Generalised implementation when .icdf() is available

def inverse_cdf_sampling(distribution, sample_size=10000):
    """Performs inverse CDF sampling for a given torch distribution."""
    U = torch.rand(sample_size)  # Generate uniform samples
    X = distribution.icdf(U)     # Apply inverse CDF (quantile function)
    return X
X = torch.distributions.Normal(0, 1)
samples = inverse_cdf_sampling(X, 1000)
### Use CDF function
our_dist = torch.distributions.Normal(0, 1)
unif_samples = inverse_cdf_sampling(our_dist, 1000)

plt.hist(samples.numpy(), bins=50, density=True)
(array([0.01627178, 0.00813589, 0.00813589, 0.        , 0.01627178,
        0.01627178, 0.01627178, 0.04881535, 0.05695124, 0.04881535,
        0.08135891, 0.06508713, 0.10576658, 0.13831015, 0.14644604,
        0.13831015, 0.16271782, 0.23594084, 0.29289207, 0.38238687,
        0.36611509, 0.34170742, 0.27662029, 0.39865866, 0.3579792 ,
        0.39865866, 0.38238687, 0.37425098, 0.52069702, 0.33357153,
        0.37425098, 0.2684844 , 0.30916326, 0.2440772 , 0.13831041,
        0.20339688, 0.195261  , 0.14644632, 0.13831041, 0.0976305 ,
        0.05695113, 0.04067953, 0.05695135, 0.04881525, 0.04881525,
        0.00813591, 0.00813591, 0.        , 0.        , 0.01627181]),
 array([-3.15819168, -3.03527951, -2.91236734, -2.78945518, -2.66654301,
        -2.54363084, -2.42071867, -2.2978065 , -2.17489433, -2.05198216,
        -1.92907   , -1.80615783, -1.68324566, -1.56033349, -1.43742132,
        -1.31450915, -1.19159698, -1.06868482, -0.94577265, -0.82286048,
        -0.69994831, -0.57703614, -0.45412397, -0.33121181, -0.20829964,
        -0.08538747,  0.0375247 ,  0.16043687,  0.28334904,  0.40626121,
         0.52917337,  0.65208554,  0.77499771,  0.89791012,  1.02082205,
         1.14373398,  1.26664639,  1.38955879,  1.51247072,  1.63538265,
         1.75829506,  1.88120747,  2.0041194 ,  2.12703133,  2.24994373,
         2.37285614,  2.49576807,  2.61868   ,  2.74159241,  2.86450481,
         2.98741674]),
 <BarContainer object of 50 artists>)