Skip to content

Fix GaussianStateSpace.variance shape when batch comes from initial_value#2205

Merged
fehiepsi merged 1 commit into
pyro-ppl:masterfrom
tillahoffmann:fix-gaussian-state-space-variance-shape
Jun 8, 2026
Merged

Fix GaussianStateSpace.variance shape when batch comes from initial_value#2205
fehiepsi merged 1 commit into
pyro-ppl:masterfrom
tillahoffmann:fix-gaussian-state-space-variance-shape

Conversation

@tillahoffmann

Copy link
Copy Markdown
Collaborator

Problem

GaussianStateSpace.variance was computed from scale_tril alone, which drops any batch dimension contributed by initial_value. When the batch shape originates from initial_value (e.g. scale_tril of shape (2, 2) with initial_value of shape (3, 2)), variance came out as (num_steps, state_dim) instead of the required batch_shape + event_shape.

Concretely, for num_steps=5, transition (2, 2), scale_tril (2, 2), initial_value (3, 2):

  • batch_shape = (3,), event_shape = (5, 2) → expected moment shape (3, 5, 2)
  • variance returned (5, 2)

Fix

Broadcast the variance to batch_shape + event_shape. The initial value is deterministic and does not affect the variance values, only the shape, so broadcasting is exact.

Tests

Added a uniform moment-shape contract check in test_mean_var asserting that mean/variance (where implemented) have shape batch_shape + event_shape for every distribution. With the fix, the full test_mean_var suite passes (162 passed, 61 skipped, 3 xpassed, 0 failed).

🤖 Generated with Claude Code

…alue

GaussianStateSpace.variance was computed from scale_tril alone, dropping any
batch dimension contributed by initial_value. When the batch shape originates
from initial_value (e.g. scale_tril (2,2) with initial_value (3,2)), variance
came out (num_steps, state_dim) instead of batch_shape + event_shape.

Broadcast the result to batch_shape + event_shape. The initial value is
deterministic and does not affect variance values, only the shape, so
broadcasting is exact.

Also add a uniform moment-shape contract check in test_mean_var asserting
mean/variance have shape batch_shape + event_shape for every distribution.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@Qazalbash Qazalbash requested review from fehiepsi and juanitorduz June 8, 2026 19:22

@juanitorduz juanitorduz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@fehiepsi fehiepsi merged commit 662efd5 into pyro-ppl:master Jun 8, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants