Skip to content

[Common] Fix fused router for large top-K and expert counts#2821

Merged
tdophung merged 9 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_fix_p2
Apr 16, 2026
Merged

[Common] Fix fused router for large top-K and expert counts#2821
tdophung merged 9 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_fix_p2

Conversation

@harryzhou2000
Copy link
Copy Markdown
Member

Description

Fixed fused router support for large topk and num_expert. Now num_expert <=2304 and any topk is supported with reasonable performance.

Current benchmark shows fused topk forward kernel is faster than pytorch at topk=32, which would be around 8x faster than before optimization.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Fix fused MoE router kernels to support large top-K values and large numbers of experts (1024+) by adding a warp-level radix-selection top-K algorithm (O(E), independent of K) alongside the existing naive O(K^2*E) implementation, dispatched at topk >= 16 boundary.
  • Expand dynamic shared memory allocation via cudaFuncSetAttribute in both forward and backward
    kernel launchers to avoid silent failures when expert count exceeds the default 48 KB limit.
  • Rewrite apply_softmax_on_float to use a numerically stable online max+sum accumulation
    (two-pass → single-pass) with NaN-safe warp reduction, eliminating shared-memory round-trips.

Details

Radix top-K selection (utils.h):
Implements a 4-bit radix selection algorithm (8 passes over float32) that finds the K-th largest
value in O(E/32) per warp, independent of K. Phase 1 narrows the bit pattern of the K-th value
via histogram counting; Phase 2 gathers elements into output arrays with deterministic tie-breaking
(value DESC, index ASC) matching torch.topk behavior.
Dispatch logic (fused_topk_with_score_function.cu, fused_score_for_moe_aux_loss.cu):
Template parameter TopkFuncType (Naive/Radix) is selected at launch time based on
topk < 16. Both forward kernels and backward kernels now call cudaFuncSetAttribute to
request the required dynamic shared memory size before launch.
Tests (test_fused_router.py):

  • Add num_experts=1024 to all parametrized test cases.
  • Add _get_tolerances() helper that scales atol/rtol with expert count to account for
    O(N * eps) accumulation divergence between fused and reference implementations.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@harryzhou2000 harryzhou2000 marked this pull request as ready for review April 1, 2026 14:44
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 1, 2026

Greptile Summary

This PR fixes fused MoE router kernels for large top-K values and expert counts by adding a warp-level radix selection algorithm (O(E), independent of K) alongside the existing O(K²E) naive implementation, dispatched at topk >= 16. It also expands dynamic shared memory via cudaFuncSetAttribute to handle expert counts above the default 48 KB limit, rewrites apply_softmax_on_float with a numerically stable single-pass online accumulation, and updates tests to cover the new Radix path (topk=16, 32) and num_experts=1024.

Previous feedback on unchecked cudaFuncSetAttribute return values, untested Radix paths, and the dead dtype guard in _get_tolerances have all been addressed in this revision.

Confidence Score: 5/5

Safe to merge; the radix algorithm, smem expansion, and softmax rewrite are all correct, and the only remaining finding is a minor per-launch performance suggestion.

All P0/P1 concerns from prior review rounds have been addressed: cudaFuncSetAttribute return values are now checked, Radix path is covered by topk=16,32 tests, and _get_tolerances now raises on non-fp32 dtypes. The sole new finding is a P2 performance note about redundant device attribute queries per kernel launch, which does not affect correctness.

transformer_engine/common/fused_router/utils.h — check_shared_memory_capacity_num_experts queries device attributes on every launch; worth caching.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/utils.h Core of the PR: adds radix_topk_and_mask (warp-level 4-bit radix selection, Phase 1 + Phase 2 gather), the topk_and_mask dispatch template, TopkFuncType enum, check_shared_memory_capacity_num_experts host helper, and a rewritten single-pass numerically-stable apply_softmax_on_float. Algorithm is correct; minor concern on per-launch device attribute queries.
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Forward and backward launchers updated: cudaFuncSetAttribute now wrapped with NVTE_CHECK_CUDA, Naive/Radix dispatch added at topk<16 boundary, check_shared_memory_capacity_num_experts called before launch. Looks correct.
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Same launcher pattern as the other .cu file: TopkFuncType template parameter added, Naive/Radix dispatch added, cudaFuncSetAttribute checked, backward launcher also updated with smem capacity check. Looks correct.
transformer_engine/common/utils.cuh Adds float_to_ordered_uint, ordered_uint_to_float (unused, pre-existing concern), and warp_allreduce_sum — all correct standalone device utilities used by the new radix algorithm.
tests/pytorch/test_fused_router.py Adds _get_tolerances helper with fp32 guard, topk=16,32 and num_experts=1024 to all parametrize sets, and pytest.skip guards for impossible (topk >= num_experts) configurations. Radix path is now exercised.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Kernel Launcher] --> B{topk lt 16?}
    B -->|Yes| C[TopkFuncType Naive]
    B -->|No| D[TopkFuncType Radix]
    C --> E[cudaFuncSetAttribute wrapped with NVTE_CHECK_CUDA]
    D --> E
    E --> F[check_shared_memory_capacity_num_experts]
    F --> G[Kernel Launch]
    G --> H[NVTE_CHECK_CUDA cudaGetLastError]
    subgraph radix_topk_and_mask
        I[Phase 1 - 8-pass radix selection] --> J[Phase 2a - Gather strictly greater elements]
        J --> K[Phase 2b - Fill ties in ascending index order]
    end
    subgraph apply_softmax_on_float
        L[Pass 1 - Online max and sum per lane] --> M[Warp butterfly reduction with NaN guard]
        M --> N[Pass 2 - Normalize in-place]
    end
Loading

Reviews (6): Last reviewed commit: "warning about dtype for tolerance in tes..." | Re-trigger Greptile

Comment thread tests/pytorch/test_fused_router.py
Comment thread transformer_engine/common/fused_router/fused_topk_with_score_function.cu Outdated
Comment thread tests/pytorch/test_fused_router.py
Comment thread transformer_engine/common/fused_router/utils.h Outdated
@harryzhou2000 harryzhou2000 marked this pull request as draft April 1, 2026 15:08
@harryzhou2000 harryzhou2000 marked this pull request as ready for review April 1, 2026 15:14
@harryzhou2000 harryzhou2000 changed the title Fix fused router for large top-K and expert counts [Common] Fix fused router for large top-K and expert counts Apr 2, 2026
Comment thread tests/pytorch/test_fused_router.py
Comment thread transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Outdated
Comment thread transformer_engine/common/fused_router/utils.h Outdated
Comment thread transformer_engine/common/fused_router/utils.h Outdated
Comment thread transformer_engine/common/fused_router/utils.h Outdated
@harryzhou2000 harryzhou2000 marked this pull request as draft April 3, 2026 02:53
@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_fix_p2 branch 2 times, most recently from ee33ea2 to fab73d1 Compare April 3, 2026 07:30
@harryzhou2000 harryzhou2000 marked this pull request as ready for review April 3, 2026 07:31
@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_fix_p2 branch from 09b6dfc to 14228cb Compare April 7, 2026 01:44
@harryzhou2000 harryzhou2000 requested a review from tdophung April 7, 2026 01:46
@denera denera self-requested a review April 7, 2026 21:29
@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_fix_p2 branch from 14228cb to 08c2de9 Compare April 9, 2026 10:57
Copy link
Copy Markdown
Collaborator

@denera denera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending rebase and clean CI

@tdophung
Copy link
Copy Markdown
Collaborator

/te-ci

harryzhou2000 and others added 9 commits April 16, 2026 19:54
…r of experts

- expanding shared memory when needed
- switch to radix topk selection when topk is large
- test_fused_router.py updated with large num experts and tolerances refined for different cases

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
added return value check of cudaFuncSetAttribute in transformer_engine/common/fused_router/fused_topk_with_score_function.cu
added dtype dependent eps in tests/pytorch/test_fused_router.py
removed unneeded code in transformer_engine/common/fused_router/utils.h

pr bot suggestions

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
cleaned up raw warp operations
added comments
added shared_memory check
added return code check

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
@tdophung
Copy link
Copy Markdown
Collaborator

/te-ci

Copy link
Copy Markdown
Collaborator

@tdophung tdophung left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approve pending CI. Seems like the issues seen in CI are not related to this PR

@tdophung tdophung merged commit 1e9e48c into NVIDIA:main Apr 16, 2026
28 of 33 checks passed
@tdophung
Copy link
Copy Markdown
Collaborator

faradawn pushed a commit to faradawn/TransformerEngine that referenced this pull request May 14, 2026
)

* fix: enabling fused _router to be able to handle large topk and number of experts
- expanding shared memory when needed
- switch to radix topk selection when topk is large
- test_fused_router.py updated with large num experts and tolerances refined for different cases

* added topk>=16 in tests/pytorch/test_fused_router.py
added return value check of cudaFuncSetAttribute in transformer_engine/common/fused_router/fused_topk_with_score_function.cu
added dtype dependent eps in tests/pytorch/test_fused_router.py
removed unneeded code in transformer_engine/common/fused_router/utils.h

* test_fused_router.py needs to skip topk >= num_experts case

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>


cleaned up raw warp operations
added comments
added shared_memory check
added return code check

* warning about dtype for tolerance in test_fused_router.py

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>

---------

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants