Skip to content

Use cuDNN for row-scaled NVFP4 grouped GEMM#3042

Draft
zianglih wants to merge 2 commits into
NVIDIA:mainfrom
zianglih:codex/cudnn-row-scale-nvfp4-grouped-gemm
Draft

Use cuDNN for row-scaled NVFP4 grouped GEMM#3042
zianglih wants to merge 2 commits into
NVIDIA:mainfrom
zianglih:codex/cudnn-row-scale-nvfp4-grouped-gemm

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

Summary

  • route row-scaled NVFP4 grouped GEMM through cuDNN grouped GEMM quant
  • remove the per-GEMM fallback for the row-scaled grouped path so unsupported cases fail explicitly
  • tighten NVFP4 grouped GEMM tests to cover the cuDNN wrapper path, a supported functional case, and an unsupported no-fallback case

Required dependency

This PR explicitly requires the corresponding cuDNN Frontend feature in NVIDIA/cudnn-frontend#251. It requires a cudnn-frontend version whose cudnn.grouped_gemm_quant_wrapper_sm100(...) accepts row_scale_tensor; without that cudnn-fe PR feature, this TransformerEngine PR is expected to fail on the row-scaled grouped GEMM path.

Motivation

Related to the row-scaled NVFP4 work in #2931. This PR is intended to land only after TransformerEngine can depend on a cudnn-fe version containing the row-scaled grouped GEMM quant feature.

Validation

  • python3 -m py_compile transformer_engine/pytorch/cpp_extensions/gemm.py tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
  • git diff --check -- transformer_engine/pytorch/cpp_extensions/gemm.py tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
  • pre-commit run --all-files
  • B200 devbox: installed cudnn-fe branch from Add row-scale support to grouped GEMM quant cudnn-frontend#251 and verified row_scale_tensor is in grouped_gemm_quant_wrapper_sm100
  • B200 devbox: built and installed TransformerEngine with NVTE_FRAMEWORK=pytorch NVTE_CUDA_ARCHS=100a MAX_JOBS=4
  • B200 devbox: pytest -q tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py::test_nvfp4_row_scaled_grouped_gemm_uses_cudnn_quant_wrapper --tb=short passed: 2 passed
  • B200 devbox: supported BF16 case passed: test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm[mae_err-default-single_output-no_bias-torch.bfloat16-torch.bfloat16-torch.bfloat16-m_splits4-1024-1024]
  • B200 devbox: unsupported no-fallback case passed: test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm[mae_err-default-list_output-no_bias-torch.float32-torch.float32-torch.float32-m_splits0-128-128]
  • B200 devbox: python3 -m pylint transformer_engine/pytorch/cpp_extensions/gemm.py passed: 10.00/10

zianglih added 2 commits May 22, 2026 21:35
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant