Skip to content

Disable the RHT fusion for non-SM100 family devices#2968

Merged
timmoon10 merged 4 commits into
NVIDIA:mainfrom
ptrendx:pr_fix_rht_fusion
May 12, 2026
Merged

Disable the RHT fusion for non-SM100 family devices#2968
timmoon10 merged 4 commits into
NVIDIA:mainfrom
ptrendx:pr_fix_rht_fusion

Conversation

@ptrendx
Copy link
Copy Markdown
Member

@ptrendx ptrendx commented May 8, 2026

Description

Disable the RHT fusion for non-sm100 class devices (the kernel uses too much shared memory to be runnable on e.g. sm120).

Fixes #2956

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add the check on the sm arch when testing for the fusion eligibility.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx requested a review from timmoon10 May 8, 2026 00:07
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 8, 2026

Greptile Summary

This PR fixes a runtime crash/failure on SM120 (GB20x) devices by restricting the RHT cast fusion kernel to SM100-family hardware (compute capability 100–110), where the required MMA shared memory fits within device limits.

  • Adds sm_arch() >= 100 && sm_arch() <= 110 to the eligible_for_rht_cast_fusion guard in NVFP4Quantizer::quantize_impl, effectively disabling the fused path on SM120 and other non-SM100 architectures.
  • Pulls in common/util/cuda_runtime.h (previously unused in this translation unit) to expose the sm_arch() utility.

Confidence Score: 5/5

Safe to merge — the change is a narrow, additive guard that only disables a fusion path on hardware where it was already broken.

The fix is a single boolean condition added to an eligibility check; it cannot regress correct behavior on SM100 devices and safely falls back to the non-fused path on everything else. The header inclusion and the sm_arch() API are both well-established in this codebase.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/quantizer.cpp Adds an SM architecture range check (100–110) to eligible_for_rht_cast_fusion, restricting the RHT fusion kernel to SM100-family Blackwell devices and calling sm_arch() twice in the expression.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[quantize_impl called] --> B{Check eligibility for RHT cast fusion}
    B --> C{dtype == BFloat16?}
    C -- No --> F[eligible = false]
    C -- Yes --> D{rows % 64 == 0 AND cols % 128 == 0?}
    D -- No --> F
    D -- Yes --> E{sm_arch in range 100..110?}
    E -- No --> F
    E -- Yes --> G[eligible = true]
    G --> H[Use fused RHT cast kernel]
    F --> I[Use non-fused path]
Loading

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

bool eligible_for_rht_cast_fusion =
input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0;
input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0 &&
transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The upper bound <= 110 is tighter than the stated intent ("non-SM100 family"). Using < 120 more precisely captures "anything below SM120" and avoids silently disabling the fusion for hypothetical SM111/SM112 variants that belong to the same Blackwell compute family. The codebase already uses 120 as the implicit dividing line (SM120 = GB20x, which is the architecture that triggered the bug), so < 120 reads as clearly intentional.

Suggested change
transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110;
transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() < 120;

timmoon10
timmoon10 previously approved these changes May 8, 2026
@osubotin
Copy link
Copy Markdown

Empirical verification on RTX 5080 (sm_120) — confirms PR fix works

We filed the original failure as #2956 (cycle #289 stream X.5.2, May 2 2026). Just verified this PR end-to-end on RTX 5080.

Dim (M=K=N) TE 2.14.1 (unpatched) TE 2.14.1 + this PR
64×64×64 OK OK
96×96×96 OK OK
128×128×128 FAIL: RuntimeError: row_cast_col_hadamard_transform_cast_fusion.cu:1200 in function row_col_rht_gemm_ntt_w_sfc: CUDA Error: invalid argument OK
192×192×192 FAIL: AcceleratorError CUDA invalid argument OK
256×256×256 FAIL: invalid argument OK
384×384×384 FAIL: invalid argument OK
512×512×512 FAIL: invalid argument OK
1024×1024×1024 FAIL: invalid argument OK
1024×4096×1024 (production) FAIL: invalid argument OK

Numerical correctness: separate-ops fallback produces NVFP4 output within expected noise floor — max relative error vs bf16 reference is 0.124–0.163 across all shapes (NVFP4 E2M1 + per-block scale; ~12–16% rel-err vs bf16 is the expected quantization noise, not a correctness bug). Same band on the working 64×64×64 baseline (rel_err = 0.147), so the patch does not introduce additional drift.

Reproducer:

import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import NVFP4BlockScaling

recipe = NVFP4BlockScaling()
for M, K, N in [(64, 64, 64), (128, 128, 128), (1024, 4096, 1024)]:
    x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
    layer = te.Linear(K, N, params_dtype=torch.bfloat16, device="cuda")
    with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
        y = layer(x)
    torch.cuda.synchronize()
    print(f"OK M={M} K={K} N={N}")

On 2.14.1 unpatched: dim 64 OK, dim 128 onwards FAIL with invalid argument. With this PR applied: all shapes OK.

Hardware: RTX 5080 16GB, sm_120, CUDA 13.0.2, PyTorch 2.11.0+cu130, Linux WSL2 Ubuntu 24.04. TE built editable from a fresh clone of 2.14.1+366798e with the PR diff applied via git apply. Note: TE 2.14.1's quantizer.cpp does not transitively include common/util/cuda_runtime.h (which declares transformer_engine::cuda::sm_arch()), so we had to add that include locally for the PR to compile against the 2.14.1 base. On main this is presumably already pulled in elsewhere — worth a sanity check that the PR builds clean against the actual base it'll merge into.

Performance: separate-ops fallback ms/op for the patched path is within 5% of the working baseline shapes (e.g., 0.176 ms at 128×128×128 vs 0.184 ms at 96×96×96 baseline). Did not benchmark TFLOPS or compare against fouroversix — our project (Volkov VLM training, RTX 5080) uses fouroversix for production NVFP4 on sm_120 because of its fused-kernel sm_120 GEMM (cycle #321 Action 2 + cycle #372 Action 2 three-way validation = bit-identical at max_delta=0.000 vs TE NVFP4 reference).

LGTM for SM120 functional correctness. Reviewer's earlier suggestion of < 120 upper bound is also clean — functionally identical to <= 110 since SM111/SM112 don't exist, just slightly more readable as "everything below sm_120".

Looking forward to seeing this in 2.15 — that graduates TE-NVFP4 from RED to AMBER on consumer Blackwell as a fallback option for users without fouroversix.

@ptrendx
Copy link
Copy Markdown
Member Author

ptrendx commented May 11, 2026

/te-ci pytorch

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx
Copy link
Copy Markdown
Member Author

ptrendx commented May 11, 2026

/te-ci pytorch

@timmoon10 timmoon10 merged commit d5e7087 into NVIDIA:main May 12, 2026
21 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

3 participants