import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import pyro
pyro.clear_param_store()
def ppca_model(data, latent_dim):
N, data_dim = data.shape
W = pyro.sample(
"W",
dist.Normal(
loc=torch.zeros([latent_dim, data_dim]),
scale=5.0 * torch.ones([latent_dim, data_dim]),
),
)
Z = pyro.sample(
"Z",
dist.Normal(
loc=torch.zeros([N, latent_dim]),
scale=torch.ones([N, latent_dim]),
),
)
mean = Z @ W
return pyro.sample("obs", pyro.distributions.Normal(mean, 1.0), obs=data)
pyro.render_model(
ppca_model, model_args=(torch.randn(150, 2), 1), render_distributions=True
)