Add MXFP8 attention unit test with linear and rope layers#3033
Add MXFP8 attention unit test with linear and rope layers#3033layalir wants to merge 10 commits into
Conversation
Forward/backward Triton kernels for DSv3 671B MLA RoPE, ported from Megatron-LM fused_mla_yarn_rope_apply.py. Falls back to PyTorch when Triton is unavailable. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Tests: Linear(QKV, MXFP8) -> MLA-RoPE -> DotProductAttention(MXFP8) -> Linear(out, MXFP8) against a BF16 baseline for accuracy, backward correctness, and performance. Dimensions: hidden=16384, heads=128, dqk=192 (nope=128+rope=64), dv=128, s=4096, b=1. Weight quantization is amortized via is_first_microbatch caching (pre-quantized weights reused each iteration). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a DSv3-shaped MXFP8 attention unit test covering the full training path (
Confidence Score: 5/5This PR is safe to merge — it adds test-only files with no changes to library source code, and the Triton kernel logic is correct for the fixed 128-head DSv3 configuration it targets. Both files are purely test infrastructure; no production library code is modified. The Triton RoPE kernels are logically correct for the specific 128-head count used in the test (128 is divisible by every autotune BLOCK_H candidate, so the partial-block boundary condition never triggers). The overall forward/backward autograd wiring is sound, and the benchmark and accuracy thresholds are appropriate for FP8 arithmetic. No files require special attention for merging; both are new test-only additions with no impact on library users. Important Files Changed
Sequence DiagramsequenceDiagram
participant X as Input x [s,b,H]
participant QKV as te.Linear QKV
participant SPL as _split_qkv
participant RoPE as apply_mla_rope (Triton/PyTorch)
participant DPA as te.DotProductAttention
participant OUT as te.Linear out
participant Y as Output [s,b,H]
Note over QKV,OUT: fp8_autocast (3 separate contexts)
X->>QKV: [s,b,H] fp8_autocast 1
QKV-->>SPL: "qkv [s,b,2*H_QK+H_V]"
SPL-->>RoPE: q[s,b,h,192], k[s,b,h,192], v[s,b,h,128]
Note over RoPE: outside fp8_autocast
RoPE-->>DPA: q_rot, k_rot, v fp8_autocast 2
DPA-->>OUT: attn_out[s,b,h,128] fp8_autocast 3
OUT-->>Y: [s,b,H]
Note over X,Y: Backward (reverse order)
Y->>OUT: d_attn_out
OUT->>DPA: dq, dk, dv
DPA->>RoPE: _MLARoPETriton.backward
Note over RoPE: rotary_bwd_q_kernel in-place on dq, rotary_bwd_kv_kernel accumulates dEMB
RoPE->>SPL: dq_in, dk_in, dv_in
SPL->>QKV: d_qkv slice backward
QKV->>X: d_x
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| x_off = tl.arange(0, BLOCK_H)[:, None] * stride_x_nheads + qk_head_dim | ||
| mask = x_off < head_num * stride_x_nheads |
There was a problem hiding this comment.
Mask ignores block offset for partial last head-block
The bound check x_off < head_num * stride_x_nheads compares only the intra-block head index i against the total head count, but the pointer Q has already been advanced by pid_head * BLOCK_H * stride_x_nheads. For the last block when head_num % BLOCK_H != 0, the absolute head index pid_head * BLOCK_H + i may exceed head_num while the mask still evaluates to True, causing an out-of-bounds load/store. The same pattern appears in rotary_bwd_q_kernel (line 178) and in the accumulation loop inside rotary_bwd_kv_kernel (line 332). For DSv3's 128 heads all BLOCK_H candidates (1–128) evenly divide 128, so the current test is unaffected, but any future caller with a non-aligned head count would silently corrupt memory.
| with te.fp8_autocast(enabled=True, fp8_recipe=recipe): | ||
| qkv = qkv_linear(x, is_first_microbatch=is_first_microbatch) | ||
|
|
||
| q, k, v = _split_qkv(qkv) | ||
| q, k, v = apply_mla_rope(q, k, v) | ||
|
|
||
| with te.fp8_autocast(enabled=True, fp8_recipe=recipe): | ||
| attn_out = dpa(q, k, v, qkv_format="sbhd") | ||
|
|
||
| with te.fp8_autocast(enabled=True, fp8_recipe=recipe): | ||
| out = out_linear( | ||
| attn_out.view(x.shape[0], x.shape[1], HIDDEN_SIZE), | ||
| is_first_microbatch=is_first_microbatch, | ||
| ) |
There was a problem hiding this comment.
Three separate
fp8_autocast scopes may reset per-layer FP8 statistics
The forward pipeline is split into three independent fp8_autocast contexts (QKV linear, DPA, out linear). For current MXFP8 block scaling the effect is benign because scales are computed per-block and don't depend on cross-layer statistics, but TE's FP8GlobalStateManager maintains per-forward-pass amax history used by some recipes. Exiting and re-entering the context between layers means each layer is treated as a separate forward pass for bookkeeping purposes, so any inter-layer scale propagation is lost. A single surrounding fp8_autocast context covering all three layers would match the standard training usage and would more faithfully exercise the end-to-end path.
| with torch.no_grad(): | ||
| _run_forward_mxfp8(mxfp8_modules, x, fp8_recipe, is_first_microbatch=True) | ||
|
|
||
| mxfp8_fprop_ms, mxfp8_bprop_ms = _benchmark_training_step( | ||
| _run_forward_mxfp8, mxfp8_modules, x, fp8_recipe, False |
There was a problem hiding this comment.
Warmup backward may invalidate the
is_first_microbatch weight cache before timed iterations
The weight cache is seeded with is_first_microbatch=True inside torch.no_grad(), but the 10 warmup iterations inside _benchmark_training_step call backward() on a new computation graph. TE invalidates or refreshes the cached quantized weight buffer after a backward pass (the cache is keyed to the current "microbatch"). Consequently the timed iterations that pass is_first_microbatch=False may miss the cache on every call and silently fall back to per-iteration weight quantization, making the benchmark measure a different workload than described in the docstring.
|
Thanks for the contribution! Could you please:
|
| ) | ||
| flash_attn_supported, fused_attn_supported_fp8, _ = fp8_backends | ||
| if flash_attn_supported + fused_attn_supported_fp8 < 1: | ||
| pytest.skip("No FP8 attention backend available for DSv3 MLA shape.") |
There was a problem hiding this comment.
I think we should check "if not fused_attn_supported_fp8" here. FlashAttention doesn't support MXFP8, so it's kind of relevant here.
| _attention_backends["backend_selection_requires_update"] = True | ||
|
|
||
|
|
||
| def _split_qkv(qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Does this work for MXFP8 tensors? Should we do qkv.to() first before splitting?
| qkv_mxfp8, | ||
| qkv_bf16, | ||
| atol=2.0, | ||
| rtol=0.5, |
There was a problem hiding this comment.
These tolerances are a bit too high, aren't they? For regular MXFP8 tests, I've been using atol = 5e-1, rtol = 5e-2. Would they work for these tests?
Add a DSv3-shaped MXFP8 attention unit test covering the training path:
Linear(QKV) -> MLA RoPE -> DotProductAttention -> Linear(out).DotProductAttentionwrapper.Validation
Local checks:
python -m py_compile tests/pytorch/attention/test_linear_mxfp8_attention.py tests/pytorch/attention/mla_rope_utils.pygit diff --checkGB300 dlcluster validation:
1062811(10, 3)(9, 21, 1)(True, '')python -m pytest tests/pytorch/attention/test_linear_mxfp8_attention.py -v -s3 passedPerf output: