[Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute#2907
Conversation
Greptile SummaryThis PR addresses two related correctness bugs in the MoE permutation kernels: int32 overflow in pointer-offset arithmetic for large token counts, and incorrect ordering of
Confidence Score: 3/5The core forward-path overflow fixes are correct, but two 32-bit overflow paths in the backward/unpermute kernels remain unaddressed and will corrupt memory for sufficiently large token counts. The int index = source_token narrowing in moe_permute_kernel stores a newly-widened int64_t blockIdx.x into an int, and the k * num_rows inner-loop product in moe_unpermute_kernel is still evaluated as int32 before being used as a row_id_map index. Both paths produce wrong memory addresses when topK * num_rows approaches 2^31, which is the same class of bug this PR targets. These two paths were called out in earlier review rounds and have not been closed by the current commits. transformer_engine/common/permutation/permutation.cu — specifically the moe_permute_kernel backward path (line 169 index narrowing) and the moe_unpermute_kernel inner loop (line 92 int32 product). Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[routing_map keys plus row_id values] -->|uint32 sort: -1 goes to tail| B[sorted_indices and sorted_row_id]
B --> C{idx less than num_out_tokens?}
C -->|Valid token| D[row_id_map write = idx using int64_t index]
C -->|Dropped sentinel| E[row_id_map write = -1 using int64_t index]
D --> F[moe_permute_kernel FWD source_token as int64_t dest_row as int64_t]
E --> F
F --> G[moe_unpermute_kernel source_token as int64_t source_row as int64_t if -1 then zero output]
Reviews (9): Last reviewed commit: "Merge branch 'main' into fix/moe-permute..." | Re-trigger Greptile |
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | ||
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + num_minus_ones * sizeof(int); |
There was a problem hiding this comment.
Negative
num_minus_ones becomes enormous size_t offset
num_minus_ones is computed as int. If a caller passes num_out_tokens > num_tokens * topK (which the function does not validate), num_minus_ones is negative. The pointer advance expression:
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + num_minus_ones * sizeof(int);involves int * size_t, which promotes num_minus_ones to size_t (unsigned). A value like -4 becomes SIZE_MAX - 3, advancing the pointer far out of the allocation and causing a silent OOB read. A simple clamp or assert before this line would prevent this:
TORCH_CHECK(num_out_tokens <= num_tokens * topK,
"num_out_tokens (", num_out_tokens, ") cannot exceed num_tokens*topK (",
num_tokens * topK, ")");a909254 to
ce34774
Compare
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | ||
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | ||
| static_cast<size_t>(num_minus_ones) * sizeof(int); |
There was a problem hiding this comment.
num_tokens * topK still computed as int * int
num_tokens and topK are both int, so num_tokens * topK on line 61 is evaluated in 32-bit arithmetic before the result feeds the int64_t subtraction. The same expression appears twice in the NVTE_CHECK on lines 59–60. If num_tokens * topK wraps to a negative int (possible when, e.g., num_tokens ≥ 2^31 / topK), the NVTE_CHECK would either spuriously reject a valid num_out_tokens, or the error-message value would be wrong. Casting to int64_t before the multiplication closes this gap:
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | |
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | |
| static_cast<size_t>(num_minus_ones) * sizeof(int); | |
| const int64_t total_tokens = static_cast<int64_t>(num_tokens) * topK; | |
| NVTE_CHECK(num_out_tokens <= total_tokens, "num_out_tokens (", num_out_tokens, | |
| ") must not exceed num_tokens*topK (", total_tokens, ")"); | |
| const int num_minus_ones = static_cast<int>(total_tokens - num_out_tokens); | |
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | |
| static_cast<size_t>(num_minus_ones) * sizeof(int); |
There was a problem hiding this comment.
This (the greptile command) looks correct. Can you please help cast num_tpkens to int64 before multiplication and - num_out_tokens?
There was a problem hiding this comment.
This is not a huge deal because even with topK=128, youwould need > 16M tokens per rank for the int product to overflow. But better to be consistent, and also, this casting of 1 value on the CPU side probably would not slow dow much
cd64aa5 to
b73a1f9
Compare
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | ||
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | ||
| static_cast<size_t>(num_minus_ones) * sizeof(int); |
There was a problem hiding this comment.
This (the greptile command) looks correct. Can you please help cast num_tpkens to int64 before multiplication and - num_out_tokens?
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | ||
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | ||
| static_cast<size_t>(num_minus_ones) * sizeof(int); |
There was a problem hiding this comment.
This is not a huge deal because even with topK=128, youwould need > 16M tokens per rank for the int product to overflow. But better to be consistent, and also, this casting of 1 value on the CPU side probably would not slow dow much
| num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; | ||
| NVTE_CHECK(num_out_tokens <= num_tokens * topK, "num_out_tokens (", num_out_tokens, | ||
| ") must not exceed num_tokens*topK (", num_tokens * topK, ")"); | ||
| const int num_minus_ones = num_tokens * topK - num_out_tokens; |
There was a problem hiding this comment.
This is probably going to introduce a regression for the capacity-drop path. This shift assumes the dropped routes are -1 sentinels at the head of sorted_row_id (cub's signed radix sort), which is true for the EP-mask case this PR targets. But the pre-existing capacity-drop path encodes drops as a large positive expert id that sorts to the tail. For that case, the head is valid low-expert-id rows, and shifting past them drops the wrong tokens.(just fyi, capacity-dropping case means no -1 in indices, num_out_tokens < num_tokens * topK because some expert exceeded capacity))
See in this file tests/pytorch/test_permutation.py, in pytorch_permute_index_map, we have:
sorted_indices[:num_out_tokens] (keeps the head),
so I'd expect test_permutation_index_map[..., num_out_tokens=2039, ...] to fail. We can run the te_ci to confirm it.
There was a problem hiding this comment.
I think another solution to this without doing num_tokens * topk - num_out_tokens (or counting the number of -1 on host side) is to sort the keys as uint32_t instead of int32_t. So, -1 becomes UINT_MAX and sorts to the tail, unifying both capacity-dropping and dropless under the original idx >= num_out_tokens --> drop logic. That removes the need for the prefix shift you did, and the row_id_map pre-fill. This just needs expert_id to be <= UINT_MAX, which I do not think we are reaching there anytime soon
There was a problem hiding this comment.
Thanks for the careful review. Acknowledging the capacity-drop regression concern and the unsigned-sort suggestion below — both make sense. Waiting on the te_ci result you triggered before I push any code change, so we have a concrete signal on what needs to move.
There was a problem hiding this comment.
Here is the CI pipeline: https://gitlab-master.nvidia.com/dl/transformerengine/transformerengine/-/pipelines/50478896
It failed in the expected tests
|
/te_ci pytorch |
|
/te-ci pytorch |
|
/te-ci pytorch L0 |
Hi, I can trigger it now. However, for future reference, you could also trigger it by commenting "/te-ci " on the PR. is either pytorch or jax, and can be L0, L1, L2. Alternatively, you can test it out locally also,, by running the related tests in tests/pytorch or tests/jax, depending on the change, or using the scripts in qa/L0_pytorch_unittest/test.sh |
|
/te-ci pytorch |
|
I can see that the CI finished and the failures are not related to your change: https://gitlab-master.nvidia.com/dl/transformerengine/transformerengine/-/pipelines/50584460 the only test failing was fused_qkv not permutation |
Per reviewer feedback in NVIDIA#2907, promote the int * int multiplications in moe_permute_row_map and its launcher to int64_t. These are not the overflow path this PR was originally fixing (DeepSeek-V3 long-context hits row * num_cols, where num_cols is the hidden dim ~ 7-8k), and num_rows * topK only crosses 2**31 at unrealistic per-rank token counts (>= 268M at topK=8). The change is purely defensive but keeps the index arithmetic in this kernel consistent with the int64_t source_token / source_row / dest_row widening already applied to moe_unpermute_kernel and moe_permute_kernel. Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
|
/te-ci pytorch |
1 similar comment
|
/te-ci pytorch |
|
It seems my "/te-ci pytorch" does not work? @tdophung could you trigger it on your side once you have a moment? The branch is also rebased onto latest main now. |
…permute Two independent bugs in transformer_engine/common/permutation/permutation.cu and the PyTorch extension caller reproduce on main (264da2b) and v2.13: 1. int32 overflow in moe_unpermute_kernel and moe_permute_kernel. `source_token * num_cols` and `source_row * num_cols` are computed with int, so for long-sequence MoE workloads where num_out_tokens * num_cols reaches 2**31 (e.g. 2**18 tokens x 2**13 hidden), the pointer offset wraps and the kernel either reads garbage or raises `an illegal memory access was encountered`. Widening source_token, source_row and dest_row to int64_t inside the kernels keeps the index arithmetic in 64 bits without changing any public types. 2. Incorrect handling of -1 sentinels in the routing indices. Libraries such as DeepEP (and any expert-parallel mask that sets non-local (token, slot) pairs to -1) feed a routing_map that contains -1 entries. `cub::DeviceRadixSort::SortPairs` is signed ascending, so those sentinels land at the HEAD of sorted_row_id, not the tail. moe_permute_row_map currently writes -1 only for idx >= num_out_tokens and reads the sentinel prefix as if it were a valid sorted id, producing bogus row_id_map writes (for instance `source_row / topK == 0, source_row % topK == -1`). The caller now advances sorted_row_id_ptr past the num_minus_ones prefix and pre-fills row_id_map with -1 via torch::full, so the kernel only processes the valid suffix and never dereferences a sentinel. The launcher's grid switches from num_rows*topK blocks to num_out_tokens blocks to match the new valid range. No behaviour change on happy-path routing_map (no -1, no overflow). Reproducers: - 8-token, topK=2 routing_map with -1 masking: max |TE - ref| = 4.5e0 on bf16 with current main; 0.0 with this patch. - num_tokens=2**18+1, num_cols=2**13, topK=1: current main raises CUDA illegal memory access at permutation.cu:252; with this patch it succeeds. Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
for more information, see https://pre-commit.ci
Add an NVTE_CHECK that num_out_tokens <= num_tokens * topK and cast num_minus_ones to size_t before the pointer advance, so a negative num_minus_ones (from an invalid num_out_tokens) cannot silently wrap into a huge pointer offset. Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
The MoE permute path was correct for the existing capacity-drop convention (drops encoded as a large positive expert id, sorted to the tail by the signed cub::DeviceRadixSort), but it broke for callers that mark dropped (token, slot) pairs with -1 (expert-parallel rank masking, e.g. DeepEP). With signed sort the -1 sentinels land at the HEAD of sorted_row_id, while moe_permute_row_map's `idx >= num_out_tokens` branch assumes drops are at the tail. Reinterpret the keys as uint32_t inside nvte_device_radix_sort_pairs so -1 (= UINT_MAX) sorts to the tail, unifying the EP-mask case with the existing capacity-drop convention. The kernel and caller sides are unchanged - this is a one-place fix that makes both drop conventions land in the existing drop branch. Also widen the loop-carried indices in moe_unpermute_kernel and moe_permute_kernel to int64_t (`source_token`, `source_row`, `dest_row`) to keep `row * num_cols` in 64 bits. We hit this on DeepSeek-V3 long- context training (hidden = 7168, topK = 8): once `num_out_tokens * num_cols` reaches 2**31 the int product wraps and the kernel either silently corrupts rows or raises CUDA `illegal memory access`. Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
Per reviewer feedback in NVIDIA#2907, promote the int * int multiplications in moe_permute_row_map and its launcher to int64_t. These are not the overflow path this PR was originally fixing (DeepSeek-V3 long-context hits row * num_cols, where num_cols is the hidden dim ~ 7-8k), and num_rows * topK only crosses 2**31 at unrealistic per-rank token counts (>= 268M at topK=8). The change is purely defensive but keeps the index arithmetic in this kernel consistent with the int64_t source_token / source_row / dest_row widening already applied to moe_unpermute_kernel and moe_permute_kernel. Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
7d20c2b to
fbe91af
Compare
|
/te-ci pytorch |
…permute (NVIDIA#2907) * [Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute Two independent bugs in transformer_engine/common/permutation/permutation.cu and the PyTorch extension caller reproduce on main (264da2b) and v2.13: 1. int32 overflow in moe_unpermute_kernel and moe_permute_kernel. `source_token * num_cols` and `source_row * num_cols` are computed with int, so for long-sequence MoE workloads where num_out_tokens * num_cols reaches 2**31 (e.g. 2**18 tokens x 2**13 hidden), the pointer offset wraps and the kernel either reads garbage or raises `an illegal memory access was encountered`. Widening source_token, source_row and dest_row to int64_t inside the kernels keeps the index arithmetic in 64 bits without changing any public types. 2. Incorrect handling of -1 sentinels in the routing indices. Libraries such as DeepEP (and any expert-parallel mask that sets non-local (token, slot) pairs to -1) feed a routing_map that contains -1 entries. `cub::DeviceRadixSort::SortPairs` is signed ascending, so those sentinels land at the HEAD of sorted_row_id, not the tail. moe_permute_row_map currently writes -1 only for idx >= num_out_tokens and reads the sentinel prefix as if it were a valid sorted id, producing bogus row_id_map writes (for instance `source_row / topK == 0, source_row % topK == -1`). The caller now advances sorted_row_id_ptr past the num_minus_ones prefix and pre-fills row_id_map with -1 via torch::full, so the kernel only processes the valid suffix and never dereferences a sentinel. The launcher's grid switches from num_rows*topK blocks to num_out_tokens blocks to match the new valid range. No behaviour change on happy-path routing_map (no -1, no overflow). Reproducers: - 8-token, topK=2 routing_map with -1 masking: max |TE - ref| = 4.5e0 on bf16 with current main; 0.0 with this patch. - num_tokens=2**18+1, num_cols=2**13, topK=1: current main raises CUDA illegal memory access at permutation.cu:252; with this patch it succeeds. Signed-off-by: Jingyi Xi <flotherxi@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Guard against invalid num_out_tokens in moe_permute_fwd Add an NVTE_CHECK that num_out_tokens <= num_tokens * topK and cast num_minus_ones to size_t before the pointer advance, so a negative num_minus_ones (from an invalid num_out_tokens) cannot silently wrap into a huge pointer offset. Signed-off-by: Jingyi Xi <flotherxi@gmail.com> * Switch radix sort keys to uint32_t to fix -1 sentinel ordering The MoE permute path was correct for the existing capacity-drop convention (drops encoded as a large positive expert id, sorted to the tail by the signed cub::DeviceRadixSort), but it broke for callers that mark dropped (token, slot) pairs with -1 (expert-parallel rank masking, e.g. DeepEP). With signed sort the -1 sentinels land at the HEAD of sorted_row_id, while moe_permute_row_map's `idx >= num_out_tokens` branch assumes drops are at the tail. Reinterpret the keys as uint32_t inside nvte_device_radix_sort_pairs so -1 (= UINT_MAX) sorts to the tail, unifying the EP-mask case with the existing capacity-drop convention. The kernel and caller sides are unchanged - this is a one-place fix that makes both drop conventions land in the existing drop branch. Also widen the loop-carried indices in moe_unpermute_kernel and moe_permute_kernel to int64_t (`source_token`, `source_row`, `dest_row`) to keep `row * num_cols` in 64 bits. We hit this on DeepSeek-V3 long- context training (hidden = 7168, topK = 8): once `num_out_tokens * num_cols` reaches 2**31 the int product wraps and the kernel either silently corrupts rows or raises CUDA `illegal memory access`. Signed-off-by: Jingyi Xi <flotherxi@gmail.com> * Widen num_rows * topK products in moe_permute_row_map for consistency Per reviewer feedback in NVIDIA#2907, promote the int * int multiplications in moe_permute_row_map and its launcher to int64_t. These are not the overflow path this PR was originally fixing (DeepSeek-V3 long-context hits row * num_cols, where num_cols is the hidden dim ~ 7-8k), and num_rows * topK only crosses 2**31 at unrealistic per-rank token counts (>= 268M at topK=8). The change is purely defensive but keeps the index arithmetic in this kernel consistent with the int64_t source_token / source_row / dest_row widening already applied to moe_unpermute_kernel and moe_permute_kernel. Signed-off-by: Jingyi Xi <flotherxi@gmail.com> --------- Signed-off-by: Jingyi Xi <flotherxi@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Teddy Do <tdophung@nvidia.com>
Fixes #2908 — full description, repros, and DeepSeek-V3 context there.
Changes
permutation.cu— widensource_token,source_row,dest_rowtoint64_tinsidemoe_unpermute_kernelandmoe_permute_kernelsorow * num_colsstays 64-bit. Simplifymoe_permute_row_mapto only process the valid[0, num_out_tokens)range; launcher grid becomesnum_out_tokensblocks.permutation.cpp— advancesorted_row_id_ptrpast thenum_minus_onessentinel prefix left bycub::DeviceRadixSort(signed ascending), and pre-fillrow_id_mapwith-1viatorch::fullso dropped slots are marked without the kernel ever dereferencing a sentinel.No public API / dtype changes.
+17 / -18lines across the two files.Test plan
routing_map(no-1, offsets within int32) — unchanged.-1-sentinel repro from [Bug] moe_permute CUDA kernel: int32 overflow and incorrect -1 sentinel handling #2908 →max |TE - ref| = 0.0on bf16 (was4.56e0).int32-boundary repro from [Bug] moe_permute CUDA kernel: int32 overflow and incorrect -1 sentinel handling #2908 → no longer raisesillegal memory access; matches reference.tests/pytorch/test_permutation.pyvia CI.