Skip to content

[#13816][feat] AutoDeploy: Optimize GPT-OSS-120b perf #13867

Draft
taylor-yb-lee wants to merge 14 commits intoNVIDIA:mainfrom
nv-auto-deploy:taylor/chenghao/gpt-oss-0505
Draft

[#13816][feat] AutoDeploy: Optimize GPT-OSS-120b perf #13867
taylor-yb-lee wants to merge 14 commits intoNVIDIA:mainfrom
nv-auto-deploy:taylor/chenghao/gpt-oss-0505

Conversation

@taylor-yb-lee
Copy link
Copy Markdown
Collaborator

@coderabbitai summary

Description

  • Added mxfp4 support in AutoDeploy
  • Tested with gpt-oss-120b & gsm8K passes with score 90.3

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

nvchenghaoz and others added 13 commits May 5, 2026 17:18
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Wraps torch.ops.trtllm.bf16_mxe2m1_block_scale_moe_runner -- the
trtllm-gen MXFP4-weight x BF16-activation MoE kernel that PT's
W4A16MXFP4TRTLLMGenFusedMoEMethod uses today on B200 by default.

Op signature: takes pre-shuffled MXFP4 weights, UE8M0 scales, float32
biases, and per-expert SwiGLU params. At forward time only zero-pads
activations to the kernel's expected H_pad and slices the output
back to valid_hidden_size.

The matching weight-prep helper, transform, and ShardingInfo arrive in
following steps. Op verified to register via torch.library and produce
the expected schema. No graph/transform changes yet -- this op is
inert until step 3 wires it into a transform.

Refs: cc_reports/gpt-oss-120b/MOE_TRTLLM_GEN_PLAN.md (step 1 of 6)
Signed-off-by: yeonbokl <yeonbokl@nvidia.com>
Mirrors PT MXFP4WeightTRTLLMGenFusedMoEMethod weight-loading path
(quantization.py:4135-4500). Reuses PT helpers maybe_pad_for_mxfp4,
trtllmgen_maybe_get_cached_*_permute_indices, _get_weight_alignment.

Steps: reshape HF [E, 2I, H/32, 16] -> [E, 2I, H/2], pad to alignment
(input_hidden_alignment//2=256 cols, weight_alignment=128 rows), pad
matching scales, shuffle per expert via torch.ops.trtllm.shuffle_matrix,
cast biases to float32. Returns PreparedMXFP4Weights dataclass.

Step-2 scope: tp_size=1 only; TP slicing arrives in step 5.

Smoke-tested on gpt-oss-120b shapes (E=128, I=H=2880) on B200 -- output
shapes match PT byte-for-byte.

Refs: cc_reports/gpt-oss-120b/MOE_TRTLLM_GEN_PLAN.md (step 2 of 6)
Signed-off-by: yeonbokl <yeonbokl@nvidia.com>
Cleaner graph integration: the op now takes raw router_weight + bias +
top_k and computes RenormalizeMoeRoutingMethod-style routing internally
(F.linear -> topk -> softmax-of-topk), then dispatches to the kernel
with pre-computed topk_weights / topk_ids.

This makes the upcoming transform (step 3) a single 1:1 op rewrite of
torch_moe_dense_mlp -> trtllm_mxfp4_w4a16_moe_fused without needing a
separate routing op upstream.

Refs: cc_reports/gpt-oss-120b/MOE_TRTLLM_GEN_PLAN.md (step 1 of 6)
Signed-off-by: yeonbokl <yeonbokl@nvidia.com>
Runs in post_load_fusion stage. Picks up triton_mxfp4_moe nodes from
quantize_mxfp4_moe, runs the step-2 weight prep, registers prepared
params on the experts module, and rewrites the call to
auto_deploy::trtllm_mxfp4_w4a16_moe_fused.

Frees the original raw HF-layout MXFP4 params after rewrite.

Step-3 V4 scope: EP=1 (triton_mxfp4_moe without _ep) only. EP variant
is covered by step 5 with MXFP4TRTLLMGenSharding.

Refs: cc_reports/gpt-oss-120b/MOE_TRTLLM_GEN_PLAN.md (step 3 of 6)
Signed-off-by: yeonbokl <yeonbokl@nvidia.com>
Reorder positional args in
``Bf16MxE2m1BlockScaleMoERunner.get_valid_tactics`` to match the C++
signature of ``Bf16MxE2m1BlockScaleMoeRunner::getValidConfigs``
(``cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp:516``):
``(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens,
validHiddenSize, validIntermediateSize)``.

Commit 86cfb3e (cubin update + valid_*_size plumbing) added
``valid_hidden_size`` / ``valid_intermediate_size`` params to all three
trtllm-gen MoE runners' Python wrappers. The other two siblings
(``MxE4m3MxE2m1`` line 968, ``E4m3MxE2m1`` line 1274) appended the new
args at the end correctly; only ``Bf16MxE2m1`` placed them in the
middle, so the autotuner was passing ``valid_*`` values into the
``numLocalExperts`` / ``numTokens`` slots and ``local_num_experts`` /
``num_tokens`` into the ``valid_*`` slots. Effect: the cubin filter
saw garbage shape parameters, returned an empty tactic list, and the
autotune cache stayed empty -- so at run time the kernel fell back to
``getDefaultValidConfigIndex`` and asserted "No valid config found
for the given problem shape MNK" on the first MoE call (e.g. AD's
``resize_kv_cache`` memory probe at ``max_num_tokens=8192``).

This Python-only reorder restores parity with the C++ binding; no
recompile needed.

Found while onboarding gpt-oss-120b on AutoDeploy with the
``bf16_mxe2m1`` MoE path; reproduces in any non-tuning-mode call to
the op (e.g. PT's ``MXFP4WeightTRTLLMGenFusedMoEMethod`` users hit it
on the first prefill).

Signed-off-by: yeonbokl <yeonbokl@nvidia.com>
Bring ``prepare_mxfp4_weights_for_trtllm_gen`` and the
``trtllm_mxfp4_w4a16_moe_fused`` op into structural parity with PT's
``MXFP4WeightTRTLLMGenFusedMoEMethod`` (quantization.py:4135) so the
trtllm-gen MoE kernel sees the same byte layout PT exercises:

mxfp4_weight_prep.py changes:
* Per-expert ``I_pad = roundUp(I, weight_alignment) = 2944`` first;
  derive ``2I_pad = 5888`` and ``I/2_pad = 1472`` from that. Previously
  we padded ``2I = 5760`` directly which is already 128-aligned and
  thus a no-op, leaving w1's effective ``I = 2880`` while w2's column
  padding pushed ``I = 2944`` -- inconsistent intermediate dim across
  the two gemms.
* W1 hidden axis padded to ``input_hidden_alignment = 512``
  (``H_w1_pad = 3072``), W2 hidden axis padded to
  ``weight_alignment = 128`` (``H_w2_pad = 2944``), matching PT's
  ``create_weights`` (lines 3715-3717 of quantization.py).
* De-interleave gate / up rows from the on-disk row-interleaved
  storage (``gate_up_proj_blocks[:, ::2, :]`` = gate,
  ``[:, 1::2, :]`` = up) and pad each half to ``I_pad`` separately
  before stacking as ``[up | gate]``. PT's chunk-then-copy dance
  (modeling_gpt_oss.py:695-706 + quantization.py:4252-4258) ends up
  with the same physical layout.
* Add ``torch.ops.trtllm.block_scale_interleave`` after
  ``shuffle_matrix`` for both fc1 and fc2 scales -- PT does both ops
  (quantization.py:4382, 4439); skipping the second was a partial
  bug.

trtllm_moe.py change:
* Routing softmax in fp32 instead of bf16 -- matches PT's
  ``RenormalizeMoeRoutingMethod`` which casts to fp32 for the topk
  softmax then back to the activation dtype.

Status: kernel builds and runs cleanly with these changes, and pure
GEMM throughput is at the V4 target (~9.28 ms ITL / ~108 tok/s for
gpt-oss-120b vs V3 Triton's 127.79 ms / 7.96 tok/s -- 13.5x).
However, content correctness is still blocked by an upstream NaN bug
in the trtllm-gen MoE kernel itself: PT's own
``TRTLLMGenFusedMoE.forward`` on gpt-oss-120b at this TRT-LLM commit
also produces NaN logits, so any byte-correct prep cannot rescue
output. Tracking note: re-validate when upstream fix lands; if
correctness is restored, proceed to step 5 (TP-MoE sharding).

Refs: cc_reports/gpt-oss-120b/MOE_TRTLLM_GEN_PLAN.md (step 4 of 6),
RESUME_V4.md.
Signed-off-by: yeonbokl <yeonbokl@nvidia.com>
Mirrors modeling_gpt_oss.py but routes every attention Linear through
``torch.ops.auto_deploy.torch_linear_simple`` with sharding hint kwargs
(``tp_mode``, ``tp_min_local_shape``, ``layer_type``), and inserts
``torch.ops.auto_deploy.view`` (``tp_scaled_dim=2``) for q/k/v/attn_out
reshapes plus a trailing ``torch.ops.auto_deploy.all_reduce`` placeholder
after the rowwise o_proj.  Same pattern qwen3_ir / qwen3_5_moe_ir use.

Sharding strategy emitted into the graph:
  q_proj / k_proj / v_proj -> colwise (+ tp_min_local_shape=head_dim
                              for GQA: 64 Q heads / 8 KV heads at TP=8)
  view (q/k/v/attn_out)    -> tp_scaled_dim=2 (head-count dim)
  o_proj                   -> rowwise + auto_deploy.all_reduce

Out of scope here (matches qwen_ir convention):
  * MoE router + experts stay replicated -- the V4 trtllm-gen MoE op
    (``trtllm_mxfp4_w4a16_moe_fused``) has no ShardableNode yet.  Step 5
    of MOE_TRTLLM_GEN_PLAN.md (V6) registers TP-MoE for that op.
  * lm_head stays as plain nn.Linear -- no canonical sharding-IR pattern
    for col-parallel-then-all-gather in this codebase yet.

Registration:
  * GptOssForCausalLM still registers via ``register_custom_model_cls``
    (last-registration-wins).
  * ``models/custom/__init__.py`` adds modeling_gpt_oss_ir to the
    ``AD_USE_IR_MODELS`` opt-in block, alongside deepseek_ir,
    nemotron_h_ir, qwen3_5_moe_ir.

Validated end-to-end on gpt-oss-120b 8xB200 with the new V5 yaml
(world_size=8, apply_sharding_hints with shard_layers=["mha"],
detect_sharding+sharding_transform_executor disabled): apply_sharding_hints
processed 324 nodes / skipped 37 (the MoE nodes carry layer_type="moe"),
strip_sharding_hints stripped 288 hints, fuse_allreduce_residual_rmsnorm
matched 36 -- attention TP=8 fully wired through.

Refs: cc_reports/gpt-oss-120b/MOE_TRTLLM_GEN_PLAN.md (V5 step),
      RESUME_V4.md (still valid for the trtllm-gen NaN tracking).

Signed-off-by: yeonbokl <yeonbokl@nvidia.com>
Step 5 of MOE_TRTLLM_GEN_PLAN.md: extend the V4 trtllm-gen MoE op with
TP-sharding on the intermediate axis so MoE compute itself splits across
ranks (V5 only sharded attention; MoE was replicated and dominated cost).

prepare_mxfp4_weights_for_trtllm_gen:
  * Add tp_rank arg.
  * Compute TP-aware alignment via _get_weight_alignment so per-rank
    intermediate is itself 128-aligned after pad-before-shard (matches
    PT load_expert_w3_w1_weight / load_expert_w2_weight).
  * Pre-pad intermediate axis to alignment_tp, then slice
    [tp_rank * I_pr, (tp_rank+1) * I_pr] on gate/up rows, scales, biases
    (col-parallel) and on dn_3d cols (row-parallel, /2 for packed mxfp4).
  * Slice down_scales on dim 2 with /scaling_vector_size stride.
  * Clamp valid_intermediate to min(intermediate_size, slice_stop) -
    slice_start.

QuantizeMXFP4MoETrtllmGen transform:
  * Read moe_tp_size / moe_tp_rank / allreduce_strategy from
    shared_config.dist_config.
  * Forward to prepare_mxfp4_weights_for_trtllm_gen.
  * After the V4 op rewrite, when moe_tp_size > 1 insert
    auto_deploy.all_reduce so partial [..., hidden] outputs from each
    rank sum across ranks before the residual add.  fc2_bias is divided
    by tp_size in the prep helper so the post-AR sum reproduces the
    unsharded bias.

Smoke-tested:
  * tp=1 -> fc1=[8, 5888, 1536] valid_I=2880 (no regression).
  * tp=8 rank=0 -> fc1=[8, 768, 1536] valid_I=384.
  * tp=8 rank=7 -> fc1=[8, 768, 1536] valid_I=192 (last rank partial).
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
…ayout

prepare_mxfp4_weights_for_trtllm_gen padded per-expert biases but never
row-shuffled them, while it did shuffle the weights and scales. The
trtllm-gen bf16_mxe2m1_block_scale_moe_runner kernel adds bias[i] to
post-shuffle output row i of GEMM1/GEMM2, so leaving biases in
pre-shuffle order made the kernel attribute the wrong bias to each row
and the MoE output came out as noise (gpt-oss-120b GSM8K dropped to
2.05% vs the 90.30% reference).

PT's MXFP4WeightTRTLLMGenFusedMoEMethod (quantization.py:4204-4319)
runs the very same row permutation on the bias destination buffer:
load_expert_w3_w1_weight applies the gated-act-gemm interleave +
epilogue-tile reorder to the 1-D [2*I_pad] gated bias, and
load_expert_w2_weight applies the epilogue-tile reorder to the 1-D
[H_pad] down bias. Mirror that in the AD prep helper via two new
_shuffle_per_expert_bias_w3_w1 / _shuffle_per_expert_bias_w2 helpers
so the AD prep stays byte-identical with PT.

Add tests/unittest/auto_deploy/singlegpu/custom_ops/moe/test_mxfp4_weight_prep.py
(3 tests) to pin the invariant: fc1 bias matches a manual gated+TMA
permute, fc2 bias matches the manual TMA permute, and the full prep
output is byte-identical to a per-expert PT-style reference loader
(weights, scales, and biases all checked). Without the fix all three
tests fail (98.8% mismatch on the bias rows); with it they pass.

End-to-end validation on gpt-oss-120b at world_size=1 with
quantize_mxfp4_moe_trtllm_gen enabled:
- GSM8K (test_mxfp4_gsm8k[120b]): 2.05% -> 90.37% (threshold 87.10%,
  reference 90.30%) -> PASS.
- ITL (V4 single-GPU, ISL=1000 OSL=1000 conc=1, 20 reqs): 8.53 ms p50
  / 117.4 tok/s/user with content valid (OSL=1000 verified).

Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
…gen MoE

The V4 single-GPU + trtllm-gen MXFP4 MoE path is the now-correctness-
validated baseline for gpt-oss-120b on B200 (previous commit fixes the
weight-prep bias shuffle so the trtllm-gen kernel produces correct
logits). Update examples/auto_deploy/model_registry/configs/
gpt_oss_120b.yaml to that configuration so the standalone AD serving
config matches the live recommendation:

- world_size 4 -> 1 (single GPU; the model fits in 192 GB HBM at
  MXFP4 and there is no AR overhead at BS=1).
- Enable transform `quantize_mxfp4_moe_trtllm_gen` so the post-load
  fusion stage rewrites `triton_mxfp4_moe` to
  `auto_deploy::trtllm_mxfp4_w4a16_moe_fused` and dispatches to
  `torch.ops.trtllm.bf16_mxe2m1_block_scale_moe_runner` -- the same
  kernel PT exercises via `MXFP4WeightTRTLLMGenFusedMoEMethod`.

Measured on the same standalone serving config (ISL=1000, OSL=1000,
conc=1, 20 reqs, `DISABLE_HARMONY_ADAPTER=1` +
`--use-server-token-count`):

- ITL p50 8.53 ms / 117.4 tok/s/user (vs Triton-MXFP4 baseline 122 ms
  ITL / 8 tok/s/user, ~15x speedup).
- GSM8K accuracy 90.37 % (threshold 87.10 %, reference 90.30 %).

Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants