import numpy as np
import matplotlib.pyplot as plt
import torch
import seaborn as sns
import pandas as pd
import pyro
=pyro.distributions
dist
sns.reset_defaults()="talk", font_scale=1)
sns.set_context(context%matplotlib inline
%config InlineBackend.figure_format='retina'
Basic Imports
= dist.MultivariateNormal(loc = torch.tensor([0., 0.]), covariance_matrix=torch.eye(2)) X
pyro.condition()
tensor([-0.0687, 0.7461])
= 2
data_dim = 1
latent_dim = 100
num_datapoints = dist.Normal(
z =torch.zeros([latent_dim, num_datapoints]),
loc=torch.ones([latent_dim, num_datapoints]),)
scale
= dist.Normal(
w =torch.zeros([data_dim, latent_dim]),
loc=5.0 * torch.ones([data_dim, latent_dim]),
scale )
= w.sample()
w_sample= z.sample()
z_sample
= dist.Normal(loc = w_sample@z_sample, scale=1)
x = x.sample([100])
x_sample 0], x_sample[:, 1], alpha=0.2, s=30) plt.scatter(x_sample[:,
Generative model for PPCA in Pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import pyro
pyro.clear_param_store()
def ppca_model(data, latent_dim):
= data.shape
N, data_dim = pyro.sample(
W "W",
dist.Normal(=torch.zeros([latent_dim, data_dim]),
loc=5.0 * torch.ones([latent_dim, data_dim]),
scale
),
)= pyro.sample(
Z "Z",
dist.Normal(=torch.zeros([N, latent_dim]),
loc=torch.ones([N, latent_dim]),
scale
),
)
= Z @ W
mean
return pyro.sample("obs", pyro.distributions.Normal(mean, 1.0), obs=data)
pyro.render_model(=(torch.randn(150, 2), 1), render_distributions=True
ppca_model, model_args )
0], 3).shape ppca_model(x_sample[
torch.Size([2, 100])
from pyro import poutine
with pyro.plate("samples", 10, dim=-3):
= poutine.trace(ppca_model).get_trace(x_sample[0], 1) trace
'W']['value'].squeeze() trace.nodes[
torch.Size([10, 100])
= 3
data_dim = 2
latent_dim
= pyro.sample(
W "W",
dist.Normal(=torch.zeros([latent_dim, data_dim]),
loc=5.0 * torch.ones([latent_dim, data_dim]),
scale
), )
= 150
N = pyro.sample(
Z "Z",
dist.Normal(=torch.zeros([N, latent_dim]),
loc=torch.ones([N, latent_dim]),
scale
), )
Z.shape, W.shape
(torch.Size([150, 2]), torch.Size([2, 3]))
@W).shape (Z
torch.Size([150, 3])