[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938
[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938hxbai wants to merge 11 commits into
Conversation
Greptile SummaryThis PR makes the
Confidence Score: 5/5Safe to merge; the change is purely additive with a correct backward-compatible default of 1.0. All CUDA kernels, PyTorch bindings, and JAX FFI handlers correctly propagate the new parameter. The old public C symbols are preserved unchanged. The fusion guard is consistent with the pre-existing alpha guard. No files require special attention; all changed files are internally consistent. Important Files Changed
Reviews (11): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| * \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0). | ||
| * \param[in] stream CUDA stream used for the operation. | ||
| */ |
There was a problem hiding this comment.
nvte_clamped_swiglu and nvte_clamped_dswiglu are public symbols declared in a versioned public header. Inserting glu_linear_offset before cudaStream_t is an ABI-breaking change: any external binary or shared library compiled against the old header will silently pass the stream pointer as the offset and a garbage value as the stream, leading to undefined behavior at runtime rather than a clean compile error if called via a pre-compiled library. This should be acknowledged as a breaking change in the PR checklist, and — if this library follows semantic versioning or a compatibility guarantee — a deprecation/transition path or version bump is needed.
timmoon10
left a comment
There was a problem hiding this comment.
The fused op for grouped MLP is hard-coded for GPT-OSS, so we should make sure not to fuse if glu_linear_offset != 1:
TransformerEngine/transformer_engine/pytorch/ops/_common.py
Lines 180 to 183 in df0025b
|
/te-ci |
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
|
|
||
| void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, | ||
| cudaStream_t stream) { | ||
| float glu_linear_offset, cudaStream_t stream) { |
There was a problem hiding this comment.
Can we define new APIs named nvte_clamped_swiglu_v2 and nvte_clamped_dswiglu_v2
and deprecate this API here to not break backward compatibility?
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
|
/te-ci |
|
/te-ci |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Overall looks pretty good from the JAX side, thanks for adding the JAX changes too! Left a couple small comments
| ::xla::ffi::StructMember<float>("limit"), | ||
| ::xla::ffi::StructMember<float>("alpha")); | ||
| ::xla::ffi::StructMember<float>("alpha"), | ||
| ::xla::ffi::StructMember<float>("glu_linear_offset")); |
There was a problem hiding this comment.
can we add a default value for users on HLO from a previous version? Would glu_linear_offset=1 be the same as the current behavior on main?
There was a problem hiding this comment.
Yes, glu_linear_offset=1 is consistent with the current behavior.
Could you point me on how to add the default value on HLO? Thanks.
There was a problem hiding this comment.
@hxbai So I had thought this was easy to add a default value for, but I realized it's a different case where it's a function argument attribute, not a struct field, where we have supported default values in XLA FFIs in TE/JAX previously.
I reached out to the XLA team and heard using std::optional may be supported. Can you try this?
struct XXXX {
...
std::optional<float> glu_linear_offset;
};
then when using the value glu_linear_offset_value = glu_linear_offset.value_or(1.0f)
If it doesn't work, then let me know we can keep it without a default and I'll approve the PR from the JAX side. Thanks!
There was a problem hiding this comment.
Tests failed due to the changes. It seems optional is not supported and I reverted. Is it OK?
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
|
/te-ci |
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
for more information, see https://pre-commit.ci
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
Description
The previous ClampedSwiGLU follows GPT-OSS, which hard-coded the offset 1.0.
DeepSeek-V4 uses ClampedSwiGLU without alpha and offset.
This PR makes the offset of ClampedSwiGLU configurable to support DeepSeek-V4.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: