[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912
[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912tdophung wants to merge 19 commits into
Conversation
Greptile SummaryThis PR adds a large new
Confidence Score: 3/5The no-EP single-device path looks functionally correct and is well-tested; the EP backward path contains at least one known defect (ctx["expert_outputs"] used as the wrong tensor in _combine_bwd) plus a likely shard_map spec mismatch when using TRITON backend with EP and default alignment — both affect users as soon as they combine EP with the TRITON permutation backend. Two defects affect the EP backward path: the previously-flagged ctx["expert_outputs"] wrong-tensor issue (silently wrong gradients for gate/FFN weights under EP) and a newly-identified structural mismatch where _build_dispatch_specs assigns P() for pad_offsets while _dispatch stores None there when align_size=0. The default _MoEBlock._align_size=0 means every TRITON+EP user hits the spec mismatch on first run. The no-EP single-device path is independently verifiable and clean. transformer_engine/jax/moe.py needs the most scrutiny: _combine_bwd's ctx["expert_outputs"] under EP and the _build_dispatch_specs pad_offsets spec mismatch both live there. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["moe() entry\n[B, S, H]"] --> B["gate GEMM\nlogits_2d [T, E]"]
B --> C["fused_topk_with_score_function_fwd\nrouting_map, sparse_probs, saved_scores"]
C -->|"aux_loss_coeff > 0"| D["aux_loss computation\n(all_gather logits under EP)"]
C --> E["_dispatch\n[global permute + A2A + local permute]"]
E -->|"EP active"| E1["ragged_all_to_all fwd\n→ sorted_x [recv_buffer_rows, H]"]
E -->|"no EP"| E2["sorted_x\n[num_real+pad, H]"]
E1 --> F["grouped_gemm x3 + SwiGLU FFN\nexpert_outputs [recv_buffer_rows, H]"]
E2 --> F
F --> G["_combine\n[inv local permute + rev A2A + global unpermute]"]
G --> H["output [B, S, H]\naux_loss scalar"]
H -->|backward| I["_combine_bwd\nd_output → d_expert_outputs\n⚠ uses ctx.expert_outputs wrong under EP"]
I --> J["FFN bwd\nGEMM3/2/1 + activation bwd\nd_sorted_x"]
J --> K["_dispatch_bwd\nrev A2A + global unpermute"]
K --> L["routing bwd\nfused_topk_bwd + gate bwd"]
L --> M["grads dict\n{inputs, gate_kernel, wi_0, wi_1, wo}"]
Reviews (10): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
Signed-off-by: tdophung <tdophung@nvidia.com>
…ody single GPU vs. multi GPU Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e and single device initial params in the MoEBlock. Tests should pass now Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
…int in C++ files, make FP8 works. Tested with current scaling Signed-off-by: JAX Toolbox <jax@nvidia.com>
for more information, see https://pre-commit.ci
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
… grad tol to 5e-2, move arch/align_size docs into MoEBlock class Signed-off-by: tdophung <tdophung@nvidia.com>
| batch_divisor = num_ep * dp_size | ||
| if global_batch_size % batch_divisor != 0: | ||
| raise ValueError( | ||
| f"batch={global_batch_size} not divisible by prod(data_parallelism_axes)={dp_size}" | ||
| ) | ||
| recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk |
There was a problem hiding this comment.
Receive buffer undersized when
align_size > 0 + EP are combined
recv_buffer_rows is computed assuming unpadded token counts, but when align_size > 0 the per-expert group_sizes are the aligned counts, so the send_sizes in compute_ragged_all_to_all_params include padding tokens. The worst-case receive per shard is num_ep * ((B/(num_ep*dp_size))*S*K + num_experts_per_shard*(align_size-1)), which exceeds the current recv_buffer_rows = (B/dp_size)*S*K by up to num_experts*(align_size-1) rows. ragged_all_to_all writing beyond the buffer produces incorrect results or a crash. The correct worst-case size is:
recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk + num_experts * (self.align_size - 1 if self.align_size > 0 else 0)
This combination (EP + align_size > 0) is not exercised by the current distributed test (which defaults to align_size=0), so the bug is latent.
phu0ngng
left a comment
There was a problem hiding this comment.
I think we should go with exposing GroupMLP VJP first before the MoE module to enable future possible fusions.
…ing None as group_topk, align_size rename, Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
|
Changing back to draft to not spam people's email while I push commits to this branch for the full unrolling of ops in a big VJP. |
Replace the per-primitive custom_vjp boundaries in MoEBlock with a
single jax.custom_vjp covering routing, dispatch, expert FFN, and
combine. Helper functions group permute -> ragged_all_to_all ->
local-permute into a single dispatch / combine pair, with a hand-
derived bwd that mirrors the forward and runs entirely inside the
EP shard_map body.
Add a multi-process (one-GPU-per-process) test suite for the new
unified VJP under a 2x2 (ep, fsdp) mesh:
* tests/jax/test_multiprocess_moe_vjp.py
-- fwd/bwd + aux_loss + PURE_JAX vs TRITON parity at
Mixtral-ish shapes (batch=16, seq=2048, hidden=1024,
intermediate=4096, num_experts=8, topk=2).
* tests/jax/run_multiprocess_moe_vjp.sh
-- launcher; forks one pytest process per visible GPU
(mirrors examples/jax/encoder/run_test_multiprocessing_encoder.sh).
* tests/jax/conftest.py
-- pytest --num-process / --process-id options for the launcher.
* qa/L0_jax_distributed_unittest/test.sh
-- CI hook for the multiprocess smoke.
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: Teddy Do <tdophung@nvidia.com>
| # ----------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| @pytest.mark.triton |
There was a problem hiding this comment.
Will this skip these tests if triton isn't installed? If so, we should do this manually and only skip when backend_name == triton
There was a problem hiding this comment.
See comment on permutation.py about removing triton dependency for pure-JAX permute backend
|
|
||
| Both backends support optional alignment padding (``align_size > 0``) so each | ||
| expert's group size is a multiple of ``align_size``, which is required for | ||
| quantized grouped GEMMs. |
There was a problem hiding this comment.
Now that permutation.py has two backends, I think we should update this file to remove our triton dependency if using the pure-JAX backend.
I think we can do this in the following way:
- Gate imports from triton_extensions here to only apply if triton is installed and defer any errors about missing triton installation until the triton backend is invoked, no errors for the pure-JAX backend.
- Then
transformer_engine/jax/moe.pyandtransformer_engine/jax/flax/moe.pyno longer depend on triton - We can remove the
_inject_moelogic from the test files and also remove the#noqa: F821linter suppresions that_inject_moerequires.
| sorted_x, q_set_w1.x, local_group_sizes, flatten_axis=-1 | ||
| ) | ||
| casted_wi_1 = tex.grouped_quantize(wi_1, q_set_w1.kernel, flatten_axis=-1) | ||
| layer_w1 = tex.grouped_gemm( |
There was a problem hiding this comment.
What does the w mean in layer_w1? When I see w I think weights but this is an intermediate
There was a problem hiding this comment.
yeah that's confusing, it means the intermediate result that was derived from using w1, I changed layer_w0 to gate_proj_out, and layer_w1 to up_proj_out.
| if isinstance(casted_wi_0_rhs_trans, ScaledTensor): | ||
| casted_wi_0_rhs_trans = casted_wi_0_rhs_trans.checkpoint(q_set_w0.kernel) | ||
|
|
||
| # GEMM 2: layer_w1 = sorted_x @ wi_1 |
There was a problem hiding this comment.
It shouldn't be that big of a change to fuse GEMMs 1 and 2, right? Our activation function supports gated activations already and having a single up-proj GEMM will make it easier to fuse act+quant in the future
As for the weight tensors, we could enforce a single contiguous wi weight tensor instead of wi_0 and wi_1. I don't think we need to support both fused and unfused up-proj
| ) | ||
|
|
||
| # ---------------- Build ctx dict ---------------- | ||
| ctx: dict = { |
There was a problem hiding this comment.
There are a lot of fields here, so I agree a dict is better than a tuple. But I think it'd be even easier to read if this was a dataclass
| # must use the per-shard shape rather than the captured global | ||
| # ``x_shape``. | ||
| if ep_active: | ||
| import math as _math # local import keeps the no-EP path zero-overhead. |
There was a problem hiding this comment.
I think we can import math at the top of the file. We don't have to import inline or gate it behind any logic since it's a built-in library
| # the fwd rule. That helper preserves the ORIGINAL positional order | ||
| # of the decorated function: dyn (= diff) args sit at their original | ||
| # positions and static (= nondiff) args fill the remaining slots in | ||
| # nondiff_argnums order. So the fwd rule MUST take args in the |
There was a problem hiding this comment.
This feels pretty verbose. Do we need this comment or was this WIP notes from an agent and we can remove it now.
There was a problem hiding this comment.
it's WIP from an agent, trimmed it down to 2 lines now
Add MXFP8BlockScaling alongside bf16 to test_fwd_and_bwd / test_aux_loss / test_pure_jax_triton_parity. Tests skip the FP8 recipe on GPUs older than sm_100. Tolerance widened from 5e-2 to 3e-1 for MXFP8 parity to absorb block-scale quantization noise. WIP -- DO NOT MERGE into teddy/moe_block until dlcluster verifies all six (recipe, backend) combos pass. Signed-off-by: tdophung <tdophung@nvidia.com>
| # Per-shard compile-time-constant shape info (Python ints / int tuples). | ||
| # See ``_compute_static_shape_info`` and the note in ``_dispatch`` | ||
| # for why these are kwargs rather than state-dict entries. | ||
| num_real_tokens: int, | ||
| padding_size: int, | ||
| post_a2a_buffer_shape: Optional[Tuple[int, int]], | ||
| # EP-only: | ||
| ep_axis: Optional[str], | ||
| shard_id: Optional[jnp.ndarray] = None, | ||
| num_ep: int = 1, | ||
| ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: | ||
| """Inverse of :func:`_combine` on the cotangent. | ||
|
|
||
| Returns ``(d_expert_outputs, d_routing_weights_or_merging_probs)``. | ||
|
|
||
| ``expert_outputs`` is the *forward* output of the FFN (same value the | ||
| fwd handed to :func:`_combine`). It's required by the TRITON | ||
| combine_bwd kernel; for PURE_JAX we don't need it but accept it for | ||
| a symmetric signature. | ||
| """ | ||
| # Step 3 inverse: global combine bwd. | ||
| d_output_2d = d_output.reshape(-1, d_output.shape[-1]) | ||
| if backend is PermutationBackend.PURE_JAX: | ||
| # The pure-jax combine is: | ||
| # unsort = _sort_activations(expert_outputs, argsort(sorted_indices)) | ||
| # if pad: unsort = unsort[:num_real] | ||
| # reshape -> einsum BKE,BK -> BE -> reshape to BSE | ||
| # Hand-derive the bwd in plain JAX (no custom_vjp involved): | ||
| unsort_indices = jnp.argsort(state["sorted_indices"]) | ||
| topk = num_experts_per_tok | ||
| num_real = num_real_tokens | ||
| padding = padding_size | ||
| # Recover the unsorted intermediate that the fwd produced (we | ||
| # need it for the d_routing_weights pullback). Apply the same | ||
| # gather the fwd did. | ||
| unsort_intermediate = expert_outputs[unsort_indices] | ||
| if padding > 0: | ||
| unsort_intermediate = unsort_intermediate[:num_real] | ||
| # Bwd of einsum/reshape: | ||
| # output[B, E] = sum_K intermediate[B, K, E] * weights[B, K] | ||
| # d_intermediate[B, K, E] = d_output[B, E] * weights[B, K] | ||
| # d_weights[B, K] = sum_E d_output[B, E] * intermediate[B, K, E] | ||
| rw = state["routing_weights"].reshape(-1, topk) | ||
| intermediate_3d = unsort_intermediate.reshape(rw.shape[0], topk, -1) | ||
| rw_cast = rw.astype(intermediate_3d.dtype) | ||
| d_intermediate_3d = jnp.einsum("BE,BK -> BKE", d_output_2d, rw_cast) | ||
| d_routing_weights = jnp.einsum("BE,BKE -> BK", d_output_2d, intermediate_3d).astype( | ||
| state["routing_weights"].dtype | ||
| ) | ||
| d_routing_weights = d_routing_weights.reshape(state["routing_weights"].shape) | ||
| d_unsort_intermediate = d_intermediate_3d.reshape(num_real, -1) | ||
| # Pad back with zeros if the fwd stripped padding. | ||
| if padding > 0: | ||
| d_unsort_intermediate = jnp.concatenate( | ||
| [ | ||
| d_unsort_intermediate, | ||
| jnp.zeros( | ||
| (padding, d_unsort_intermediate.shape[-1]), | ||
| dtype=d_unsort_intermediate.dtype, | ||
| ), | ||
| ], | ||
| axis=0, | ||
| ) | ||
| # Bwd of the gather is gather-by-original-indices: | ||
| # sorted = unsort[argsort(sorted_indices)] | ||
| # d_sorted = scatter d_unsort via argsort(sorted_indices) | ||
| # = d_unsort[sorted_indices] (gather by original sorted_indices, | ||
| # which is the inverse of argsort(sorted_indices)). | ||
| d_expert_outputs_global = d_unsort_intermediate[state["sorted_indices"]] | ||
| else: | ||
| # TRITON combine bwd: requires fwd_input (expert_outputs). | ||
| num_tokens = state["row_id_map"].shape[0] | ||
| n_experts = (state["row_id_map"].shape[1] - 1) // 2 | ||
| hidden = d_output_2d.shape[-1] | ||
| num_out_tokens = expert_outputs.shape[0] | ||
| if state["pad_offsets"] is not None: | ||
| d_expert_outputs_global, d_merging_probs = unpermute_bwd_with_merging_probs_and_unpad( | ||
| d_output_2d, | ||
| state["row_id_map"], | ||
| expert_outputs, | ||
| state["merging_probs"], | ||
| state["pad_offsets"], | ||
| num_tokens, | ||
| n_experts, | ||
| num_out_tokens, | ||
| hidden, | ||
| ) | ||
| # The kernel only writes positions tokens map to; padded | ||
| # positions may contain NaN. Replace with zeros (matches | ||
| # ``_token_combine_bwd_rule``). | ||
| d_expert_outputs_global = jnp.where( | ||
| jnp.isnan(d_expert_outputs_global), 0.0, d_expert_outputs_global | ||
| ) | ||
| else: | ||
| d_expert_outputs_global, d_merging_probs = unpermute_bwd_with_merging_probs( | ||
| d_output_2d, | ||
| state["row_id_map"], | ||
| expert_outputs, | ||
| state["merging_probs"], | ||
| num_tokens, | ||
| n_experts, | ||
| num_out_tokens, | ||
| hidden, | ||
| ) | ||
| d_routing_weights = d_merging_probs | ||
|
|
||
| if not ep_active: | ||
| return d_expert_outputs_global, d_routing_weights | ||
|
|
||
| # Step 2 (EP) inverse: bwd of reverse ragged_all_to_all is a forward | ||
| # ragged_all_to_all using the SAME forward parameters (sender / | ||
| # receiver roles swap from the reverse direction back to forward). | ||
| in_off_f, send_sz_f, out_off_f, recv_sz_f = compute_ragged_all_to_all_params( | ||
| state["all_shards_tokens_per_expert"], shard_id, num_ep | ||
| ) | ||
| recv_buf_for_bwd = jnp.zeros(post_a2a_buffer_shape, dtype=d_expert_outputs_global.dtype) | ||
| d_x_send_back = jax.lax.ragged_all_to_all( | ||
| d_expert_outputs_global, | ||
| recv_buf_for_bwd, | ||
| in_off_f, | ||
| send_sz_f, | ||
| out_off_f, | ||
| recv_sz_f, | ||
| axis_name=ep_axis, | ||
| ) | ||
| # Step 1 (EP) inverse: combine fwd applied is_forward=False; the | ||
| # bwd is is_forward=True with the SAME row_id_map. | ||
| recv_buffer_rows, hidden = d_x_send_back.shape | ||
| d_expert_outputs, _ = sort_chunks_by_map( | ||
| d_x_send_back, | ||
| state["local_perm_row_id_map"], | ||
| None, | ||
| recv_buffer_rows, |
There was a problem hiding this comment.
ctx["expert_outputs"] is the wrong tensor for the combine backward under EP
_body_fwd stores the shard-local FFN output (shape [recv_buffer_rows, hidden]) in ctx["expert_outputs"] before calling _combine. However, the global-combine step (step 3) inside _combine operates on a different tensor: the globally-permuted FFN output (shape [num_real_tokens + padding_size, hidden]) produced after the EP inverse-local-permute and reverse-ragged-A2A steps (steps 1–2). _combine_bwd then receives ctx["expert_outputs"] (shard-local, expert-major) and uses it as if it were the globally-permuted tensor.
Under PURE_JAX + EP: unsort_intermediate = expert_outputs[unsort_indices] indexes the shard-local tensor with global argsort indices; the recovered intermediate has wrong content, so d_routing_weights (and the resulting gate-kernel gradient) is silently incorrect.
Under TRITON + EP: expert_outputs is passed directly to unpermute_bwd_with_merging_probs[_and_unpad] with num_out_tokens = expert_outputs.shape[0] = recv_buffer_rows instead of num_real_tokens + padding_size; both d_expert_outputs_global and d_merging_probs are wrong, corrupting expert-weight gradients as well.
The existing distributed tests only assert finiteness and non-zero gradients; because both backends suffer the same mismatch, the backend-parity test passes even when both produce wrong gradients.
Fix: either (a) run the EP steps (inverse local permute + reverse ragged-A2A) before calling _combine / _combine_bwd, save the resulting globally-permuted tensor in ctx["expert_outputs_global_sorted"], and pass it to step 3; or (b) re-execute the EP forward steps inside _combine_bwd to recover the globally-permuted output before delegating to the step-3 backward.
transformer_engine/jax/moe.py: - Hoist 'import math' to module top (was two local imports). - Trim the verbose _moe_fwd_rule arg-order block comment. - Update PermutationBackend docstring: TRITON is the recommended default and is faster than PURE_JAX on current hardware. - Rename layer_w0 / layer_w1 to gate_proj_out / up_proj_out so the names reflect what they are (SwiGLU projection outputs, not weights). - moe() now rejects overlapping EP / FSDP axes up front instead of letting JAX produce a duplicate-axis PartitionSpec. transformer_engine/jax/permutation.py: - Drop reference to the temporary MaxText-fork compute_ragged_all_to_all_params helpers. tests/jax/test_moe_vjp.py, tests/jax/test_multiprocess_moe_vjp.py: - Add a module-level Blackwell (sm_100+) skip; grouped GEMM is Blackwell-only today. - Move the 'triton' pytest marker from the class onto the triton parametrize variant only, so the pure_jax variant still runs in environments without Triton. Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
Description
Most of MoE building blocks integration work has been deeply coupled with Maxtext development. Now creating this MoE block to isolate the work from Maxtext and provide more room for experimentation.
MoEBlockis a self-contained Flax-Linen module that wires together TE's fused router, pluggable token-dispatch backends (pure-JAX argsort or Tritonsort_chunks_by_index),grouped_dense-based expert FFN, and ragged-all-to-all (A2Av) expert parallelism viashard_mapThis first iteration will start with ring-of-experts EP, sharding on batch dimention for FSDP, CUBLASLt groupedGEMM and 2 permutation backend: pure JAX or Triton kernels. The block also exposes pluggable knobs for: weight layout (
wi_kernel_axes/wo_kernel_axes), permutation backend, A2A vs no-EP (single GPU) path, data-parallelism axes for true FSDP (batch sharded across(ep, fsdp)simultaneously), top-K with optional grouped/sigmoid scoring (for DSv3 workload), and optional auxiliary load-balancing loss.Fixes #2895
Type of change
Changes
transformer_engine/jax/flax/moe.py--MoEBlockLinen module:gate -> fused topk -> global permute -> A2A EP shard_map (ragged_a2a fwd, local permute, 3x grouped GEMM SwiGLU FFN, local unpermute, ragged_a2a rev) -> global combine.
transformer_engine/jax/permutation.pywith A2A param helpers (compute_ragged_all_to_all_params,compute_reverse_ragged_all_to_all_params,local_permute_after_a2a,local_unpermute_before_a2a) and the pure-JAXunfused_token_dispatch/unfused_token_combinepathswith custom VJPs.
tests/jax/test_moe_block.py-- single-device shape, backward,cross-backend equivalence, aux-loss, group-topk, JIT determinism.
tests/jax/test_distributed_moe_block.py-- EP=2 x FSDP=2 mesh test using the canonical Flax-Linen sharded-init pattern (eval_shape -> get_partition_spec -> logical_to_mesh_sharding -> jit(init, out_shardings=...)) anddata_parallelism_axes=("fsdp",)to exercise true FSDP (batch sharded across both axes).Checklist: