---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File ~/miniforge3/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:230, in Trace.compute_log_prob(self, site_filter)
229 try:
--> 230 log_p = site["fn"].log_prob(
231 site["value"], *site["args"], **site["kwargs"]
232 )
233 except ValueError as e:
File ~/miniforge3/lib/python3.9/site-packages/torch/distributions/normal.py:73, in Normal.log_prob(self, value)
72 if self._validate_args:
---> 73 self._validate_sample(value)
74 # compute the variance
File ~/miniforge3/lib/python3.9/site-packages/torch/distributions/distribution.py:276, in Distribution._validate_sample(self, value)
275 if i != 1 and j != 1 and i != j:
--> 276 raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
277 format(actual_shape, expected_shape))
278 try:
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([100]) vs torch.Size([200]).
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
Input In [117], in <module>
10 # do gradient steps
11 for step in range(n_steps):
---> 12 svi.step(x, y)
13 if step%110==0:
14 print(pyro.param("theta_0").item(), pyro.param("theta_1").item())
File ~/miniforge3/lib/python3.9/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
143 # get loss and compute gradients
144 with poutine.trace(param_only=True) as param_capture:
--> 145 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
147 params = set(
148 site["value"].unconstrained() for site in param_capture.trace.nodes.values()
149 )
151 # actually perform gradient steps
152 # torch.optim objects gets instantiated for any params that haven't been seen yet
File ~/miniforge3/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
138 loss = 0.0
139 # grab a trace from the generator
--> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
141 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
142 model_trace, guide_trace
143 )
144 loss += loss_particle / self.num_particles
File ~/miniforge3/lib/python3.9/site-packages/pyro/infer/elbo.py:182, in ELBO._get_traces(self, model, guide, args, kwargs)
180 else:
181 for i in range(self.num_particles):
--> 182 yield self._get_trace(model, guide, args, kwargs)
File ~/miniforge3/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
52 def _get_trace(self, model, guide, args, kwargs):
53 """
54 Returns a single trace from the guide, and the model that is run
55 against it.
56 """
---> 57 model_trace, guide_trace = get_importance_trace(
58 "flat", self.max_plate_nesting, model, guide, args, kwargs
59 )
60 if is_validation_enabled():
61 check_if_enumerated(guide_trace)
File ~/miniforge3/lib/python3.9/site-packages/pyro/infer/enum.py:75, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
72 guide_trace = prune_subsample_sites(guide_trace)
73 model_trace = prune_subsample_sites(model_trace)
---> 75 model_trace.compute_log_prob()
76 guide_trace.compute_score_parts()
77 if is_validation_enabled():
File ~/miniforge3/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:236, in Trace.compute_log_prob(self, site_filter)
234 _, exc_value, traceback = sys.exc_info()
235 shapes = self.format_shapes(last_site=site["name"])
--> 236 raise ValueError(
237 "Error while computing log_prob at site '{}':\n{}\n{}".format(
238 name, exc_value, shapes
239 )
240 ).with_traceback(traceback) from e
241 site["unscaled_log_prob"] = log_p
242 log_p = scale_and_mask(log_p, site["scale"], site["mask"])
File ~/miniforge3/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:230, in Trace.compute_log_prob(self, site_filter)
228 if "log_prob" not in site:
229 try:
--> 230 log_p = site["fn"].log_prob(
231 site["value"], *site["args"], **site["kwargs"]
232 )
233 except ValueError as e:
234 _, exc_value, traceback = sys.exc_info()
File ~/miniforge3/lib/python3.9/site-packages/torch/distributions/normal.py:73, in Normal.log_prob(self, value)
71 def log_prob(self, value):
72 if self._validate_args:
---> 73 self._validate_sample(value)
74 # compute the variance
75 var = (self.scale ** 2)
File ~/miniforge3/lib/python3.9/site-packages/torch/distributions/distribution.py:276, in Distribution._validate_sample(self, value)
274 for i, j in zip(reversed(actual_shape), reversed(expected_shape)):
275 if i != 1 and j != 1 and i != j:
--> 276 raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
277 format(actual_shape, expected_shape))
278 try:
279 support = self.support
ValueError: Error while computing log_prob at site 'obs':
Value is not broadcastable with batch_shape+event_shape: torch.Size([100]) vs torch.Size([200]).
Trace Shapes:
Param Sites:
theta_0 1
theta_1 1
Sample Sites:
obs dist 200 |
value 100 |