Skip to content

[PyTorch] Enable head dim 256 for FA4#2932

Open
yaox12 wants to merge 4 commits intoNVIDIA:mainfrom
yaox12:xiny/headdim256_fa
Open

[PyTorch] Enable head dim 256 for FA4#2932
yaox12 wants to merge 4 commits intoNVIDIA:mainfrom
yaox12:xiny/headdim256_fa

Conversation

@yaox12
Copy link
Copy Markdown
Member

@yaox12 yaox12 commented Apr 27, 2026

Description

Need FA4 version 4.0.0b11.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yaox12 yaox12 marked this pull request as draft April 27, 2026 09:31
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from bdcc02e to 3b3f7d0 Compare April 27, 2026 09:31
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 27, 2026

Greptile Summary

This PR enables head_dim=256 for FA4 by delegating head-dimension validation to FA4's own _validate_head_dims instead of maintaining a parallel lookup table, and adds a new SM100-gated test for the dedicated 256-kernel. The CI version is bumped from 4.0.0b8 to 4.0.0b11 across the board.

  • backends.py / utils.py: Replace the hand-coded _fa4_hdim_ok table with a call to FlashAttentionUtils.v4_validate_head_dims, imported directly from flash_attn.cute.interface._validate_head_dims. The MLA backward-kernel workaround is updated to use a plain if after the try/except rather than the old elif, keeping the same logic.
  • test_attention.py: Adds test_dpa_fa4_hdim256 with an explicit skipif guard on device_compute_capability not in ((10, 0), (10, 3)), correctly addressing the silent-fallback concern from the previous review; also removes now-unnecessary get_cudnn_version guards from all other FA4 tests.
  • qa/test.sh: Updates FA4 version from 4.0.0b8 to 4.0.0b11 for both SM90 and SM100+ CI runners.

Confidence Score: 4/5

Safe to merge for users on FA4 4.0.0b11, but importing _validate_head_dims inside the shared FA4 import block will hard-crash backends.py module load for anyone still on an older FA4 release rather than gracefully marking FA4 as unavailable.

The functional logic — delegating head-dim validation to FA4's own _validate_head_dims, the SM100 MLA workaround restructuring, the dedicated hdim256 test with proper architecture gating — is all correct. The outstanding concern is that _validate_head_dims is imported in the same multi-symbol block as the two primary FA4 entry points: if any user has FA4 installed at a version that exports flash_attn_func and flash_attn_varlen_func but not _validate_head_dims, the whole backends.py module fails to load, silently breaking all TE functionality that depends on it rather than just disabling FA4.

transformer_engine/pytorch/attention/dot_product_attention/backends.py — specifically the combined import block at lines 167–171 where _validate_head_dims is imported alongside the two primary FA4 functions.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds _validate_head_dims to the same multi-symbol from … import block; if this symbol is absent in any installed FA4 older than 4.0.0b11, the whole import raises ImportError and crashes the entire TE module load — breaking all FA4 tests and any TE user who has not yet upgraded.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Replaces manual architecture-dependent head-dim table with a delegation to FA4's own validator; MLA workaround updated correctly; v4_installation_steps bumped to 4.0.0b11; type annotation for v4_validate_head_dims uses Callable without Optional.
tests/pytorch/attention/test_attention.py Adds test_dpa_fa4_hdim256 with correct SM100-only skipif guard; removes stale cuDNN version guards from all FA4 test functions; addresses the previous silent-fallback concern.
qa/L3_pytorch_FA_versions_test/test.sh Straightforward FA4 version bump from 4.0.0b8 to 4.0.0b11 for both SM90 and SM100+ CI lanes.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[get_attention_backend called] --> B{use_flash_attention_4\nAND v4_is_installed\nAND v4_validate_head_dims not None?}
    B -- No --> Z[Skip FA4 head-dim validation]
    B -- Yes --> C[Compute _fa4_alignment\n= 16 // element_size of qkv_dtype]
    C --> D[Call v4_validate_head_dims\nhead_dim_qk, head_dim_v,\ncompute_capability_major, alignment]
    D -- AssertionError --> E[Log: unsupported head dims\nuse_flash_attention_4 = False]
    D -- Passes --> F{use_flash_attention_4\nAND is_training\nAND head_dim_qk != head_dim_v\nAND head_dim_qk >= 128\nAND SM100-110?}
    F -- Yes --> G[Check dK_reduce_ncol misalignment\ngcd-based TMEM layout check]
    G -- Misaligned --> H[Log: SM100 backward bug\nuse_flash_attention_4 = False]
    G -- OK --> I[FA4 enabled]
    F -- No --> I
    E --> Z
    Z --> J[Continue with other backend selection]
    I --> J
Loading

Reviews (3): Last reviewed commit: "Merge branch 'main' into xiny/headdim256..." | Re-trigger Greptile

Comment thread tests/pytorch/attention/test_attention.py Outdated
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from 3b3f7d0 to 9a93156 Compare May 6, 2026 02:44
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from ae74e44 to 8aa5242 Compare May 6, 2026 02:55
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 6, 2026

/te-ci pytorch L3

@yaox12 yaox12 marked this pull request as ready for review May 6, 2026 02:59
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 6, 2026

@vcherepanov-nv @KshitijLakhani Please review.

@KshitijLakhani KshitijLakhani requested a review from mk-61 May 8, 2026 06:34
Comment thread tests/pytorch/attention/test_attention.py Outdated
Comment thread tests/pytorch/attention/test_attention.py
# dV TMEM load atoms. When (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are
# misaligned. The dedicated (256, 256) kernel uses its own tmem layout so it's
# not affected. See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890.
if (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this still be checked when FlashAttentionUtils.v4_validate_head_dims == None?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked that this is a bug of FA4. Kernels produce wrong results on these shapes but they're allowed by v4_validate_head_dims, so we have to filter them out manually.
Raise an issue to FA4. Dao-AILab/flash-attention#2552

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

LGTM

yaox12 added 2 commits May 10, 2026 22:28
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 11, 2026

/te-ci pytorch L3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants