[Common] Fix fused router for large top-K and expert counts#2821
Conversation
Greptile SummaryThis PR fixes fused MoE router kernels for large top-K values and expert counts by adding a warp-level radix selection algorithm (O(E), independent of K) alongside the existing O(K²E) naive implementation, dispatched at Previous feedback on unchecked Confidence Score: 5/5Safe to merge; the radix algorithm, smem expansion, and softmax rewrite are all correct, and the only remaining finding is a minor per-launch performance suggestion. All P0/P1 concerns from prior review rounds have been addressed: cudaFuncSetAttribute return values are now checked, Radix path is covered by topk=16,32 tests, and _get_tolerances now raises on non-fp32 dtypes. The sole new finding is a P2 performance note about redundant device attribute queries per kernel launch, which does not affect correctness. transformer_engine/common/fused_router/utils.h — check_shared_memory_capacity_num_experts queries device attributes on every launch; worth caching. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Kernel Launcher] --> B{topk lt 16?}
B -->|Yes| C[TopkFuncType Naive]
B -->|No| D[TopkFuncType Radix]
C --> E[cudaFuncSetAttribute wrapped with NVTE_CHECK_CUDA]
D --> E
E --> F[check_shared_memory_capacity_num_experts]
F --> G[Kernel Launch]
G --> H[NVTE_CHECK_CUDA cudaGetLastError]
subgraph radix_topk_and_mask
I[Phase 1 - 8-pass radix selection] --> J[Phase 2a - Gather strictly greater elements]
J --> K[Phase 2b - Fill ties in ascending index order]
end
subgraph apply_softmax_on_float
L[Pass 1 - Online max and sum per lane] --> M[Warp butterfly reduction with NaN guard]
M --> N[Pass 2 - Normalize in-place]
end
Reviews (6): Last reviewed commit: "warning about dtype for tolerance in tes..." | Re-trigger Greptile |
ee33ea2 to
fab73d1
Compare
09b6dfc to
14228cb
Compare
14228cb to
08c2de9
Compare
denera
left a comment
There was a problem hiding this comment.
LGTM, pending rebase and clean CI
|
/te-ci |
…r of experts - expanding shared memory when needed - switch to radix topk selection when topk is large - test_fused_router.py updated with large num experts and tolerances refined for different cases Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
added return value check of cudaFuncSetAttribute in transformer_engine/common/fused_router/fused_topk_with_score_function.cu added dtype dependent eps in tests/pytorch/test_fused_router.py removed unneeded code in transformer_engine/common/fused_router/utils.h pr bot suggestions Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
cleaned up raw warp operations added comments added shared_memory check added return code check Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
08c2de9 to
6c4886b
Compare
|
/te-ci |
tdophung
left a comment
There was a problem hiding this comment.
approve pending CI. Seems like the issues seen in CI are not related to this PR
|
retriggered CI and all passed: https://gitlab-master.nvidia.com/dl/transformerengine/transformerengine/-/pipelines/48710028 |
) * fix: enabling fused _router to be able to handle large topk and number of experts - expanding shared memory when needed - switch to radix topk selection when topk is large - test_fused_router.py updated with large num experts and tolerances refined for different cases * added topk>=16 in tests/pytorch/test_fused_router.py added return value check of cudaFuncSetAttribute in transformer_engine/common/fused_router/fused_topk_with_score_function.cu added dtype dependent eps in tests/pytorch/test_fused_router.py removed unneeded code in transformer_engine/common/fused_router/utils.h * test_fused_router.py needs to skip topk >= num_experts case Signed-off-by: Harry Zhou <hhanyu@nvidia.com> cleaned up raw warp operations added comments added shared_memory check added return code check * warning about dtype for tolerance in test_fused_router.py Signed-off-by: Harry Zhou <hhanyu@nvidia.com> --------- Signed-off-by: Harry Zhou <hhanyu@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Description
Fixed fused router support for large topk and num_expert. Now num_expert <=2304 and any topk is supported with reasonable performance.
Current benchmark shows fused topk forward kernel is faster than pytorch at topk=32, which would be around 8x faster than before optimization.
Type of change
Changes
topk >= 16boundary.cudaFuncSetAttributein both forward and backwardkernel launchers to avoid silent failures when expert count exceeds the default 48 KB limit.
apply_softmax_on_floatto use a numerically stable online max+sum accumulation(two-pass → single-pass) with NaN-safe warp reduction, eliminating shared-memory round-trips.
Details
Radix top-K selection (
utils.h):Implements a 4-bit radix selection algorithm (8 passes over float32) that finds the K-th largest
value in O(E/32) per warp, independent of K. Phase 1 narrows the bit pattern of the K-th value
via histogram counting; Phase 2 gathers elements into output arrays with deterministic tie-breaking
(value DESC, index ASC) matching
torch.topkbehavior.Dispatch logic (
fused_topk_with_score_function.cu,fused_score_for_moe_aux_loss.cu):Template parameter
TopkFuncType(Naive/Radix) is selected at launch time based ontopk < 16. Both forward kernels and backward kernels now callcudaFuncSetAttributetorequest the required dynamic shared memory size before launch.
Tests (
test_fused_router.py):num_experts=1024to all parametrized test cases._get_tolerances()helper that scalesatol/rtolwith expert count to account forO(N * eps) accumulation divergence between fused and reference implementations.
Checklist: