Skip to content

[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912

Open
tdophung wants to merge 19 commits into
NVIDIA:mainfrom
tdophung:teddy/moe_block
Open

[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912
tdophung wants to merge 19 commits into
NVIDIA:mainfrom
tdophung:teddy/moe_block

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented Apr 21, 2026

Description

Most of MoE building blocks integration work has been deeply coupled with Maxtext development. Now creating this MoE block to isolate the work from Maxtext and provide more room for experimentation. MoEBlock is a self-contained Flax-Linen module that wires together TE's fused router, pluggable token-dispatch backends (pure-JAX argsort or Triton sort_chunks_by_index), grouped_dense-based expert FFN, and ragged-all-to-all (A2Av) expert parallelism via shard_map

This first iteration will start with ring-of-experts EP, sharding on batch dimention for FSDP, CUBLASLt groupedGEMM and 2 permutation backend: pure JAX or Triton kernels. The block also exposes pluggable knobs for: weight layout (wi_kernel_axes/ wo_kernel_axes), permutation backend, A2A vs no-EP (single GPU) path, data-parallelism axes for true FSDP (batch sharded across (ep, fsdp) simultaneously), top-K with optional grouped/sigmoid scoring (for DSv3 workload), and optional auxiliary load-balancing loss.

Fixes #2895

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • New transformer_engine/jax/flax/moe.py -- MoEBlock Linen module:
    gate -> fused topk -> global permute -> A2A EP shard_map (ragged_a2a fwd, local permute, 3x grouped GEMM SwiGLU FFN, local unpermute, ragged_a2a rev) -> global combine.
  • Extended transformer_engine/jax/permutation.py with A2A param helpers (compute_ragged_all_to_all_params, compute_reverse_ragged_all_to_all_params, local_permute_after_a2a, local_unpermute_before_a2a) and the pure-JAX unfused_token_dispatch / unfused_token_combine paths
    with custom VJPs.
  • tests/jax/test_moe_block.py -- single-device shape, backward,
    cross-backend equivalence, aux-loss, group-topk, JIT determinism.
  • tests/jax/test_distributed_moe_block.py -- EP=2 x FSDP=2 mesh test using the canonical Flax-Linen sharded-init pattern (eval_shape -> get_partition_spec -> logical_to_mesh_sharding -> jit(init, out_shardings=...)) and data_parallelism_axes=("fsdp",) to exercise true FSDP (batch sharded across both axes).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@tdophung tdophung marked this pull request as ready for review May 5, 2026 21:47
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 5, 2026

Greptile Summary

This PR adds a large new MoEBlock functional module to TransformerEngine's JAX backend, wiring together the fused router, pluggable permutation backends (PURE_JAX / Triton), grouped_gemm expert FFNs, and ragged-all-to-all EP via a single jax.custom_vjp boundary — deliberately avoiding nested custom-VJPs so FP8 ScaledTensors can survive the EP wire in future work.

  • transformer_engine/jax/moe.py (2027 lines): new file, single-VJP functional MoE entry point with _body_fwd/_body_bwd, spec builders, and top-level moe() public API.
  • transformer_engine/jax/permutation.py: adds pure-JAX dispatch/combine with custom VJPs, EP ragged-A2A helpers, and a public routing_map_to_selected_experts converter.
  • transformer_engine/jax/flax/moe.py: thin Flax-Linen wrapper that registers params and delegates to moe().
  • New multiprocess and single-device tests; sharding.py gains ep_resource + get_active_resource_axis.

Confidence Score: 3/5

The no-EP single-device path looks functionally correct and is well-tested; the EP backward path contains at least one known defect (ctx["expert_outputs"] used as the wrong tensor in _combine_bwd) plus a likely shard_map spec mismatch when using TRITON backend with EP and default alignment — both affect users as soon as they combine EP with the TRITON permutation backend.

Two defects affect the EP backward path: the previously-flagged ctx["expert_outputs"] wrong-tensor issue (silently wrong gradients for gate/FFN weights under EP) and a newly-identified structural mismatch where _build_dispatch_specs assigns P() for pad_offsets while _dispatch stores None there when align_size=0. The default _MoEBlock._align_size=0 means every TRITON+EP user hits the spec mismatch on first run. The no-EP single-device path is independently verifiable and clean.

transformer_engine/jax/moe.py needs the most scrutiny: _combine_bwd's ctx["expert_outputs"] under EP and the _build_dispatch_specs pad_offsets spec mismatch both live there.

Important Files Changed

Filename Overview
transformer_engine/jax/moe.py New 2027-line file implementing the single-VJP MoE functional entry; contains dead routing-gradient code in _body_bwd (lines 1295-1296 overwritten immediately), and the previously-flagged ctx["expert_outputs"] wrong-tensor issue in _combine_bwd under EP.
transformer_engine/jax/permutation.py Adds pure-JAX custom_vjp dispatch/combine, EP ragged-A2A helpers, and public routing_map_to_selected_experts; hardcoded align_size=128 in _token_combine_bwd_rule and assert-after-computation guard remain as previously flagged; pad_offsets=None vs P() spec mismatch for TRITON+EP needs attention.
transformer_engine/jax/flax/moe.py Thin Flax-Linen wrapper that registers gate/wi/wo params and delegates all logic to the new moe() functional entry; clean and minimal as intended.
transformer_engine/jax/sharding.py Adds ep_resource field to MeshResource and the get_active_resource_axis() helper; straightforward additive change with no issues.
transformer_engine/common/triton/permutation.py Refactors repeated autotune config lists into a shared _permutation_autotune_configs() helper; purely cosmetic deduplication with no behavioral change.
tests/jax/test_multiprocess_moe_vjp.py EP=2 x FSDP=2 distributed backward test; exercises PURE_JAX backend thoroughly, but the TRITON backend test (marked pytest.mark.triton) would be the first to trigger the pad_offsets=None spec mismatch bug.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["moe() entry\n[B, S, H]"] --> B["gate GEMM\nlogits_2d [T, E]"]
    B --> C["fused_topk_with_score_function_fwd\nrouting_map, sparse_probs, saved_scores"]
    C -->|"aux_loss_coeff > 0"| D["aux_loss computation\n(all_gather logits under EP)"]
    C --> E["_dispatch\n[global permute + A2A + local permute]"]
    E -->|"EP active"| E1["ragged_all_to_all fwd\n→ sorted_x [recv_buffer_rows, H]"]
    E -->|"no EP"| E2["sorted_x\n[num_real+pad, H]"]
    E1 --> F["grouped_gemm x3 + SwiGLU FFN\nexpert_outputs [recv_buffer_rows, H]"]
    E2 --> F
    F --> G["_combine\n[inv local permute + rev A2A + global unpermute]"]
    G --> H["output [B, S, H]\naux_loss scalar"]
    H -->|backward| I["_combine_bwd\nd_output → d_expert_outputs\n⚠ uses ctx.expert_outputs wrong under EP"]
    I --> J["FFN bwd\nGEMM3/2/1 + activation bwd\nd_sorted_x"]
    J --> K["_dispatch_bwd\nrev A2A + global unpermute"]
    K --> L["routing bwd\nfused_topk_bwd + gate bwd"]
    L --> M["grads dict\n{inputs, gate_kernel, wi_0, wi_1, wo}"]
Loading

Reviews (10): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/permutation.py
Comment thread transformer_engine/jax/flax/moe.py Outdated
tdophung added 6 commits May 5, 2026 16:35
Signed-off-by: tdophung <tdophung@nvidia.com>
…ody single GPU vs. multi GPU

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e and single device initial params in the MoEBlock. Tests should pass now

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung force-pushed the teddy/moe_block branch from 8a838f3 to 6aeb491 Compare May 5, 2026 23:44
pre-commit-ci Bot and others added 2 commits May 5, 2026 23:45
Signed-off-by: tdophung <tdophung@nvidia.com>
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread tests/jax/test_distributed_moe_block.py Outdated
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread transformer_engine/jax/permutation.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
nvjax and others added 2 commits May 7, 2026 15:18
…int in C++ files, make FP8 works. Tested with current scaling

Signed-off-by: JAX Toolbox <jax@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 7, 2026

Want your agent to iterate on Greptile's feedback? Try greploops.

Comment thread transformer_engine/common/util/multi_stream.cpp Outdated
Comment thread transformer_engine/jax/csrc/extensions/gemm.cpp Outdated
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/gemm.cpp Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread transformer_engine/common/util/multi_stream.cpp Outdated
… grad tol to 5e-2, move arch/align_size docs into MoEBlock class

Signed-off-by: tdophung <tdophung@nvidia.com>
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment on lines +909 to +914
batch_divisor = num_ep * dp_size
if global_batch_size % batch_divisor != 0:
raise ValueError(
f"batch={global_batch_size} not divisible by prod(data_parallelism_axes)={dp_size}"
)
recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Receive buffer undersized when align_size > 0 + EP are combined

recv_buffer_rows is computed assuming unpadded token counts, but when align_size > 0 the per-expert group_sizes are the aligned counts, so the send_sizes in compute_ragged_all_to_all_params include padding tokens. The worst-case receive per shard is num_ep * ((B/(num_ep*dp_size))*S*K + num_experts_per_shard*(align_size-1)), which exceeds the current recv_buffer_rows = (B/dp_size)*S*K by up to num_experts*(align_size-1) rows. ragged_all_to_all writing beyond the buffer produces incorrect results or a crash. The correct worst-case size is:

recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk + num_experts * (self.align_size - 1 if self.align_size > 0 else 0)

This combination (EP + align_size > 0) is not exercised by the current distributed test (which defaults to align_size=0), so the bug is latent.

Copy link
Copy Markdown
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should go with exposing GroupMLP VJP first before the MoE module to enable future possible fusions.

tdophung added 3 commits May 12, 2026 15:53
…ing None as group_topk, align_size rename,

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Comment thread transformer_engine/jax/flax/moe.py Outdated
@tdophung tdophung marked this pull request as draft May 15, 2026 16:46
@tdophung
Copy link
Copy Markdown
Collaborator Author

Changing back to draft to not spam people's email while I push commits to this branch for the full unrolling of ops in a big VJP.

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 21, 2026
Replace the per-primitive custom_vjp boundaries in MoEBlock with a
single jax.custom_vjp covering routing, dispatch, expert FFN, and
combine. Helper functions group permute -> ragged_all_to_all ->
local-permute into a single dispatch / combine pair, with a hand-
derived bwd that mirrors the forward and runs entirely inside the
EP shard_map body.

Add a multi-process (one-GPU-per-process) test suite for the new
unified VJP under a 2x2 (ep, fsdp) mesh:

  * tests/jax/test_multiprocess_moe_vjp.py
    -- fwd/bwd + aux_loss + PURE_JAX vs TRITON parity at
       Mixtral-ish shapes (batch=16, seq=2048, hidden=1024,
       intermediate=4096, num_experts=8, topk=2).
  * tests/jax/run_multiprocess_moe_vjp.sh
    -- launcher; forks one pytest process per visible GPU
       (mirrors examples/jax/encoder/run_test_multiprocessing_encoder.sh).
  * tests/jax/conftest.py
    -- pytest --num-process / --process-id options for the launcher.
  * qa/L0_jax_distributed_unittest/test.sh
    -- CI hook for the multiprocess smoke.

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung marked this pull request as ready for review May 21, 2026 22:11
Signed-off-by: Teddy Do <tdophung@nvidia.com>
Comment thread tests/jax/test_moe_vjp.py Outdated
# -----------------------------------------------------------------------------


@pytest.mark.triton
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this skip these tests if triton isn't installed? If so, we should do this manually and only skip when backend_name == triton

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment on permutation.py about removing triton dependency for pure-JAX permute backend

Comment thread tests/jax/test_moe_vjp.py

Both backends support optional alignment padding (``align_size > 0``) so each
expert's group size is a multiple of ``align_size``, which is required for
quantized grouped GEMMs.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that permutation.py has two backends, I think we should update this file to remove our triton dependency if using the pure-JAX backend.

I think we can do this in the following way:

  1. Gate imports from triton_extensions here to only apply if triton is installed and defer any errors about missing triton installation until the triton backend is invoked, no errors for the pure-JAX backend.
  2. Then transformer_engine/jax/moe.py and transformer_engine/jax/flax/moe.py no longer depend on triton
  3. We can remove the _inject_moe logic from the test files and also remove the #noqa: F821 linter suppresions that _inject_moe requires.

Comment thread transformer_engine/jax/permutation.py Outdated
Comment thread transformer_engine/jax/moe.py Outdated
Comment thread transformer_engine/jax/moe.py Outdated
sorted_x, q_set_w1.x, local_group_sizes, flatten_axis=-1
)
casted_wi_1 = tex.grouped_quantize(wi_1, q_set_w1.kernel, flatten_axis=-1)
layer_w1 = tex.grouped_gemm(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the w mean in layer_w1? When I see w I think weights but this is an intermediate

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's confusing, it means the intermediate result that was derived from using w1, I changed layer_w0 to gate_proj_out, and layer_w1 to up_proj_out.

Comment thread transformer_engine/jax/moe.py Outdated
if isinstance(casted_wi_0_rhs_trans, ScaledTensor):
casted_wi_0_rhs_trans = casted_wi_0_rhs_trans.checkpoint(q_set_w0.kernel)

# GEMM 2: layer_w1 = sorted_x @ wi_1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be that big of a change to fuse GEMMs 1 and 2, right? Our activation function supports gated activations already and having a single up-proj GEMM will make it easier to fuse act+quant in the future

As for the weight tensors, we could enforce a single contiguous wi weight tensor instead of wi_0 and wi_1. I don't think we need to support both fused and unfused up-proj

)

# ---------------- Build ctx dict ----------------
ctx: dict = {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of fields here, so I agree a dict is better than a tuple. But I think it'd be even easier to read if this was a dataclass

Comment thread transformer_engine/jax/moe.py Outdated
# must use the per-shard shape rather than the captured global
# ``x_shape``.
if ep_active:
import math as _math # local import keeps the no-EP path zero-overhead.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can import math at the top of the file. We don't have to import inline or gate it behind any logic since it's a built-in library

Comment thread transformer_engine/jax/moe.py Outdated
# the fwd rule. That helper preserves the ORIGINAL positional order
# of the decorated function: dyn (= diff) args sit at their original
# positions and static (= nondiff) args fill the remaining slots in
# nondiff_argnums order. So the fwd rule MUST take args in the
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels pretty verbose. Do we need this comment or was this WIP notes from an agent and we can remove it now.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's WIP from an agent, trimmed it down to 2 lines now

Add MXFP8BlockScaling alongside bf16 to test_fwd_and_bwd /
test_aux_loss / test_pure_jax_triton_parity. Tests skip the FP8
recipe on GPUs older than sm_100. Tolerance widened from 5e-2 to
3e-1 for MXFP8 parity to absorb block-scale quantization noise.

WIP -- DO NOT MERGE into teddy/moe_block until dlcluster verifies
all six (recipe, backend) combos pass.

Signed-off-by: tdophung <tdophung@nvidia.com>
Comment on lines +568 to +700
# Per-shard compile-time-constant shape info (Python ints / int tuples).
# See ``_compute_static_shape_info`` and the note in ``_dispatch``
# for why these are kwargs rather than state-dict entries.
num_real_tokens: int,
padding_size: int,
post_a2a_buffer_shape: Optional[Tuple[int, int]],
# EP-only:
ep_axis: Optional[str],
shard_id: Optional[jnp.ndarray] = None,
num_ep: int = 1,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Inverse of :func:`_combine` on the cotangent.

Returns ``(d_expert_outputs, d_routing_weights_or_merging_probs)``.

``expert_outputs`` is the *forward* output of the FFN (same value the
fwd handed to :func:`_combine`). It's required by the TRITON
combine_bwd kernel; for PURE_JAX we don't need it but accept it for
a symmetric signature.
"""
# Step 3 inverse: global combine bwd.
d_output_2d = d_output.reshape(-1, d_output.shape[-1])
if backend is PermutationBackend.PURE_JAX:
# The pure-jax combine is:
# unsort = _sort_activations(expert_outputs, argsort(sorted_indices))
# if pad: unsort = unsort[:num_real]
# reshape -> einsum BKE,BK -> BE -> reshape to BSE
# Hand-derive the bwd in plain JAX (no custom_vjp involved):
unsort_indices = jnp.argsort(state["sorted_indices"])
topk = num_experts_per_tok
num_real = num_real_tokens
padding = padding_size
# Recover the unsorted intermediate that the fwd produced (we
# need it for the d_routing_weights pullback). Apply the same
# gather the fwd did.
unsort_intermediate = expert_outputs[unsort_indices]
if padding > 0:
unsort_intermediate = unsort_intermediate[:num_real]
# Bwd of einsum/reshape:
# output[B, E] = sum_K intermediate[B, K, E] * weights[B, K]
# d_intermediate[B, K, E] = d_output[B, E] * weights[B, K]
# d_weights[B, K] = sum_E d_output[B, E] * intermediate[B, K, E]
rw = state["routing_weights"].reshape(-1, topk)
intermediate_3d = unsort_intermediate.reshape(rw.shape[0], topk, -1)
rw_cast = rw.astype(intermediate_3d.dtype)
d_intermediate_3d = jnp.einsum("BE,BK -> BKE", d_output_2d, rw_cast)
d_routing_weights = jnp.einsum("BE,BKE -> BK", d_output_2d, intermediate_3d).astype(
state["routing_weights"].dtype
)
d_routing_weights = d_routing_weights.reshape(state["routing_weights"].shape)
d_unsort_intermediate = d_intermediate_3d.reshape(num_real, -1)
# Pad back with zeros if the fwd stripped padding.
if padding > 0:
d_unsort_intermediate = jnp.concatenate(
[
d_unsort_intermediate,
jnp.zeros(
(padding, d_unsort_intermediate.shape[-1]),
dtype=d_unsort_intermediate.dtype,
),
],
axis=0,
)
# Bwd of the gather is gather-by-original-indices:
# sorted = unsort[argsort(sorted_indices)]
# d_sorted = scatter d_unsort via argsort(sorted_indices)
# = d_unsort[sorted_indices] (gather by original sorted_indices,
# which is the inverse of argsort(sorted_indices)).
d_expert_outputs_global = d_unsort_intermediate[state["sorted_indices"]]
else:
# TRITON combine bwd: requires fwd_input (expert_outputs).
num_tokens = state["row_id_map"].shape[0]
n_experts = (state["row_id_map"].shape[1] - 1) // 2
hidden = d_output_2d.shape[-1]
num_out_tokens = expert_outputs.shape[0]
if state["pad_offsets"] is not None:
d_expert_outputs_global, d_merging_probs = unpermute_bwd_with_merging_probs_and_unpad(
d_output_2d,
state["row_id_map"],
expert_outputs,
state["merging_probs"],
state["pad_offsets"],
num_tokens,
n_experts,
num_out_tokens,
hidden,
)
# The kernel only writes positions tokens map to; padded
# positions may contain NaN. Replace with zeros (matches
# ``_token_combine_bwd_rule``).
d_expert_outputs_global = jnp.where(
jnp.isnan(d_expert_outputs_global), 0.0, d_expert_outputs_global
)
else:
d_expert_outputs_global, d_merging_probs = unpermute_bwd_with_merging_probs(
d_output_2d,
state["row_id_map"],
expert_outputs,
state["merging_probs"],
num_tokens,
n_experts,
num_out_tokens,
hidden,
)
d_routing_weights = d_merging_probs

if not ep_active:
return d_expert_outputs_global, d_routing_weights

# Step 2 (EP) inverse: bwd of reverse ragged_all_to_all is a forward
# ragged_all_to_all using the SAME forward parameters (sender /
# receiver roles swap from the reverse direction back to forward).
in_off_f, send_sz_f, out_off_f, recv_sz_f = compute_ragged_all_to_all_params(
state["all_shards_tokens_per_expert"], shard_id, num_ep
)
recv_buf_for_bwd = jnp.zeros(post_a2a_buffer_shape, dtype=d_expert_outputs_global.dtype)
d_x_send_back = jax.lax.ragged_all_to_all(
d_expert_outputs_global,
recv_buf_for_bwd,
in_off_f,
send_sz_f,
out_off_f,
recv_sz_f,
axis_name=ep_axis,
)
# Step 1 (EP) inverse: combine fwd applied is_forward=False; the
# bwd is is_forward=True with the SAME row_id_map.
recv_buffer_rows, hidden = d_x_send_back.shape
d_expert_outputs, _ = sort_chunks_by_map(
d_x_send_back,
state["local_perm_row_id_map"],
None,
recv_buffer_rows,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 ctx["expert_outputs"] is the wrong tensor for the combine backward under EP

_body_fwd stores the shard-local FFN output (shape [recv_buffer_rows, hidden]) in ctx["expert_outputs"] before calling _combine. However, the global-combine step (step 3) inside _combine operates on a different tensor: the globally-permuted FFN output (shape [num_real_tokens + padding_size, hidden]) produced after the EP inverse-local-permute and reverse-ragged-A2A steps (steps 1–2). _combine_bwd then receives ctx["expert_outputs"] (shard-local, expert-major) and uses it as if it were the globally-permuted tensor.

Under PURE_JAX + EP: unsort_intermediate = expert_outputs[unsort_indices] indexes the shard-local tensor with global argsort indices; the recovered intermediate has wrong content, so d_routing_weights (and the resulting gate-kernel gradient) is silently incorrect.

Under TRITON + EP: expert_outputs is passed directly to unpermute_bwd_with_merging_probs[_and_unpad] with num_out_tokens = expert_outputs.shape[0] = recv_buffer_rows instead of num_real_tokens + padding_size; both d_expert_outputs_global and d_merging_probs are wrong, corrupting expert-weight gradients as well.

The existing distributed tests only assert finiteness and non-zero gradients; because both backends suffer the same mismatch, the backend-parity test passes even when both produce wrong gradients.

Fix: either (a) run the EP steps (inverse local permute + reverse ragged-A2A) before calling _combine / _combine_bwd, save the resulting globally-permuted tensor in ctx["expert_outputs_global_sorted"], and pass it to step 3; or (b) re-execute the EP forward steps inside _combine_bwd to recover the globally-permuted output before delegating to the step-3 backward.

tdophung and others added 2 commits May 21, 2026 17:40
transformer_engine/jax/moe.py:
- Hoist 'import math' to module top (was two local imports).
- Trim the verbose _moe_fwd_rule arg-order block comment.
- Update PermutationBackend docstring: TRITON is the recommended
  default and is faster than PURE_JAX on current hardware.
- Rename layer_w0 / layer_w1 to gate_proj_out / up_proj_out so the
  names reflect what they are (SwiGLU projection outputs, not weights).
- moe() now rejects overlapping EP / FSDP axes up front instead of
  letting JAX produce a duplicate-axis PartitionSpec.

transformer_engine/jax/permutation.py:
- Drop reference to the temporary MaxText-fork
  compute_ragged_all_to_all_params helpers.

tests/jax/test_moe_vjp.py, tests/jax/test_multiprocess_moe_vjp.py:
- Add a module-level Blackwell (sm_100+) skip; grouped GEMM is
  Blackwell-only today.
- Move the 'triton' pytest marker from the class onto the
  triton parametrize variant only, so the pure_jax variant
  still runs in environments without Triton.

Signed-off-by: tdophung <tdophung@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[JAX] Create initial MoE Block

4 participants