Skip to content

Padding free bufix ljl#202

Draft
meichangsu1 wants to merge 3 commits into
modelscope:mainfrom
meichangsu1:padding_free_bufix_ljl
Draft

Padding free bufix ljl#202
meichangsu1 wants to merge 3 commits into
modelscope:mainfrom
meichangsu1:padding_free_bufix_ljl

Conversation

@meichangsu1
Copy link
Copy Markdown
Collaborator

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

This PR fixes padding-free and sequence-parallel compatibility issues with newer Transformers versions.

Main changes:

  • Support forward methods whose kwargs are incompatible across Transformers versions.
  • Add compatibility handling for Qwen3.5 GatedDeltaNet padding-free training.
  • Skip Twinkle GDN padding-free patch when Transformers version natively supports Qwen3.5 cu_seq_lens_q.
  • Adapt sequence-parallel causal mask patching for both old and new Transformers mask APIs.
  • Restore global query length for SDPA no-cache prefill path under sequence parallel.
  • Avoid passing removed or unsupported kwargs such as cache_position to newer forward methods.

This fixes issues where SDPA training under SP could see local shard length in mask creation, causing incorrect loss compared with flash_attention_2.

Experiment results

Validation performed:

.venv/bin/python -m py_compile src/twinkle/patch/gdn_padding_free.py

qq_30035749 added 3 commits May 22, 2026 19:16
Add `_call_with_supported_kwargs` utility to filter out unsupported keyword arguments when calling forward methods, preventing errors from incompatible function signatures. This fixes issues where `origin_forward` methods may not accept all passed kwargs.
…handling

- Add `_call_with_supported_kwargs` and `_call_create_causal_mask` helpers to filter unsupported kwargs
- Rename `cache_position` parameter to `q_length` in flash_attention_mask and sdpa_mask for clarity
- Fix device detection in sdpa_mask when `q_length` is not a tensor
- Ensure compatibility with models that don't accept `cache_position` in causal mask functions
…ill path

In sequence parallel training, when newer Transformers versions pass q_length/q_offset instead of cache_position, the causal mask creation may still see the local shard length. This change restores the global query length for the no-cache prefill path while keeping cache/sliding paths with their upstream offsets.

Also refactor GDN padding-free detection to use transformers version check instead of source inspection, supporting transformers >= 5.9.0.
@meichangsu1 meichangsu1 marked this pull request as draft May 22, 2026 11:23
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces compatibility helpers and updates masking and forward logic to support varying transformers library versions, specifically handling changes in argument signatures such as cache_position. The review feedback primarily focuses on optimizing performance by caching inspect.signature results and moving signature checks out of hot-path function definitions to avoid significant overhead during the model's forward pass.

Comment on lines +32 to +36
def _call_with_supported_kwargs(fn, *args, **kwargs):
signature = inspect.signature(fn)
if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters}
return fn(*args, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling inspect.signature is a relatively expensive operation. Since _call_with_supported_kwargs is used within the model's forward pass (e.g., during mask creation), calling it repeatedly can introduce significant overhead. Consider caching the signature of the function to improve performance.

Suggested change
def _call_with_supported_kwargs(fn, *args, **kwargs):
signature = inspect.signature(fn)
if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters}
return fn(*args, **kwargs)
def _call_with_supported_kwargs(fn, *args, **kwargs):
if not hasattr(_call_with_supported_kwargs, '_cache'):
_call_with_supported_kwargs._cache = {}
sig = _call_with_supported_kwargs._cache.get(fn)
if sig is None:
sig = _call_with_supported_kwargs._cache[fn] = inspect.signature(fn)
if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values()):
kwargs = {key: value for key, value in kwargs.items() if key in sig.parameters}
return fn(*args, **kwargs)

Comment on lines +113 to +115
def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs):
origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin']
origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The signature inspection and the lookup of the original SDPA function should be performed once outside of the sdpa_mask definition. This avoids redundant computation during every forward pass of the model.

Suggested change
def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs):
origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin']
origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters
origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa']
origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters
def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs):

Comment on lines +153 to +155
return origin_sdpa(batch_size, cache_position, kv_length, *args, **kwargs)

return origin_sdpa(batch_size, q_length, kv_length, *args, device=device, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calls to origin_sdpa should use _call_with_supported_kwargs to ensure compatibility across different Transformers versions. This is particularly important for the device argument, which may not be accepted by older versions of the SDPA mask function.

Suggested change
return origin_sdpa(batch_size, cache_position, kv_length, *args, **kwargs)
return origin_sdpa(batch_size, q_length, kv_length, *args, device=device, **kwargs)
return _call_with_supported_kwargs(origin_sdpa, batch_size, cache_position, kv_length, *args, **kwargs)
return _call_with_supported_kwargs(origin_sdpa, batch_size, q_length, kv_length, *args, device=device, **kwargs)

Comment on lines +181 to +186
if 'cache_position' in inspect.signature(masking_utils.origin_create_causal_mask).parameters:
cache_position_or_past_key_values = torch.arange(
0,
input_embeds.shape[1],
device=input_embeds.device,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the optimization suggested for sdpa_mask, the signature of masking_utils.origin_create_causal_mask should be inspected once outside the create_causal_mask function to avoid overhead in the forward pass.

Comment on lines +39 to +43
def _call_with_supported_kwargs(fn, *args, **kwargs):
signature = inspect.signature(fn)
if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters}
return fn(*args, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Caching the signature in _call_with_supported_kwargs is recommended here as well, as this function is called within the forward pass of the GatedDeltaNet layers.

Suggested change
def _call_with_supported_kwargs(fn, *args, **kwargs):
signature = inspect.signature(fn)
if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters}
return fn(*args, **kwargs)
def _call_with_supported_kwargs(fn, *args, **kwargs):
if not hasattr(_call_with_supported_kwargs, '_cache'):
_call_with_supported_kwargs._cache = {}
sig = _call_with_supported_kwargs._cache.get(fn)
if sig is None:
sig = _call_with_supported_kwargs._cache[fn] = inspect.signature(fn)
if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values()):
kwargs = {key: value for key, value in kwargs.items() if key in sig.parameters}
return fn(*args, **kwargs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant