Skip to content

Add MXFP8 attention unit test with linear and rope layers#3033

Open
layalir wants to merge 10 commits into
NVIDIA:mainfrom
layalir:add_linear_mxfp8_unit_test
Open

Add MXFP8 attention unit test with linear and rope layers#3033
layalir wants to merge 10 commits into
NVIDIA:mainfrom
layalir:add_linear_mxfp8_unit_test

Conversation

@layalir
Copy link
Copy Markdown

@layalir layalir commented May 22, 2026

Add a DSv3-shaped MXFP8 attention unit test covering the training path:

  • Adds MLA RoPE utilities for the DSv3 671B attention shape.
  • Adds an end-to-end MXFP8 path: Linear(QKV) -> MLA RoPE -> DotProductAttention -> Linear(out).
  • Exercises MXFP8 forward and backward through TE's real DotProductAttention wrapper.
  • Runs BF16 reference comparison by default.
  • Runs the performance benchmark by default and reports fprop and bprop timing separately from the same benchmark collection.

Validation

Local checks:

  • python -m py_compile tests/pytorch/attention/test_linear_mxfp8_attention.py tests/pytorch/attention/mla_rope_utils.py
  • git diff --check

GB300 dlcluster validation:

  • Job: 1062811
  • GPU: NVIDIA GB300
  • CUDA capability: (10, 3)
  • cuDNN: (9, 21, 1)
  • MXFP8 available: (True, '')
  • Command: python -m pytest tests/pytorch/attention/test_linear_mxfp8_attention.py -v -s
  • Result: 3 passed

Perf output:

[PERF] b=1 s=4096:
  BF16 fprop:  7.219 ms  (567397 tok/s)
  BF16 bprop:  15.179 ms  (269844 tok/s)
  MXFP8 fprop: 4.718 ms  (868181 tok/s)
  MXFP8 bprop: 9.215 ms  (444492 tok/s)
  Fprop speedup: 1.53x
  Bprop speedup: 1.65x

layalir and others added 6 commits May 21, 2026 07:56
Forward/backward Triton kernels for DSv3 671B MLA RoPE, ported from
Megatron-LM fused_mla_yarn_rope_apply.py. Falls back to PyTorch when
Triton is unavailable.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Tests: Linear(QKV, MXFP8) -> MLA-RoPE -> DotProductAttention(MXFP8) -> Linear(out, MXFP8)
against a BF16 baseline for accuracy, backward correctness, and performance.

Dimensions: hidden=16384, heads=128, dqk=192 (nope=128+rope=64), dv=128, s=4096, b=1.

Weight quantization is amortized via is_first_microbatch caching
(pre-quantized weights reused each iteration).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 22, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR adds a DSv3-shaped MXFP8 attention unit test covering the full training path (Linear(QKV) → MLA RoPE → DotProductAttention → Linear(out)) plus the companion Triton/PyTorch MLA RoPE utilities. The implementation is well-structured and validated on GB300 hardware.

  • mla_rope_utils.py: New Triton kernels (rotary_fwd/bwd_q_kernel, rotary_fwd/bwd_kv_kernel) implement the interleaved-to-NeoX RoPE rotation for DSv3 MLA; a pure-PyTorch fallback is provided when Triton is unavailable. The backward correctly reconstructs the key-positional-embedding gradient by accumulating across all heads and routing it only to head 0's rope slice.
  • test_linear_mxfp8_attention.py: Three test cases exercise accuracy (loose FP8-appropriate tolerances), gradient flow (NaN/Inf checks), and throughput (speedup assertion vs BF16 reference), parameterised over the fixed b=1, s=4096 DSv3 shape.

Confidence Score: 5/5

This PR is safe to merge — it adds test-only files with no changes to library source code, and the Triton kernel logic is correct for the fixed 128-head DSv3 configuration it targets.

Both files are purely test infrastructure; no production library code is modified. The Triton RoPE kernels are logically correct for the specific 128-head count used in the test (128 is divisible by every autotune BLOCK_H candidate, so the partial-block boundary condition never triggers). The overall forward/backward autograd wiring is sound, and the benchmark and accuracy thresholds are appropriate for FP8 arithmetic.

No files require special attention for merging; both are new test-only additions with no impact on library users.

Important Files Changed

Filename Overview
tests/pytorch/attention/mla_rope_utils.py New Triton implementation of MLA RoPE for DSv3 671B attention, providing both forward and backward kernels with a pure-PyTorch fallback; the partial-head-block mask bug (already flagged) is the main latent risk, but is harmless for the fixed 128-head DSv3 shape.
tests/pytorch/attention/test_linear_mxfp8_attention.py End-to-end MXFP8 accuracy, backward, and benchmark tests covering the QKV→RoPE→DPA→out pipeline; three separate fp8_autocast contexts (already flagged) and warmup-backward cache-invalidation (already flagged) are the open concerns.

Sequence Diagram

sequenceDiagram
    participant X as Input x [s,b,H]
    participant QKV as te.Linear QKV
    participant SPL as _split_qkv
    participant RoPE as apply_mla_rope (Triton/PyTorch)
    participant DPA as te.DotProductAttention
    participant OUT as te.Linear out
    participant Y as Output [s,b,H]

    Note over QKV,OUT: fp8_autocast (3 separate contexts)

    X->>QKV: [s,b,H] fp8_autocast 1
    QKV-->>SPL: "qkv [s,b,2*H_QK+H_V]"
    SPL-->>RoPE: q[s,b,h,192], k[s,b,h,192], v[s,b,h,128]
    Note over RoPE: outside fp8_autocast
    RoPE-->>DPA: q_rot, k_rot, v fp8_autocast 2
    DPA-->>OUT: attn_out[s,b,h,128] fp8_autocast 3
    OUT-->>Y: [s,b,H]

    Note over X,Y: Backward (reverse order)
    Y->>OUT: d_attn_out
    OUT->>DPA: dq, dk, dv
    DPA->>RoPE: _MLARoPETriton.backward
    Note over RoPE: rotary_bwd_q_kernel in-place on dq, rotary_bwd_kv_kernel accumulates dEMB
    RoPE->>SPL: dq_in, dk_in, dv_in
    SPL->>QKV: d_qkv slice backward
    QKV->>X: d_x
Loading

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +118 to +119
x_off = tl.arange(0, BLOCK_H)[:, None] * stride_x_nheads + qk_head_dim
mask = x_off < head_num * stride_x_nheads
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Mask ignores block offset for partial last head-block

The bound check x_off < head_num * stride_x_nheads compares only the intra-block head index i against the total head count, but the pointer Q has already been advanced by pid_head * BLOCK_H * stride_x_nheads. For the last block when head_num % BLOCK_H != 0, the absolute head index pid_head * BLOCK_H + i may exceed head_num while the mask still evaluates to True, causing an out-of-bounds load/store. The same pattern appears in rotary_bwd_q_kernel (line 178) and in the accumulation loop inside rotary_bwd_kv_kernel (line 332). For DSv3's 128 heads all BLOCK_H candidates (1–128) evenly divide 128, so the current test is unaffected, but any future caller with a non-aligned head count would silently corrupt memory.

Comment on lines +189 to +202
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
qkv = qkv_linear(x, is_first_microbatch=is_first_microbatch)

q, k, v = _split_qkv(qkv)
q, k, v = apply_mla_rope(q, k, v)

with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
attn_out = dpa(q, k, v, qkv_format="sbhd")

with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
out = out_linear(
attn_out.view(x.shape[0], x.shape[1], HIDDEN_SIZE),
is_first_microbatch=is_first_microbatch,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Three separate fp8_autocast scopes may reset per-layer FP8 statistics

The forward pipeline is split into three independent fp8_autocast contexts (QKV linear, DPA, out linear). For current MXFP8 block scaling the effect is benign because scales are computed per-block and don't depend on cross-layer statistics, but TE's FP8GlobalStateManager maintains per-forward-pass amax history used by some recipes. Exiting and re-entering the context between layers means each layer is treated as a separate forward pass for bookkeeping purposes, so any inter-layer scale propagation is lost. A single surrounding fp8_autocast context covering all three layers would match the standard training usage and would more faithfully exercise the end-to-end path.

Comment on lines +374 to +378
with torch.no_grad():
_run_forward_mxfp8(mxfp8_modules, x, fp8_recipe, is_first_microbatch=True)

mxfp8_fprop_ms, mxfp8_bprop_ms = _benchmark_training_step(
_run_forward_mxfp8, mxfp8_modules, x, fp8_recipe, False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Warmup backward may invalidate the is_first_microbatch weight cache before timed iterations

The weight cache is seeded with is_first_microbatch=True inside torch.no_grad(), but the 10 warmup iterations inside _benchmark_training_step call backward() on a new computation graph. TE invalidates or refreshes the cached quantized weight buffer after a backward pass (the cache is keyed to the current "microbatch"). Consequently the timed iterations that pass is_first_microbatch=False may miss the cache on every call and silently fall back to per-iteration weight quantization, making the benchmark measure a different workload than described in the docstring.

@cyanguwa
Copy link
Copy Markdown
Collaborator

Thanks for the contribution! Could you please:

  • fix the DCO (there are instructions in the DCO link)
  • address Greptile comments
  • add test_linear_mxfp8_attention.py to qa/L0_pytorch_unittest/test.sh, similar to this

)
flash_attn_supported, fused_attn_supported_fp8, _ = fp8_backends
if flash_attn_supported + fused_attn_supported_fp8 < 1:
pytest.skip("No FP8 attention backend available for DSv3 MLA shape.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check "if not fused_attn_supported_fp8" here. FlashAttention doesn't support MXFP8, so it's kind of relevant here.

_attention_backends["backend_selection_requires_update"] = True


def _split_qkv(qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work for MXFP8 tensors? Should we do qkv.to() first before splitting?

qkv_mxfp8,
qkv_bf16,
atol=2.0,
rtol=0.5,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tolerances are a bit too high, aren't they? For regular MXFP8 tests, I've been using atol = 5e-1, rtol = 5e-2. Would they work for these tests?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants