import matplotlib.pyplot as plt
import numpy as np
print(np.__version__)
import torch
import torch.nn as nn
import pandas as pd
# Retina mode
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
1.26.4
Nipun Batra
March 18, 2025
import matplotlib.pyplot as plt
import numpy as np
print(np.__version__)
import torch
import torch.nn as nn
import pandas as pd
# Retina mode
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
1.26.4
dist = torch.distributions.Normal(0, 1)
x = dist.sample((1000,))
plt.hist(x.numpy(), bins=50, density=True)
x_range = torch.linspace(-3, 3, 1000)
y = dist.log_prob(x_range).exp()
plt.plot(x_range.numpy(), y.numpy())
tensor([-1.9083, 0.3758, 0.0051, 0.5140, 0.9852, -0.5989, 0.5222, -0.7744,
0.9462, -1.7868])
dist_2d_normal = torch.distributions.MultivariateNormal(torch.tensor([0.0, 0.0]), torch.eye(2))
#dist_2d_normal = torch.distributions.MultivariateNormal(torch.tensor([0.0, 0.0]), torch.tensor([[1.0, 0.5], [0.5, 1.0]]))
dist_2d_normal.sample([10])
tensor([[ 0.0438, -0.0310],
[ 0.0487, -0.3790],
[-0.7872, 0.9880],
[ 1.0010, -0.9025],
[ 0.5449, 0.1047],
[ 1.6466, 0.0925],
[ 0.9357, 0.2228],
[-1.2721, 2.5194],
[-0.3306, -0.1152],
[ 1.2249, -1.7330]])
# Plot 2D normal distribution surface plot of PDF
from mpl_toolkits.mplot3d import Axes3D
x = torch.linspace(-3, 3, 100)
y = torch.linspace(-3, 3, 100)
X, Y = torch.meshgrid(x, y)
xy = torch.stack([X, Y], 2)
z = dist_2d_normal.log_prob(xy).exp()
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X.numpy(), Y.numpy(), z.numpy())
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('PDF')
fig.tight_layout()
/Users/nipun/base/lib/python3.12/site-packages/torch/functional.py:539: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp:3638.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
Height(Inches) | Weight(Pounds) | |
---|---|---|
1 | 65.78331 | 112.9925 |
2 | 71.51521 | 136.4873 |
3 | 69.39874 | 153.0269 |
4 | 68.21660 | 142.3354 |
5 | 67.78781 | 144.2971 |
# Plot the data
plt.scatter(data[:, 0], data[:, 1], alpha=0.1, color='k', facecolors='k')
plt.xlabel("Height")
plt.ylabel("Weight")
Text(0, 0.5, 'Weight')
# plot the PDF
x = torch.linspace(50, 80, 100)
y = torch.linspace(80, 280, 100)
X, Y = torch.meshgrid(x, y)
xy = torch.stack([X, Y], 2)
z = dist.log_prob(xy).exp()
import plotly.graph_objects as go
# Create surface plot with custom hover labels
fig = go.Figure(data=[go.Surface(
x=X, y=Y, z=z, colorscale="viridis",
hovertemplate="Height: %{x:0.2f}<br>Weight: %{y:0.2f}<br>PDF: %{z:0.5f}<extra></extra>"
)])
# Maximize figure size and reduce whitespace
fig.update_layout(
autosize=True,
width=1200, # Set wider figure
height=700, # Set taller figure
margin=dict(l=0, r=0, t=40, b=0), # Remove extra whitespace
title="2D Gaussian PDF",
scene=dict(
xaxis_title="Height",
yaxis_title="Weight",
zaxis_title="PDF"
)
)
# Show plot
fig.show()
# uniform distribution
dist_uniform = torch.distributions.Uniform(0, 1)
x = dist_uniform.sample((1000,))
plt.hist(x.numpy(), bins=50, density=True)
x_range = torch.linspace(0, 1, 1000)
y = dist_uniform.log_prob(x_range).exp()
plt.plot(x_range.numpy(), y.numpy())
dist_uniform_2d = torch.distributions.Uniform(torch.tensor([0.0, 0.0]), torch.tensor([1.0, 1.0]))
dist_uniform_2d.sample([10])
tensor([[0.5493, 0.3478],
[0.7661, 0.2568],
[0.7199, 0.2975],
[0.9114, 0.2916],
[0.0045, 0.4948],
[0.0156, 0.7434],
[0.6856, 0.1037],
[0.4446, 0.1913],
[0.1995, 0.5009],
[0.0716, 0.6085]])
## Important:
## f(x, y) = f(x) * f(y) for independent random variables
## log(f(x, y)) = log(f(x)) + log(f(y))
z1 = dist_uniform_2d.log_prob(xy).sum(-1).exp()
z2 = dist_uniform.log_prob(X).exp() * dist_uniform.log_prob(Y).exp()
assert torch.allclose(z1, z2)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X.numpy(), Y.numpy(), z1.numpy())
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('PDF')
Text(0.5, 0, 'PDF')