Skip to content

[speechlm2] SALMAutomodel: THD (packed sequence) and context parallel support#15679

Draft
pzelasko wants to merge 4 commits intomainfrom
speechlm2-automodel-context-parallel
Draft

[speechlm2] SALMAutomodel: THD (packed sequence) and context parallel support#15679
pzelasko wants to merge 4 commits intomainfrom
speechlm2-automodel-context-parallel

Conversation

@pzelasko
Copy link
Copy Markdown
Collaborator

@pzelasko pzelasko commented May 7, 2026

Important

The Update branch button must only be pressed in very rare occassions.
An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.

What does this PR do ?

Adds opt-in THD packed-sequence training and full context-parallel support to SALMAutomodel, with a fit-start validator that catches known-bad configs.

Collection: speechlm2

Changelog

  • THD packed sequences (model.packed_sequences: true) — concatenates per-utterance text + audio embeddings into a single flat [T_total, H] sequence with cu_seqlens instead of padding to [B, T_max, H]. Packed sequences still contain some padding for CP. New parts/packed_sequences.py + parts/cp_helpers.get_cp_mesh.
  • Context parallelism (cp_size > 1) — sequence-shard hybrid Mamba/attention LLMs across GPUs. Requires TransformerEngine. New parts/cp_helpers.shard_bshd_for_cp + encode_audio_with_cp_distribution. Recommended pairing is THD + CP. prepare_inputs derives the CP mesh once, distributes audio encoding, and (for the BSHD branch) inserts the TE DualChunkSwap shard before the TP-truncation fallback.
  • Config validator — parts/parallel.validate_parallelism_compatibility runs at on_fit_start and raises errors for three known-bad combos: BSHD+CP>1 (NaN at step 2 from pad-K/V leak), THD+non-TE attention (Automodel's SDPA THD path is not ready), and THD+TE without NVTE_FUSED_ATTN=0 (cuDNN backward gradient amp on Blackwell sm_120 — hard error on sm_120, warn elsewhere).
  • Docs — training_and_scaling.rst and configs.rst document the new packed_sequences flag, the CP knob, the BSHD+CP-not-supported warning, and the NVTE_FUSED_ATTN=0 workaround.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

pzelasko added 4 commits May 7, 2026 13:38
Adds an opt-in packed-sequence (THD) training/validation path so that
``SALMAutomodel`` can feed a Nemotron-V3 LLM via ``cu_seqlens``-aware
varlen attention instead of the right-padded BSHD layout. Padding
overhead drops from O(B*max_T) to O(rounding) per minibatch, which is
substantial for variable-length speech inputs.

Activated by ``model.packed_sequences: True`` in the YAML; the BSHD
path is unchanged when the flag is unset. Generate / inference still
use BSHD (it doesn't go through ``prepare_inputs``).

Pieces:

* ``parts/packed_sequences.py`` — concatenates per-utterance text +
  audio embeddings into a single flat ``[T_total, H]`` sequence with a
  ``cu_seqlens`` index, applies the per-utt next-token shift, and
  rounds each utterance's flat length up to a multiple of ``2*cp_size``
  so the same packing also satisfies TE's CP DualChunkSwap contract.
  Output shape mirrors Automodel's canonical THD layout
  (``components/distributed/thd_utils.process_input_for_thd`` /
  ``cp_utils._shard_thd_chunk_for_te`` — 2D, no leading batch dim) so
  no extra squeeze/unsqueeze hops are needed.
* ``parts/cp_helpers.py`` — ``get_cp_mesh`` reads the CP submesh out of
  the device mesh and returns ``(None, 1, 0)`` when CP is inactive.
  Used by ``prepare_packed_llm_inputs`` to short-circuit the CP-shard
  path. Further CP plumbing for the BSHD path lands in a follow-up.
* ``models/salm_automodel.py`` — ``forward`` accepts ``**llm_kwargs``
  (THD metadata: ``qkv_format``, ``cu_seqlens``, ``position_ids``,
  ``max_seqlen``); ``prepare_inputs`` returns the THD dict early when
  ``packed_sequences`` is set; ``training_step``/``validation_step``
  splat ``llm_kwargs`` into the forward call and use a
  shape-generic ``logits.reshape(-1, V)`` so the same code handles
  both BSHD ``(B, T, V)`` and THD ``(T, V)`` outputs.

Tests:

* ``test_salm_packed_sequences.py`` — covers shape contracts, cu_seqlens
  invariants, per-utt next-token shift, audio-frame label masking, the
  ``cp_size``/``tp_size`` rounding, and the TE preprocessor regression
  test that pins down the ``cu_seqlens`` + ``max_seqlen`` (singular)
  contract. Includes BSHD-vs-THD pair-equivalence checks: the set of
  supervised ``(input_embedding, target_token_id)`` pairs reaching the
  cross-entropy must be identical between the two layouts on the same
  batch.
* ``test_salm_cp_helpers.py`` — three CPU tests for ``get_cp_mesh``
  covering the no-mesh, ``cp_size==1``, and missing-``cp``-axis paths.
Adds context-parallelism support to ``SALMAutomodel`` so that large
Nemotron-V3 LLMs with hybrid Mamba/attention layers can train on long
audio sequences across multiple GPUs. Builds on the THD packed-sequence
path from the previous commit; the BSHD path is also supported but the
THD path is the recommended configuration under CP.

Activated by ``cp_size > 1`` in the strategy config (e.g.
``AutomodelParallelStrategy(cp_size=2, ...)``); the existing TP
truncation path is folded into the CP padding so single-axis runs are
unchanged.

Pieces:

* ``parts/cp_helpers.py`` — extends the module with two CP-aware
  helpers used by ``SALMAutomodel.prepare_inputs``:
  - ``shard_bshd_for_cp`` pads the BSHD seq dim to a multiple of
    ``2*cp_size*tp_size`` and partitions along the seq dim using TE's
    ``thd_get_partitioned_indices`` (the same DualChunkSwap pattern
    Automodel's ``Config 1`` reference test uses).
  - ``encode_audio_with_cp_distribution`` distributes the audio
    encoder forward across CP ranks instead of recomputing it
    ``cp_size`` times. Right-pads the audio batch with zero-audio
    dummies so every rank participates in FSDP all-gather (and AC
    fires uniformly), then all-gathers the variable-length embedding
    tensors back so each rank reconstructs the full ordered list.
* ``models/salm_automodel.py`` — ``prepare_inputs`` derives the CP
  mesh once via ``get_cp_mesh``, swaps the audio encoder call to the
  CP-distributed version, and (for the BSHD branch) inserts a
  ``shard_bshd_for_cp`` step before the TP-truncation fallback. Under
  CP the BSHD path also drops the padding mask before passing the
  batch to the LLM (TE's fused-attention CP path supports ``causal``
  but not ``padding_causal``); this is documented as a known
  limitation, the durable fix is the THD packed-sequence path.

Tests:

* ``test_salm_cp_helpers.py`` — adds a ``_PerceptionStub`` and CPU
  fallback tests for ``encode_audio_with_cp_distribution``
  (``cp_mesh is None`` and ``B_aud == 0`` paths). The
  ``cp_size > 1`` paths in ``shard_bshd_for_cp`` and
  ``encode_audio_with_cp_distribution`` require ``transformer_engine_torch``
  and a real ``torch.distributed`` process group respectively;
  exercised by 2-GPU smoke tests.
Adds two new subsections under "AutomodelParallelStrategy (SALMAutomodel)"
in the training-and-scaling guide:

* "Packed Sequences (THD)" — explains the layout, when it helps
  (variable-length speech batches), and the YAML knob
  (``model.packed_sequences: true`` plus ``attn: te``).
* "Context Parallelism (CP)" — explains the strategy knob
  (``cp_size > 1``), the BSHD vs THD pairing, and the recommended
  configuration. Documents the BSHD-under-CP padding-mask drop as a
  known limitation, with THD as the durable fix.

Calls out the TransformerEngine 2.14 cuDNN-backend bug on certain GPU
architectures (notably Blackwell sm_120) that returns correct THD
forward activations but gradients amplified 8x-960x per layer, and the
``NVTE_FUSED_ATTN=0`` workaround that forces FlashAttention dispatch
(which is gradient-correct on the same shapes).

Adds a matching "Packed sequences (THD)" entry to the SALMAutomodel
config reference, with a cross-reference to the training-and-scaling
guide for the CP pairing.
Catches three configuration combinations that produce silent NaN
gradients or hangs at training time, and raises an informative error
(or warns where the bug is architecture-specific) before the user
spends ~7 minutes on model load only to watch their loss go NaN.

The check is a pure function in ``parts/parallel.py`` and runs from
``SALMAutomodel.on_fit_start`` once the device mesh is wired up.

Cases:

1. ``model.packed_sequences=false`` (BSHD) under ``cp_size > 1`` —
   hard error. TE's fused-attention CP path rejects ``padding_causal``
   so the right-pad mask is dropped, which lets pad K/V leak into
   real-token attention through the causal mask and produces NaN
   gradients after step 1. There is no supported workaround; the
   error message points users to ``packed_sequences: true``.

2. ``model.packed_sequences=true`` (THD) with
   ``automodel_backend.attn != "te"`` — hard error. THD packing emits
   a 2D ``[T_total, H]`` layout for TE varlen FlashAttention; the
   SDPA THD code path in the Automodel branch transposes assuming 4D
   BSHD inputs and breaks.

3. ``model.packed_sequences=true`` + ``attn="te"`` +
   ``NVTE_FUSED_ATTN != "0"`` — TE 2.14's cuDNN fused-attention
   backward kernel amplifies THD/padding_causal gradients 8x-960x per
   layer on Blackwell sm_120; the resulting ``inf`` gradients drive
   the optimizer to NaN. We have no way to be certain the bug only
   affects sm_120, so this is a ``warnings.warn`` on other arches and
   a hard ``ValueError`` on sm_120 (where the failure is reproduced).

Tests:

* ``test_salm_parallelism_validation.py`` — 19 unit tests covering
  every (BSHD, THD) x (cp=1, cp>1) x (attn ∈ {te, sdpa, flex}) x
  (NVTE_FUSED_ATTN ∈ {None, "", "0", "1", "true"}) x
  (device_capability ∈ {(9,0), (12,0), None}) combination that
  matters. Pure-function tests — no Lightning, no model, no device
  mesh required.
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 7, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

if cp_size > 1:
sharded = shard_bshd_for_cp(input_embs, attention_mask, target_ids, cp_mesh, tp_size=tp_size)
input_embs = sharded["input_embs"]
attention_mask = sharded["attention_mask"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants