Fix GaussianStateSpace.variance shape when batch comes from initial_value#2205
Merged
fehiepsi merged 1 commit intoJun 8, 2026
Conversation
…alue GaussianStateSpace.variance was computed from scale_tril alone, dropping any batch dimension contributed by initial_value. When the batch shape originates from initial_value (e.g. scale_tril (2,2) with initial_value (3,2)), variance came out (num_steps, state_dim) instead of batch_shape + event_shape. Broadcast the result to batch_shape + event_shape. The initial value is deterministic and does not affect variance values, only the shape, so broadcasting is exact. Also add a uniform moment-shape contract check in test_mean_var asserting mean/variance have shape batch_shape + event_shape for every distribution. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
GaussianStateSpace.variancewas computed fromscale_trilalone, which drops any batch dimension contributed byinitial_value. When the batch shape originates frominitial_value(e.g.scale_trilof shape(2, 2)withinitial_valueof shape(3, 2)),variancecame out as(num_steps, state_dim)instead of the requiredbatch_shape + event_shape.Concretely, for
num_steps=5, transition(2, 2),scale_tril(2, 2),initial_value(3, 2):batch_shape = (3,),event_shape = (5, 2)→ expected moment shape(3, 5, 2)variancereturned(5, 2)Fix
Broadcast the variance to
batch_shape + event_shape. The initial value is deterministic and does not affect the variance values, only the shape, so broadcasting is exact.Tests
Added a uniform moment-shape contract check in
test_mean_varasserting thatmean/variance(where implemented) have shapebatch_shape + event_shapefor every distribution. With the fix, the fulltest_mean_varsuite passes (162 passed, 61 skipped, 3 xpassed, 0 failed).🤖 Generated with Claude Code