[PyTorch] Enable head dim 256 for FA4#2932
Conversation
bdcc02e to
3b3f7d0
Compare
Greptile SummaryThis PR enables head_dim=256 for FA4 by delegating head-dimension validation to FA4's own
Confidence Score: 4/5Safe to merge for users on FA4 4.0.0b11, but importing The functional logic — delegating head-dim validation to FA4's own transformer_engine/pytorch/attention/dot_product_attention/backends.py — specifically the combined import block at lines 167–171 where Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[get_attention_backend called] --> B{use_flash_attention_4\nAND v4_is_installed\nAND v4_validate_head_dims not None?}
B -- No --> Z[Skip FA4 head-dim validation]
B -- Yes --> C[Compute _fa4_alignment\n= 16 // element_size of qkv_dtype]
C --> D[Call v4_validate_head_dims\nhead_dim_qk, head_dim_v,\ncompute_capability_major, alignment]
D -- AssertionError --> E[Log: unsupported head dims\nuse_flash_attention_4 = False]
D -- Passes --> F{use_flash_attention_4\nAND is_training\nAND head_dim_qk != head_dim_v\nAND head_dim_qk >= 128\nAND SM100-110?}
F -- Yes --> G[Check dK_reduce_ncol misalignment\ngcd-based TMEM layout check]
G -- Misaligned --> H[Log: SM100 backward bug\nuse_flash_attention_4 = False]
G -- OK --> I[FA4 enabled]
F -- No --> I
E --> Z
Z --> J[Continue with other backend selection]
I --> J
Reviews (3): Last reviewed commit: "Merge branch 'main' into xiny/headdim256..." | Re-trigger Greptile |
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch L3 |
|
@vcherepanov-nv @KshitijLakhani Please review. |
| # dV TMEM load atoms. When (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are | ||
| # misaligned. The dedicated (256, 256) kernel uses its own tmem layout so it's | ||
| # not affected. See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890. | ||
| if ( |
There was a problem hiding this comment.
Should this still be checked when FlashAttentionUtils.v4_validate_head_dims == None?
There was a problem hiding this comment.
I double checked that this is a bug of FA4. Kernels produce wrong results on these shapes but they're allowed by v4_validate_head_dims, so we have to filter them out manually.
Raise an issue to FA4. Dao-AILab/flash-attention#2552
|
LGTM |
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch L3 |
Description
Need FA4 version
4.0.0b11.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: