import torch
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from ipywidgets import interact, FloatSlider
def visualize_bivariate_gaussian(mu_0, mu_1, cov_00, cov_01, cov_10, cov_11):
mean = torch.tensor([mu_0, mu_1])
covariance_matrix = torch.tensor([[cov_00, cov_01],
[cov_10, cov_11]])
bivariate_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix)
N = 200
theta_0 = torch.linspace(- 3 , 3 , N)
theta_1 = torch.linspace(- 3 , 3 , N)
Theta_0, Theta_1 = torch.meshgrid(theta_0, theta_1)
pos = torch.stack((Theta_0, Theta_1), dim= 2 )
density = torch.exp(bivariate_dist.log_prob(pos))
custom_cmap = cm.get_cmap('viridis' )
fig, axs = plt.subplots(1 , 2 , figsize= (12 , 6 ))
contour = axs[0 ].contourf(Theta_0, Theta_1, density, cmap= custom_cmap, levels= 20 )
fig.colorbar(contour, ax= axs[0 ])
axs[0 ].set_xlabel('θ \u2080 ' )
axs[0 ].set_ylabel('θ \u2081 ' )
axs[0 ].set_title(f'Bivariate Gaussian Contour \n μ = { mean. tolist()} , Covariance = { covariance_matrix. tolist()} ' )
axs[0 ].set_aspect('equal' ) # Set equal aspect ratio
axs[1 ] = fig.add_subplot(122 , projection= '3d' )
surface = axs[1 ].plot_surface(Theta_0, Theta_1, density, cmap= custom_cmap)
axs[1 ].set_xlabel('θ \u2080 ' )
axs[1 ].set_ylabel('θ \u2081 ' )
axs[1 ].set_zlabel('Density' )
axs[1 ].set_title(f'Bivariate Gaussian Surface \n μ = { mean. tolist()} , Covariance = { covariance_matrix. tolist()} ' )
plt.tight_layout()
plt.show()