[speechlm2] SALMAutomodel: THD (packed sequence) and context parallel support#15679
Draft
[speechlm2] SALMAutomodel: THD (packed sequence) and context parallel support#15679
Conversation
Adds an opt-in packed-sequence (THD) training/validation path so that ``SALMAutomodel`` can feed a Nemotron-V3 LLM via ``cu_seqlens``-aware varlen attention instead of the right-padded BSHD layout. Padding overhead drops from O(B*max_T) to O(rounding) per minibatch, which is substantial for variable-length speech inputs. Activated by ``model.packed_sequences: True`` in the YAML; the BSHD path is unchanged when the flag is unset. Generate / inference still use BSHD (it doesn't go through ``prepare_inputs``). Pieces: * ``parts/packed_sequences.py`` — concatenates per-utterance text + audio embeddings into a single flat ``[T_total, H]`` sequence with a ``cu_seqlens`` index, applies the per-utt next-token shift, and rounds each utterance's flat length up to a multiple of ``2*cp_size`` so the same packing also satisfies TE's CP DualChunkSwap contract. Output shape mirrors Automodel's canonical THD layout (``components/distributed/thd_utils.process_input_for_thd`` / ``cp_utils._shard_thd_chunk_for_te`` — 2D, no leading batch dim) so no extra squeeze/unsqueeze hops are needed. * ``parts/cp_helpers.py`` — ``get_cp_mesh`` reads the CP submesh out of the device mesh and returns ``(None, 1, 0)`` when CP is inactive. Used by ``prepare_packed_llm_inputs`` to short-circuit the CP-shard path. Further CP plumbing for the BSHD path lands in a follow-up. * ``models/salm_automodel.py`` — ``forward`` accepts ``**llm_kwargs`` (THD metadata: ``qkv_format``, ``cu_seqlens``, ``position_ids``, ``max_seqlen``); ``prepare_inputs`` returns the THD dict early when ``packed_sequences`` is set; ``training_step``/``validation_step`` splat ``llm_kwargs`` into the forward call and use a shape-generic ``logits.reshape(-1, V)`` so the same code handles both BSHD ``(B, T, V)`` and THD ``(T, V)`` outputs. Tests: * ``test_salm_packed_sequences.py`` — covers shape contracts, cu_seqlens invariants, per-utt next-token shift, audio-frame label masking, the ``cp_size``/``tp_size`` rounding, and the TE preprocessor regression test that pins down the ``cu_seqlens`` + ``max_seqlen`` (singular) contract. Includes BSHD-vs-THD pair-equivalence checks: the set of supervised ``(input_embedding, target_token_id)`` pairs reaching the cross-entropy must be identical between the two layouts on the same batch. * ``test_salm_cp_helpers.py`` — three CPU tests for ``get_cp_mesh`` covering the no-mesh, ``cp_size==1``, and missing-``cp``-axis paths.
Adds context-parallelism support to ``SALMAutomodel`` so that large
Nemotron-V3 LLMs with hybrid Mamba/attention layers can train on long
audio sequences across multiple GPUs. Builds on the THD packed-sequence
path from the previous commit; the BSHD path is also supported but the
THD path is the recommended configuration under CP.
Activated by ``cp_size > 1`` in the strategy config (e.g.
``AutomodelParallelStrategy(cp_size=2, ...)``); the existing TP
truncation path is folded into the CP padding so single-axis runs are
unchanged.
Pieces:
* ``parts/cp_helpers.py`` — extends the module with two CP-aware
helpers used by ``SALMAutomodel.prepare_inputs``:
- ``shard_bshd_for_cp`` pads the BSHD seq dim to a multiple of
``2*cp_size*tp_size`` and partitions along the seq dim using TE's
``thd_get_partitioned_indices`` (the same DualChunkSwap pattern
Automodel's ``Config 1`` reference test uses).
- ``encode_audio_with_cp_distribution`` distributes the audio
encoder forward across CP ranks instead of recomputing it
``cp_size`` times. Right-pads the audio batch with zero-audio
dummies so every rank participates in FSDP all-gather (and AC
fires uniformly), then all-gathers the variable-length embedding
tensors back so each rank reconstructs the full ordered list.
* ``models/salm_automodel.py`` — ``prepare_inputs`` derives the CP
mesh once via ``get_cp_mesh``, swaps the audio encoder call to the
CP-distributed version, and (for the BSHD branch) inserts a
``shard_bshd_for_cp`` step before the TP-truncation fallback. Under
CP the BSHD path also drops the padding mask before passing the
batch to the LLM (TE's fused-attention CP path supports ``causal``
but not ``padding_causal``); this is documented as a known
limitation, the durable fix is the THD packed-sequence path.
Tests:
* ``test_salm_cp_helpers.py`` — adds a ``_PerceptionStub`` and CPU
fallback tests for ``encode_audio_with_cp_distribution``
(``cp_mesh is None`` and ``B_aud == 0`` paths). The
``cp_size > 1`` paths in ``shard_bshd_for_cp`` and
``encode_audio_with_cp_distribution`` require ``transformer_engine_torch``
and a real ``torch.distributed`` process group respectively;
exercised by 2-GPU smoke tests.
Adds two new subsections under "AutomodelParallelStrategy (SALMAutomodel)" in the training-and-scaling guide: * "Packed Sequences (THD)" — explains the layout, when it helps (variable-length speech batches), and the YAML knob (``model.packed_sequences: true`` plus ``attn: te``). * "Context Parallelism (CP)" — explains the strategy knob (``cp_size > 1``), the BSHD vs THD pairing, and the recommended configuration. Documents the BSHD-under-CP padding-mask drop as a known limitation, with THD as the durable fix. Calls out the TransformerEngine 2.14 cuDNN-backend bug on certain GPU architectures (notably Blackwell sm_120) that returns correct THD forward activations but gradients amplified 8x-960x per layer, and the ``NVTE_FUSED_ATTN=0`` workaround that forces FlashAttention dispatch (which is gradient-correct on the same shapes). Adds a matching "Packed sequences (THD)" entry to the SALMAutomodel config reference, with a cross-reference to the training-and-scaling guide for the CP pairing.
Catches three configuration combinations that produce silent NaN
gradients or hangs at training time, and raises an informative error
(or warns where the bug is architecture-specific) before the user
spends ~7 minutes on model load only to watch their loss go NaN.
The check is a pure function in ``parts/parallel.py`` and runs from
``SALMAutomodel.on_fit_start`` once the device mesh is wired up.
Cases:
1. ``model.packed_sequences=false`` (BSHD) under ``cp_size > 1`` —
hard error. TE's fused-attention CP path rejects ``padding_causal``
so the right-pad mask is dropped, which lets pad K/V leak into
real-token attention through the causal mask and produces NaN
gradients after step 1. There is no supported workaround; the
error message points users to ``packed_sequences: true``.
2. ``model.packed_sequences=true`` (THD) with
``automodel_backend.attn != "te"`` — hard error. THD packing emits
a 2D ``[T_total, H]`` layout for TE varlen FlashAttention; the
SDPA THD code path in the Automodel branch transposes assuming 4D
BSHD inputs and breaks.
3. ``model.packed_sequences=true`` + ``attn="te"`` +
``NVTE_FUSED_ATTN != "0"`` — TE 2.14's cuDNN fused-attention
backward kernel amplifies THD/padding_causal gradients 8x-960x per
layer on Blackwell sm_120; the resulting ``inf`` gradients drive
the optimizer to NaN. We have no way to be certain the bug only
affects sm_120, so this is a ``warnings.warn`` on other arches and
a hard ``ValueError`` on sm_120 (where the failure is reproduced).
Tests:
* ``test_salm_parallelism_validation.py`` — 19 unit tests covering
every (BSHD, THD) x (cp=1, cp>1) x (attn ∈ {te, sdpa, flex}) x
(NVTE_FUSED_ATTN ∈ {None, "", "0", "1", "true"}) x
(device_capability ∈ {(9,0), (12,0), None}) combination that
matters. Pure-function tests — no Lightning, no model, no device
mesh required.
| if cp_size > 1: | ||
| sharded = shard_bshd_for_cp(input_embs, attention_mask, target_ids, cp_mesh, tp_size=tp_size) | ||
| input_embs = sharded["input_embs"] | ||
| attention_mask = sharded["attention_mask"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Important
The
Update branchbutton must only be pressed in very rare occassions.An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.
What does this PR do ?
Adds opt-in THD packed-sequence training and full context-parallel support to
SALMAutomodel, with a fit-start validator that catches known-bad configs.Collection: speechlm2
Changelog
parts/packed_sequences.py+parts/cp_helpers.get_cp_mesh.cp_size > 1) — sequence-shard hybrid Mamba/attention LLMs across GPUs. Requires TransformerEngine. Newparts/cp_helpers.shard_bshd_for_cp+encode_audio_with_cp_distribution. Recommended pairing is THD + CP.prepare_inputsderives the CP mesh once, distributes audio encoding, and (for the BSHD branch) inserts the TEDualChunkSwapshard before the TP-truncation fallback.parts/parallel.validate_parallelism_compatibilityruns aton_fit_startand raises errors for three known-bad combos: BSHD+CP>1 (NaN at step 2 from pad-K/V leak), THD+non-TE attention (Automodel's SDPA THD path is not ready), and THD+TE withoutNVTE_FUSED_ATTN=0(cuDNN backward gradient amp on Blackwell sm_120 — hard error on sm_120, warn elsewhere).training_and_scaling.rstandconfigs.rstdocument the new packed_sequences flag, the CP knob, the BSHD+CP-not-supported warning, and theNVTE_FUSED_ATTN=0workaround.Usage
# Add a code snippet demonstrating how to use thisGitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information