[JAX] Support for cuDNN-backed flex attention#2985
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a
Confidence Score: 3/5Mergeable as a new-feature branch but carries several design-level risks that should be resolved before wide adoption. The core library wiring is sound and the tests pass numerically. The C++ registry permanently leaks Python graph objects, the GIL is held across a live CUDA stream in every kernel dispatch, and the graph cache grows without bound. The new standalone cuDNN handle in ExecuteScoreModGraph is independent of the one prepared by CudnnHandleInitHandler. The softcap test demonstrates a stateful callback pattern that depends on cuDNN calling the forward score_mod callback before the backward one during sdpa_backward graph construction. transformer_engine/jax/csrc/extensions/attention.cpp (registry lifetime, GIL, duplicate handle) and transformer_engine/jax/cpp_extensions/attention.py (private cuDNN API, unbounded cache, id-based cache keys) Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant attention.py
participant cpp_ext as cpp_extensions/attention.py
participant CppReg as C++ ScoreModGraphRegistry
participant cuDNN as cuDNN pygraph
participant FFI as XLA FFI (C++)
User->>attention.py: "fused_attn(..., score_mod=fn)"
attention.py->>cpp_ext: make_fused_attn_score_mod_config()
cpp_ext-->>attention.py: config, tensor_operands, bprop_tensor_operands
attention.py->>attention.py: _fused_attn_score_mod (custom_vjp)
note over cpp_ext,cuDNN: JAX trace time
attention.py->>cpp_ext: fused_attn_score_mod_fwd(qkv, tensors, config)
cpp_ext->>cuDNN: "pygraph() + sdpa(score_mod=callback)"
cuDNN->>User: score_mod(graph, score, tensors) called
cuDNN-->>cpp_ext: graph + workspace_size
cpp_ext->>CppReg: register_fused_attn_score_mod_graph(graph, uids)
CppReg-->>cpp_ext: graph_id
cpp_ext->>FFI: ffi_call te_fused_attn_score_mod_forward_ffi
note over FFI,cuDNN: CUDA execution time
FFI->>CppReg: GetScoreModGraphEntry(graph_id)
FFI->>FFI: acquire GIL
FFI->>cuDNN: py_graph._execute_with_ptrs(user_ptrs, workspace, handle)
FFI-->>User: output, softmax_stats
note over cpp_ext,cuDNN: Backward graph build
attention.py->>cpp_ext: fused_attn_score_mod_bwd(...)
cpp_ext->>cuDNN: pygraph() + sdpa_backward(score_mod, score_mod_bprop)
cuDNN->>User: score_mod called then score_mod_bprop called
cpp_ext->>CppReg: register bwd graph
cpp_ext->>FFI: ffi_call te_fused_attn_score_mod_backward_ffi
FFI-->>User: dq, dk, dv
Reviews (2): Last reviewed commit: "Add distributed JAX score mod attention ..." | Re-trigger Greptile |
| struct ScoreModGraphEntry { | ||
| PyObject *py_graph = nullptr; | ||
| std::vector<int64_t> user_uids; | ||
| std::vector<int64_t> input_uids; | ||
| std::vector<int64_t> output_uids; | ||
| std::vector<int64_t> scalar_uids; | ||
| std::vector<ScoreModScalarStorage> scalar_values; | ||
| }; |
There was a problem hiding this comment.
Python reference leak:
Py_INCREF without a matching Py_DECREF
ScoreModGraphEntry stores a raw PyObject* and its refcount is bumped at registration (Py_INCREF(entry->py_graph) at line 833), but the struct has no destructor to call Py_DECREF. Because ScoreModGraphRegistry never removes entries either, every cuDNN Python graph object registered here is permanently immortalised — it will never be collected by Python's GC regardless of what the call site does. Over many different attention shapes or graph configurations this accumulates silently. The fix is to add a destructor that acquires the GIL and calls Py_DECREF, or to store a pybind11::object (which manages the refcount automatically) and ensure destruction always happens under the GIL.
| intermediate_data_type=cudnn.data_type.FLOAT, | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
|
|
||
| q_dim, q_stride = _bshd_as_bhsd_dim_stride(q_aval.shape) | ||
| k_dim, k_stride = _bshd_as_bhsd_dim_stride(k_aval.shape) | ||
| v_dim, v_stride = _bshd_as_bhsd_dim_stride(v_aval.shape) | ||
| o_dim, o_stride = _bshd_as_bhsd_dim_stride(output_aval.shape) | ||
| do_dim, do_stride = _bshd_as_bhsd_dim_stride(doutput_aval.shape) |
There was a problem hiding this comment.
id()-based cache keys can produce false cache hits after GC
_score_mod_callback_cache_key builds its key from id(self_obj) and id(func). Python recycles object addresses after GC, so if a callback instance is collected and a new object (of a different class or with different graph logic) is allocated at the same address, the new config will compare equal to the old one under __eq__. JAX's nondiff-argnum caching then reuses the traced function and graph built for the original callback, silently executing the wrong cuDNN graph. The risk is low for long-lived module-level functions but real for short-lived class instances. Anchoring the key to a non-id stable identifier (e.g., a weakref plus explicit id, or requiring callers to supply an explicit stable key) would eliminate the ambiguity.
| Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id, | ||
| const std::vector<void *> &input_ptrs, | ||
| const std::vector<void *> &output_ptrs, void *workspace) { | ||
| auto entry = GetScoreModGraphEntry(graph_id); | ||
| NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), "cuDNN score_mod graph expected ", | ||
| entry->input_uids.size(), " inputs but got ", input_ptrs.size()); | ||
| NVTE_CHECK(output_ptrs.size() >= entry->output_uids.size(), | ||
| "cuDNN score_mod graph expected at least ", entry->output_uids.size(), | ||
| " outputs but got ", output_ptrs.size()); | ||
|
|
||
| std::unordered_map<int64_t, void *> variant_pack; | ||
| for (size_t i = 0; i < entry->input_uids.size(); ++i) { | ||
| variant_pack.emplace(entry->input_uids[i], input_ptrs[i]); | ||
| } | ||
| for (size_t i = 0; i < entry->output_uids.size(); ++i) { | ||
| variant_pack.emplace(entry->output_uids[i], output_ptrs[i]); | ||
| } | ||
| for (size_t i = 0; i < entry->scalar_uids.size(); ++i) { | ||
| variant_pack.emplace(entry->scalar_uids[i], entry->scalar_values[i].data.data()); | ||
| } | ||
|
|
||
| std::vector<std::intptr_t> user_ptrs; | ||
| user_ptrs.reserve(entry->user_uids.size()); | ||
| for (const auto uid : entry->user_uids) { | ||
| auto it = variant_pack.find(uid); | ||
| NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", uid); | ||
| user_ptrs.push_back(reinterpret_cast<std::intptr_t>(it->second)); | ||
| } | ||
|
|
||
| auto handle = GetScoreModCudnnHandle(); | ||
| NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); | ||
| { | ||
| pybind11::gil_scoped_acquire gil; | ||
| try { | ||
| auto graph = pybind11::reinterpret_borrow<pybind11::object>(entry->py_graph); | ||
| graph.attr("_execute_with_ptrs")(user_ptrs, reinterpret_cast<std::intptr_t>(workspace), | ||
| reinterpret_cast<std::intptr_t>(handle)); | ||
| } catch (const pybind11::error_already_set &exc) { | ||
| NVTE_ERROR("cuDNN score_mod SDPA graph execution failed: ", exc.what()); | ||
| } | ||
| } | ||
| return ffi_with_cuda_error_check(); | ||
| } |
There was a problem hiding this comment.
GIL held across a CUDA FFI call boundary
ExecuteScoreModGraph acquires pybind11::gil_scoped_acquire while the CUDA stream is live and calls a Python method (_execute_with_ptrs) synchronously. Any other Python thread that holds the GIL and is waiting on CUDA work will deadlock. More broadly, acquiring the GIL inside an XLA/JAX FFI handler — which JAX may dispatch from a non-Python thread — creates a locking inversion risk. This is by-design if cuDNN's Python frontend has no C-level execution path, but the limitation should be documented and the possibility of multi-threaded JAX dispatch should be explicitly considered.
| _SCORE_MOD_UID_DQ = 7 | ||
| _SCORE_MOD_UID_DK = 8 | ||
| _SCORE_MOD_UID_DV = 9 | ||
| _SCORE_MOD_FWD_TENSOR_UID_BASE = 1000 |
There was a problem hiding this comment.
_score_mod_graph_cache and C++ registry grow without bound
_score_mod_graph_cache is a module-level dict that accumulates (graph_id, workspace_size) entries for every unique (direction, config, aval-tuple) seen during tracing, and the C++ ScoreModGraphRegistry holds the corresponding cuDNN graph objects forever. Each entry keeps a Python cuDNN graph alive (and, due to the missing Py_DECREF noted separately, prevents GC). In long-running services or evaluation loops that sweep over many shapes/dtypes, this leads to unbounded cuDNN graph memory accumulation. An LRU eviction strategy or an explicit graph-release API paired with cache invalidation would contain the growth.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
|
|
||
| def forward(self, graph, score, tensors): | ||
| import cudnn # pylint: disable=import-outside-toplevel | ||
|
|
||
| self.before_tanh_activation = graph.div( | ||
| a=score, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
| self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) | ||
| tanh_out = graph.tanh(input=self.before_tanh_activation) | ||
| tanh_out.set_data_type(cudnn.data_type.FLOAT) | ||
| return graph.mul( | ||
| a=tanh_out, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
|
|
||
| def backward(self, graph, dscore, tensors): | ||
| import cudnn # pylint: disable=import-outside-toplevel | ||
|
|
||
| d_tanh_out = graph.mul( | ||
| a=dscore, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
| d_tanh_out.set_data_type(cudnn.data_type.FLOAT) | ||
| d_before_tanh_activation = graph.tanh_backward( | ||
| loss=d_tanh_out, | ||
| input=self.before_tanh_activation, | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
| d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) | ||
| return graph.div( | ||
| a=d_before_tanh_activation, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
|
|
||
|
|
||
| def _reference_attention( | ||
| query, key, value, scale, *, causal=False, relative_position=False, softcap=None | ||
| ): | ||
| scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale | ||
| if causal: | ||
| q_pos = jnp.arange(query.shape[1])[:, None] | ||
| kv_pos = jnp.arange(key.shape[1])[None, :] | ||
| scores = jnp.where(q_pos >= kv_pos, scores, -1e9) | ||
| if relative_position: | ||
| q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None] | ||
| kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :] | ||
| scores = scores + q_pos - kv_pos | ||
| if softcap is not None: |
There was a problem hiding this comment.
_ScoreModSoftcap.backward relies on undocumented cuDNN callback ordering
backward reads self.before_tanh_activation, which is written by forward during sdpa_backward graph construction. This is only safe if cuDNN's sdpa_backward guarantees it calls score_mod (the forward callback) before score_mod_bprop (the backward callback) within the same graph-build invocation. If that order is ever reversed, self.before_tanh_activation is None at the time backward runs, and graph.tanh_backward(input=None, ...) will fail silently or crash at execution time rather than at graph-build time.
Description
This PR introduces an alternative code path for the FusedAttention backend for JAX.
The user can specify score_mod and score_mod_bprop functions, which get routed to the corresponding parameters of the sdpa and sdpa_backward calls to cuDNN FE.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: