[TRTLLM-12339][feat] Support T5 encoder-decoder models in the PyTorch backend#13870
[TRTLLM-12339][feat] Support T5 encoder-decoder models in the PyTorch backend#13870cascade812 wants to merge 19 commits into
Conversation
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
📝 WalkthroughWalkthroughAdds encoder–decoder support with cross-attention, dual KV caches, scheduler/executor updates (encoder-first), expanded APIs/config, new T5/BART models, backend adjustments, and extensive tests/docs. ChangesEncoder–Decoder cross-attention and dual KV pools
Sequence Diagram(s)sequenceDiagram
autonumber
participant Client
participant BaseLLM
participant Executor
participant Encoder
participant Decoder
participant KVSelf
participant KVCross
Client->>BaseLLM: generate(encoder inputs/ids)
BaseLLM->>Executor: generate_async(...)
Executor->>Encoder: forward_encoder(batch)
Encoder-->>Executor: hidden_states, seq_lens
Executor->>Decoder: context with cross metadata
Decoder->>KVCross: project/read encoder KV
Decoder->>KVSelf: update self-KV
Decoder-->>Client: tokens (+optional encoder_output)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~180 minutes ✨ Finishing Touches🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts
|
There was a problem hiding this comment.
Actionable comments posted: 18
Note
Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (8)
tensorrt_llm/_torch/attention_backend/vanilla.py (1)
401-415:⚠️ Potential issue | 🟠 Major | ⚡ Quick winDisable causal masking on the flash-attn cross-attention path to match SDPA fallback behavior.
The SDPA fallback disables causal masking when Q/K lengths differ in cross-attention (
sdpa_is_causal = is_causal and (end_q - start_q) == (end_k - start_k)), but the flash-attn path unconditionally appliescausal=Truewheneverattention_maskisCAUSAL. Sinceno_kv_cache_forward()defaults toCAUSAL, cross-attention callers that forget to override the mask will get inconsistent behavior: a triangular encoder mask on fp16/bf16 (flash-attn) but none on other dtypes (SDPA fallback).Suggested fix
attn_output_unpad = flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=softmax_scale, - causal=attention_mask == PredefinedAttentionMask.CAUSAL, + causal=(attention_mask == PredefinedAttentionMask.CAUSAL) and not is_cross, alibi_slopes=None, deterministic=False, return_attn_probs=False, )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/attention_backend/vanilla.py` around lines 401 - 415, The flash-attn call unconditionally enables causal when attention_mask == PredefinedAttentionMask.CAUSAL, causing triangular masking on cross-attention unlike the SDPA fallback; change the causal argument in the flash_attn_varlen_func call to only be true when the mask is causal AND Q/K lengths match (e.g., causal=(attention_mask == PredefinedAttentionMask.CAUSAL and max_seqlen_q == max_seqlen_k) or equivalent check using cu_seqlens), so flash_attn_varlen_func follows the same sdpa_is_causal rule as the SDPA fallback.tensorrt_llm/_torch/pyexecutor/_util.py (1)
734-776:⚠️ Potential issue | 🟠 Major | ⚡ Quick winKeep the cross pool out of the draft budget split.
_split_kv_cache_budget_for_draft()now uses_get_kv_size_per_token(), which includes cross-pool bytes for encoder-decoder models. Butbuild_managers()calls this after_split_kv_cache_budget_for_cross(), so the draft share is computed fromself + cross + draftwhile the budget being split is already self-only. In enc-dec + separate-draft configs, that over-allocates the draft side and shrinks the self pool twice.One way to fix it
- total_kv = self._get_kv_size_per_token() target_kv = self._kv_cache_manager_cls.get_cache_size_per_token( self._model_engine.model.model_config, self._mapping, tokens_per_block=self._tokens_per_block) - draft_kv = total_kv - target_kv + draft_kv = self._get_draft_kv_size_per_token() + total_kv = target_kv + draft_kvCompute the draft size directly from the draft layout, instead of inferring it from the full per-token total.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/pyexecutor/_util.py` around lines 734 - 776, _split_kv_cache_budget_for_draft currently derives draft_kv by subtracting target_kv from total_kv returned by _get_kv_size_per_token(), but _get_kv_size_per_token() includes cross-pool bytes for encoder-decoder models (already removed by _split_kv_cache_budget_for_cross), causing the draft share to be over-allocated; fix by computing draft_kv directly from the draft layout instead of inferring it from total_kv — call the KV cache manager's get_cache_size_per_token (via _kv_cache_manager_cls.get_cache_size_per_token) with the draft mapping/layout (or appropriate tokens_per_block for the draft) to obtain draft_kv, then use that draft_kv and target_kv to split total_budget in _split_kv_cache_budget_for_draft (and leave _get_kv_size_per_token and build_managers unchanged).tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py (1)
374-606:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winRegister these new T5 encoder-decoder integration tests in the functional QA list.
These new PyTorch encoder-decoder tests must be added to
tests/integration/test_lists/qa/llm_function_core.txtto ensure they run in scheduled QA. Per coding guidelines, new encoder-decoder/T5 features should be placed in the correct functional test lists used by scheduled runs rather than relying only on ad-hoc CI.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py` around lines 374 - 606, Add the two new test functions (test_t5_pytorch_generate_encoder_decoder_end_to_end and test_t5_pytorch_generate_encoder_decoder_cuda_graph_mixed_encoder_lengths_batch) to the functional QA test list so they run in scheduled QA; open the QA list file (llm_function_core.txt) and append entries for these test names (one per line) following the existing list format, ensuring any required markers/tags used by the scheduler are included and consistent with other encoder-decoder/T5 tests.tensorrt_llm/executor/request.py (1)
1-1:⚠️ Potential issue | 🟠 Major | ⚡ Quick winAdd the required NVIDIA copyright/license header.
Line 1 starts directly with imports in a modified
.pysource file. Please add the standard NVIDIA copyright + Apache-2.0 header block.As per coding guidelines:
All source files (.cpp, .h, .cu, .py) should contain an NVIDIA copyright header with the year of latest modification and Apache 2.0 license notice.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/executor/request.py` at line 1, Add the standard NVIDIA copyright + Apache-2.0 header block at the very top of tensorrt_llm/executor/request.py (before the existing import statements); ensure the header includes the correct year of latest modification and the full Apache-2.0 notice and SPDX identifiers so the file complies with the project’s source-file licensing guidelines.tensorrt_llm/executor/base_worker.py (1)
1-1:⚠️ Potential issue | 🟠 Major | ⚡ Quick winMissing required source-file header in modified Python file.
Line 1 should include the NVIDIA copyright + Apache-2.0 license notice header.
As per coding guidelines:
All source files (.cpp, .h, .cu, .py) should contain an NVIDIA copyright header with the year of latest modification and Apache 2.0 license notice.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/executor/base_worker.py` at line 1, Add the required NVIDIA copyright and Apache-2.0 license header at the top of the modified Python file (tensorrt_llm/executor/base_worker.py) so the file begins with the standard multi-line header before any imports (e.g., before the existing "import copy" line); include the current year of latest modification and the full Apache-2.0 notice as per project guidelines.tensorrt_llm/executor/result.py (1)
1-1:⚠️ Potential issue | 🟠 Major | ⚡ Quick winAdd NVIDIA copyright/license header to this modified source file.
Line 1 currently starts with imports; header block is required.
As per coding guidelines:
All source files (.cpp, .h, .cu, .py) should contain an NVIDIA copyright header with the year of latest modification and Apache 2.0 license notice.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/executor/result.py` at line 1, The file tensorrt_llm/executor/result.py is missing the required NVIDIA copyright/license header; add the standard NVIDIA copyright block (including the latest modification year and the Apache-2.0 license notice) at the very top of the file before any imports (before the existing import asyncio), formatted as a Python comment block so linters remain happy; ensure the header text matches the project's canonical NVIDIA/Apache-2.0 header template.tensorrt_llm/executor/executor.py (2)
1-1:⚠️ Potential issue | 🟠 Major | ⚡ Quick winModified Python source is missing the required NVIDIA header.
Please add the standard copyright/license header block at the top.
As per coding guidelines:
All source files (.cpp, .h, .cu, .py) should contain an NVIDIA copyright header with the year of latest modification and Apache 2.0 license notice.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/executor/executor.py` at line 1, Add the standard NVIDIA copyright/header block (with the year of latest modification and the Apache 2.0 license notice) at the very top of tensorrt_llm/executor/executor.py before any imports (e.g., before the existing "import atexit"); ensure the header follows the project's canonical format used in other .py files and includes the copyright year and the full Apache-2.0 license statement.
119-166:⚠️ Potential issue | 🟠 Major | ⚡ Quick winSync API is now out of parity with async for encoder-decoder inputs.
generate_async(...)acceptsencoder_input_token_ids, butgenerate(...)cannot accept/forward it. That leaves synchronous callers without a path to pass encoder inputs for encoder-decoder models.Suggested fix
def generate( self, prompt_token_ids: Union[List[int], List[List[int]]], sampling_params: Union[SamplingParams, List[SamplingParams]], query_token_ids: Optional[Union[torch.Tensor, np.ndarray, list]] = None, + encoder_input_token_ids: Optional[Union[torch.Tensor, np.ndarray, list, List[Union[torch.Tensor, np.ndarray, list]]]] = None, lora_request: Optional[Union[LoRARequest, List[LoRARequest]]] = None, prompt_adapter_request: Optional[Union[ PromptAdapterRequest, List[PromptAdapterRequest]]] = None, disaggregated_params: Optional[DisaggregatedParams] = None, ) -> Union[GenerationResult, List[GenerationResult]]: @@ for i, p in enumerate(prompt_token_ids): @@ + if isinstance(encoder_input_token_ids, list) and len(prompt_token_ids) > 1: + enc_ids = encoder_input_token_ids[i] + else: + enc_ids = encoder_input_token_ids future = self.generate_async( p, sampling_params=sp, query_token_ids=query_token_ids, + encoder_input_token_ids=enc_ids, lora_request=lora_req, prompt_adapter_request=pa_req, streaming=False, disaggregated_params=disaggregated_params)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/executor/executor.py` around lines 119 - 166, The synchronous API generate(...) is missing support for encoder_input_token_ids while generate_async(...) accepts it; update the generate function signature to add an optional encoder_input_token_ids: Optional[Union[torch.Tensor, np.ndarray, list]] argument and forward that value into the GenerationRequest constructor (the same way generate_async does) so synchronous callers can pass encoder-decoder inputs; keep type hints consistent with generate_async, update any docstring/comments, and run/adjust related tests that construct GenerationRequest via generate to ensure parity.
🟡 Minor comments (5)
legacy_enc_dec_architecture.md-123-126 (1)
123-126:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winAdd a language tag to the fenced code block at Line 123.
This block is missing a fence language and can trip markdown lint (MD040).
Suggested fix
-``` +```text out/<tpX>/encoder/rank0.engine, config.json out/<tpX>/decoder/rank0.engine, config.json</details> <details> <summary>🤖 Prompt for AI Agents</summary>Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.In
@legacy_enc_dec_architecture.mdaround lines 123 - 126, The fenced code block
containing the lines "out//encoder/rank0.engine, config.json" and
"out//decoder/rank0.engine, config.json" is missing a language tag and
triggers MD040; update the opening fence fromtotext (i.e., add the
"text" language tag) so the block is explicitly marked and linting passes while
preserving the existing content.</details> </blockquote></details> <details> <summary>tests/unittest/_torch/executor/test_py_scheduler.py-2333-2411 (1)</summary><blockquote> `2333-2411`: _⚠️ Potential issue_ | _🟡 Minor_ | _⚡ Quick win_ **Rename the unused tuple slots in these encoder-init tests.** Ruff is already flagging the repeated `disagg` / `paused` unpacking here. Use `_`, `_disagg`, or `_paused` in the cases that only assert on `fitting` so the new coverage stays lint-clean. <details> <summary>🤖 Prompt for AI Agents</summary> ``` Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/executor/test_py_scheduler.py` around lines 2333 - 2411, Tests unpack the triple returned by scheduler.schedule_request into unused names (disagg, paused) causing Ruff lint failures; update the unpacking in the listed encoder-init tests (e.g., test_guaranteed_no_evict_encoder_does_not_consume_cross_pool, test_guaranteed_no_evict_skips_encoder_without_cross_pool, test_guaranteed_no_evict_encoder_does_not_consume_self_pool, test_max_utilization_admits_encoder_with_cross_pool, test_max_utilization_skips_encoder_without_cross_pool, test_max_utilization_encoder_not_evictable_victim) to use throwaway names for unused slots (for example: fitting, _, _ or fitting, _disagg, _paused) so only fitting is asserted and linter warnings go away. ``` </details> </blockquote></details> <details> <summary>tensorrt_llm/_torch/attention_backend/trtllm_gen.py-1064-1068 (1)</summary><blockquote> `1064-1068`: _⚠️ Potential issue_ | _🟡 Minor_ | _⚡ Quick win_ **Replace the non-ASCII `α` in this comment.** Ruff flags this as `RUF003`, which can block the file's lint pass. Use plain ASCII (`5a` / `stage-5a`) instead. As per coding guidelines, "When modifying Group A Python files, ruff will format and lint the entire file with full rule set; new violations will block the commit." <details> <summary>🤖 Prompt for AI Agents</summary> ``` Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py` around lines 1064 - 1068, The comment in trtllm_gen.py describing the "Cross-attention context phase" contains a non-ASCII Greek letter `α` (in "5α"), which triggers RUF003; update that comment to use plain ASCII such as `5a` or `stage-5a` (leave surrounding text and variable references like `cross_kv_input`, `cache_seq_lens`, and `store_encoder_kv_cache` unchanged) so the file passes ruff linting. ``` </details> </blockquote></details> <details> <summary>encoder_decoder_porting_guide.md-278-280 (1)</summary><blockquote> `278-280`: _⚠️ Potential issue_ | _🟡 Minor_ | _⚡ Quick win_ **Disambiguate the BART baseline model.** This row names `google/t5-base` but describes the BART-base checkpoint, so the benchmark baseline is not reproducible as written. Please spell out the actual BART model ID here. <details> <summary>🤖 Prompt for AI Agents</summary> ``` Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@encoder_decoder_porting_guide.md` around lines 278 - 280, The table row under "Model" is incorrect: it lists `google/t5-base` but describes a BART-base checkpoint; update the Model cell to the actual Hugging Face BART ID (e.g., `facebook/bart-base`) and adjust the description to match (remove or separate any reference to T5), keeping `google/flan-t5-large` as the second size class if intended; ensure the row now clearly shows `facebook/bart-base` as the BART baseline and `google/flan-t5-large` as the T5 alternative. ``` </details> </blockquote></details> <details> <summary>tensorrt_llm/llmapi/llm.py-602-610 (1)</summary><blockquote> `602-610`: _⚠️ Potential issue_ | _🟡 Minor_ | _⚡ Quick win_ **Fix the hanging-indent lint errors in these multiline calls.** Flake8 is already flagging E121 on these continuations, so this file will fail lint as-is. Also applies to: 694-700, 915-923 <details> <summary>🤖 Prompt for AI Agents</summary> ``` Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/llmapi/llm.py` around lines 602 - 610, The multiline call to self._preprocess has hanging-indent style that triggers E121; update the argument continuation lines so they either align under the first character after the opening parenthesis or use a consistent hanging indent (e.g., indent all continuation lines by one additional level) to satisfy flake8/PEP8; specifically fix the call to _preprocess in llm.py and the other similar multiline calls around the other _preprocess/related invocation sites in the file so the continuation lines align with the first argument or with a consistent hanging indent, removing the current misaligned indentation that causes E121. ``` </details> </blockquote></details> </blockquote></details> <details> <summary>🧹 Nitpick comments (6)</summary><blockquote> <details> <summary>tensorrt_llm/_torch/pyexecutor/py_executor.py (2)</summary><blockquote> `3148-3161`: _💤 Low value_ **Consider adding `strict=True` to zip for explicit length enforcement.** The assertion at line 3148 already ensures lengths match, but adding `strict=True` makes the invariant explicit at the iteration site and would catch any future refactoring that removes the assertion. <details> <summary>♻️ Optional fix</summary> ```diff - for req, seq_len in zip(encoder_requests, encoder_seq_lens): + for req, seq_len in zip(encoder_requests, encoder_seq_lens, strict=True): ``` </details> <details> <summary>🤖 Prompt for AI Agents</summary> ``` Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/pyexecutor/py_executor.py` around lines 3148 - 3161, The loop over encoder_requests and encoder_seq_lens should enforce equal lengths at the iteration site — update the for loop to use zip(encoder_requests, encoder_seq_lens, strict=True) so mismatched lengths raise immediately; keep the preceding assertions or remove them as you prefer, and ensure the loop still assigns req.py_encoder_output, sets req.py_skip_cross_kv_projection = False, and req.state = LlmRequestState.CONTEXT_INIT as before. ``` </details> --- `3097-3111`: _💤 Low value_ **Consider catching more specific exceptions in encoder forward.** The broad `except Exception` catches all errors, which is consistent with the existing pattern in `_forward_step` (line 3623) and `_sample_async` (line 3724). However, per coding guidelines, catching specific exceptions is preferred. If maintaining consistency with existing code is the priority, this is acceptable. <details> <summary>🤖 Prompt for AI Agents</summary> ``` Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/pyexecutor/py_executor.py` around lines 3097 - 3111, The current broad except Exception around the encoder forward (the block that calls self.encoder_stream.wait_stream and self.model_engine.forward_encoder with encoder_requests) should be narrowed to catch specific expected exceptions (e.g., torch.cuda.CudaError / RuntimeError for CUDA/torch issues and ValueError or TypeError for input problems) and let truly unexpected errors propagate; update the except clause to catch those specific exception types, keep the existing traceback.print_exc(), logger.error(...) and call to self._handle_errors( error_msg, requests=encoder_requests) for the handled cases, and only fall back to a generic except Exception if you intentionally want the existing global behavior. ``` </details> </blockquote></details> <details> <summary>tests/unittest/_torch/modeling/test_modeling_enc_dec.py (1)</summary><blockquote> `655-756`: _⚡ Quick win_ **Benchmark the cached generation hop too.** The docstring says this smoke test covers one context cross-attention call followed by one generation call, but the timed loop only exercises `skip_cross_kv_projection=False`. That leaves the cached cross-KV read path unmeasured, which is the part most likely to regress in the V1 dual-pool lane. As per coding guidelines, "Note if only functional correctness is tested for a change where a performance regression would not be caught." <details> <summary>🤖 Prompt for AI Agents</summary> ``` Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/modeling/test_modeling_enc_dec.py` around lines 655 - 756, The timed loop in TestCrossAttentionDualPoolSmokeBenchmark.test_v1_dual_pool_cross_attention_smoke only exercises the context path (skip_cross_kv_projection=False) and never measures the cached-generation hop; update the test to also time the generation cross-attention call path (skip_cross_kv_projection=True) using the same setup: add warmup invocation(s) for the generation path, perform a separate timed loop (or include both calls per iteration) that calls cross_attn with skip_cross_kv_projection=True (using the same context_metadata/context_cross_metadata and kv_managers), compute ms_per_iter for the generation hop (or both hops separately), print its latency/tokens/s, and add a matching loose assert (similar to the existing self.assertLess) to catch regressions on the cached read path in addition to the current context-path check. ``` </details> </blockquote></details> <details> <summary>tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py (2)</summary><blockquote> `1062-1109`: _⚡ Quick win_ **Add failure-path coverage for the cross-pool step.** These new tests only cover the happy path where `cross_kv_cache_manager.prepare_context()` and `resize_context()` succeed. The dual-pool admission logic is just as sensitive to either call returning `False`; without that coverage, a regression can still admit a request with only one pool prepared or leak token budget across retries. As per coding guidelines, tests under `tests/**` should cover happy path, important edge cases, and failure modes relevant to the feature or fix. <details> <summary>🤖 Prompt for AI Agents</summary> ``` Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py` around lines 1062 - 1109, Add tests covering failure paths for the cross-pool step by mocking cross_kv_cache_manager.prepare_context and/or cross_kv_cache_manager.resize_context to return False and asserting the scheduler (sched.schedule_request) does not admit a request with only one pool prepared or leak token budget across retries; create variants for (a) prepare_context returns False, (b) resize_context returns False, and (c) both return False, using the existing helpers make_encoder_scheduler, make_kv_cache_manager, make_ctx_request and verifying calls to self_mgr.prepare_context/resize_context and cross_mgr.prepare_context/resize_context (call counts) and the schedule_request outcome (e.g., no context_requests admitted or appropriate deferred state), matching the existing test style in test_encoder_then_context_defers_cross_pool_to_context and test_later_context_chunk_reuses_cross_pool_without_resizing. ``` </details> --- `1471-1481`: _⚡ Quick win_ **Add the ready-event counterpart to this filter test.** This only proves the blocked case. A matching case with `py_encoder_output_ready_event.query()` returning `True` would catch an accidental inversion, or an implementation that filters any request carrying the event object. As per coding guidelines, tests under `tests/**` should cover happy path, important edge cases, and failure modes relevant to the feature or fix. <details> <summary>🤖 Prompt for AI Agents</summary> ``` Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py` around lines 1471 - 1481, Add a complementary "ready" case to test_context_init_waiting_on_encoder_event_is_filtered: create a ctx via make_ctx_request(0) whose py_encoder_output_ready_event.query() returns True (use a ready event helper), call sched.schedule_request([ready_ctx, gen_req], set()), and assert that the context request is kept (ids(out.context_requests) == [0]) while generation_requests remains [1]; reference the test function test_context_init_waiting_on_encoder_event_is_filtered, make_ctx_request, py_encoder_output_ready_event.query(), and sched.schedule_request to locate where to add the new assertion. ``` </details> </blockquote></details> <details> <summary>tensorrt_llm/_torch/modules/cross_attention.py (1)</summary><blockquote> `60-74`: _⚡ Quick win_ **Add return annotations to the new methods.** `__init__` and `create_weights` are new in this file, but neither declares `-> None`, so the file misses the repo’s function-annotation requirement. As per coding guidelines, "Always annotate functions; make the return type `None` if the function does not return anything." Also applies to: 166-170 <details> <summary>🤖 Prompt for AI Agents</summary> ``` Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/cross_attention.py` around lines 60 - 74, The constructors and new methods lack return type annotations: add explicit "-> None" return annotations to the __init__ method and to create_weights (and any other new methods around that area, e.g., the method at lines referenced in the comment) so they satisfy the repo rule "Always annotate functions"; update the signatures for __init__ in class CrossAttention (and create_weights) to end with "-> None" while leaving the rest of the signature unchanged. ``` </details> </blockquote></details> </blockquote></details> <details> <summary>🤖 Prompt for all review comments with AI agents</summary>Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.Inline comments:
In@cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp:
- Around line 421-450: The code is forwarding crossKvCacheManager into
prefillWithChunkedContextsAlreadyExecuting even when the cross cache uses
variable-window blocks; compute a crossSkippingRelevant flag (e.g. bool
crossSkippingRelevant = crossKvCacheManager &&
!crossKvCacheManager->getBlockManager().isVariableWindow()) and only pass
crossKvCacheManager into prefillWithChunkedContextsAlreadyExecuting when that
flag is true—otherwise pass an empty OptionalRef/std::nullopt—so cross-pool
reuse prefill is disabled for variable-window cross caches; refer to
prefillWithChunkedContextsAlreadyExecuting, crossKvCacheManager, kvCacheManager,
and skippingIsRelevant to locate where to change.In
@cpp/tensorrt_llm/thop/attentionOp.cpp:
- Around line 664-688: Add explicit precondition checks when cross_attention is
true: assert that is_fused_qkv is false, cross_kv.has_value() is true, and
encoder_input_lengths.has_value() is true, and that caller did not pass per-call
k or v (i.e. !k.has_value() && !v.has_value()). Implement these via
TLLM_CHECK_WITH_INFO right after existing initial checks (use the same style as
the existing checks) so the cross-attention kernel path cannot be entered with
fused QKV, missing encoder metadata, or unexpected k/v buffers.In
@tensorrt_llm/_torch/attention_backend/interface.py:
- Around line 457-474: The cross-attention metadata currently shallow-copies
self.kv_cache_params when encoder_num_cached_tokens_per_seq is None, which can
leak self-attention cache state and wrongly inherit use_cache; update the block
that sets cross_md.kv_cache_params (the KVCacheParams construction) so that when
encoder_num_cached_tokens_per_seq is None you create a fresh KVCacheParams with
num_cached_tokens_per_seq set to zeros/empty per-seq list (representing the
documented “0 cached tokens”) and set use_cache strictly to
(cross_kv_cache_manager is not None) rather than inheriting
base_params.use_cache; for other optional fields (block_ids_per_seq,
host_max_attention_window_sizes, host_sink_token_length, num_extra_kv_tokens)
copy from base_params only if base_params is not None, otherwise use safe
defaults (None or 0).In
@tensorrt_llm/_torch/models/modeling_bart.py:
- Around line 177-183: The code hardcodes F.gelu and always applies
sqrt(d_model) embedding scaling which breaks HF BART parity; add a small helper
(e.g., map_activation(config.activation_function)) that returns the correct
torch activation (gelu, relu, gelu_new, etc.) and pass its result into MLP via
the self.mlp instantiation instead of F.gelu, and change embedding scaling to
only multiply by math.sqrt(self.config.d_model) when config.scale_embedding is
True (use config.scale_embedding) wherever embeddings are scaled (the embed
token creation/use sites and any places building the FFN like MLP and forward
uses). Update all usages that build MLP (self.mlp) and any embedding scaling
sites to read config.activation_function and config.scale_embedding so
checkpoint behavior matches HF BART.In
@tensorrt_llm/_torch/models/modeling_t5.py:
- Around line 567-568: The code currently uses hidden_states.shape[0] to size
the relative position bias, which is the total packed token count and can OOM;
change the sizing to use a per-sequence bound such as
max(attn_metadata.seq_lens) (or an explicit per-request seq_len variable) when
calling self.relative_position_bias so you build an [sl, sl] bias per sequence
rather than [total_tokens, total_tokens]; update the same pattern found around
the other occurrence (lines ~620-621) to use the same per-sequence max(seq_lens)
value and ensure you pass the correct device (hidden_states.device) to
relative_position_bias.- Around line 257-264: The current branch returns early when
attn_metadata.kv_cache_manager is present, which causes position_bias to be
dropped during cached decoding; change the logic in the T5 attention
implementation (the block that checks position_bias and
attn_metadata.kv_cache_manager before calling super().forward) so that when a kv
cache exists you still pass through and apply the existing position_bias for the
current query against cached keys rather than delegating to Attention.forward()
which omits it; specifically, ensure the call to super().forward (or the
T5Attention.forward path) receives the position_bias (and does not short-circuit
on attn_metadata.kv_cache_manager) and preserve usage of the learned relative
bias computation when handling cached decode in forward.In
@tensorrt_llm/_torch/modules/cross_attention.py:
- Around line 188-196: The current fallback returns uniform seq_lens_kv for
multi-request batches (using num_encoder_tokens, num_requests) which silently
missegments ragged encoder inputs; instead, when num_requests > 1 require
explicit cross_attn_metadata by raising a ValueError that instructs the caller
to pass cross_attn_metadata with seq_lens_kv set (remove the torch.full fallback
and any per_request division path for multi-request cases), keeping the existing
behavior only for single-request batches.In
@tensorrt_llm/_torch/modules/encoder_decoder_layer.py:
- Around line 29-36: The abstract forward method(s) in
tensorrt_llm/_torch/modules/encoder_decoder_layer.py currently place the
ellipsis on the same line as the signature, triggering Flake8 E704; update the
forward declarations (the@abstractmethod-decoratedforward function(s)) so the
trailing "..." is moved onto its own line (i.e., keep the signature and return
annotation as-is, then put a separate line containing only the ellipsis) to
resolve the lint error while preserving the type hints and optional parameters
like hidden_states, attn_metadata, and position_ids.In
@tensorrt_llm/_torch/pyexecutor/_util.py:
- Around line 1052-1062: The KV-capacity warmup bookkeeping in
configure_kv_cache_capacity() neglects the newly-added cross_kv_cache_manager;
update configure_kv_cache_capacity() so it also reads and adds allocated_bytes
(or equivalent temporary allocation metric) from
ResourceManagerType.CROSS_KV_CACHE_MANAGER (i.e., the cross_kv_cache_manager
created above) alongside the existing KV_CACHE_MANAGER and
DRAFT_KV_CACHE_MANAGER entries so encoder-decoder warmup accounts for
cross-attention temporary KV allocations when sizing the final cache.- Around line 834-845: The cross-pool length cap in _get_cross_kv_cache_layout
should not use the mutable self._max_seq_len because
build_managers->_create_kv_cache_manager can shrink that field; instead capture
and use an immutable source for the initial max sequence length (e.g., an
original_max_seq_len saved at construction or compute from config/defaults) when
computing encoder_limit and max_seq_len inside _get_cross_kv_cache_layout so
encoder sizing is independent of later mutations to self._max_seq_len; update
references in _get_cross_kv_cache_layout (and any callers like
build_managers/_create_kv_cache_manager if they rely on it) to use that
immutable value.In
@tensorrt_llm/_torch/pyexecutor/llm_request.py:
- Line 281: The PyResult.Diff must not hold a live GPU tensor: change
set_encoder_output()/wherever encoder_output is assigned to detach the tensor,
move it off GPU and into a serializable form (e.g., .detach().cpu() and either
convert to a numpy array or create/share a CPU-backed shared-memory handle),
store that serializable object (or handle) instead of the raw tensor, and update
get_diff() to return the serializable payload/handle rather than the original
tensor; apply the same pattern for the other occurrences of encoder_output usage
(the other set/get paths referenced around encoder_output and PyResult.Diff).In
@tensorrt_llm/_torch/pyexecutor/scheduler/scheduler_v2.py:
- Around line 388-391: The rollback path after calling
_try_schedule_cross_context currently calls _suspend_request(req) which only
pauses self and draft managers and therefore leaks cross-pool capacity owned via
cross_kv_cache_manager; update the rollback to also suspend/release the
request's cross-pool reservations by invoking the cross-pool suspension logic
(e.g., call the cross_kv_cache_manager suspend/release helper or extend
_suspend_request to also call cross_kv_cache_manager.suspend_request(req) /
release_cross_pages(req)) so cross pool pages are returned when cross_action !=
ScheduleAction.SCHEDULED; apply the same change to the other identical branch
that calls _try_schedule_cross_context (the second occurrence noted in the
review).In
@tensorrt_llm/_torch/pyexecutor/scheduler/scheduler.py:
- Around line 1105-1110: MaxUtilizationPolicy.schedule() initializes a
cross-pool manager but still calls _prefill_contributed_blocks()
unconditionally, which can invoke analyze_prefix_reuse() on a variable-window
cross KV cache (unsupported per _is_skipping_relevant()); update the calls that
prefill cross-pool reuse (the invocation of _prefill_contributed_blocks() and
the similar block around lines ~1151-1157) to run only when the cross KV cache
is NOT a variable-window cache — i.e., guard use of
scheduled_cross_blocks_manager and the prefill helper behind the same
variable-window check that _is_skipping_relevant() uses so
analyze_prefix_reuse() is never called for variable-window cross pools.- Around line 350-353: The current can_schedule() drops unready decoder-context
requests via filter_unready_decoder_context_requests before comparing lengths,
which can incorrectly report a batch schedulable when some original requests are
unready; change both this class's can_schedule() and
SimpleUnifiedScheduler.can_schedule() to preserve the original requests count
when deciding schedulability: call filter_unready_decoder_context_requests only
to produce the list passed to capacity_scheduler.schedule_request (using
filter_unready_decoder_context_requests(requests)), but compare
len(fitting_requests) against len(requests) from the original input (or
explicitly return False if any requests were filtered out/unready) so that any
unready request prevents returning True; references: can_schedule,
filter_unready_decoder_context_requests, capacity_scheduler.schedule_request,
SimpleUnifiedScheduler.can_schedule.In
@tensorrt_llm/inputs/data.py:
- Around line 98-102: The current assertion in the function returning inputs
allows encoder-only dicts which violate the TextPrompt/TokensPrompt return
types; update the assertion so it requires at least one of the main prompt
fields (i.e., ensure inputs.get("prompt") is not None or
inputs.get("prompt_token_ids") is not None) rather than permitting
encoder_inputs or encoder_input_token_ids alone, and keep the return type as the
existing TypedDict (TextPrompt/TokensPrompt); adjust the assert expression
(referencing inputs, "prompt", "prompt_token_ids", "encoder_inputs",
"encoder_input_token_ids") accordingly so downstream code can rely on
prompt/prompt_token_ids being present.In
@tests/unittest/_torch/executor/test_encoder_step.py:
- Around line 1-16: The file tests/unittest/_torch/executor/test_encoder_step.py
only contains SPDX lines; update its top-of-file header to include the full
NVIDIA copyright notice and Apache-2.0 license notice block used across the repo
(matching other new .py files), i.e. add the multi-line copyright header with
the latest modification year (2026) plus the full Apache License, Version 2.0
text/block directly below the SPDX lines so the file contains both the NVIDIA
copyright header and the complete Apache 2.0 notice.In
@tests/unittest/_torch/modeling/test_modeling_enc_dec.py:
- Around line 1208-1225: The tests currently cast the HuggingFace model to bf16
before calling state_dict(), so load_weights()/_convert_hf_bart_weights never
see float32 tensors; change the test to obtain hf_weights from the HF model
while it is still in its original float32 dtype (call hf_model.state_dict()
before applying .to(self.dtype) or create a separate copy for casting), then set
hf_config.torch_dtype and cast the HF model for downstream eval as needed; apply
the same change to the other occurrence that casts before state_dict() so
load_weights(TllmT5.load_weights) exercises the float32→bf16 conversion logic.
Outside diff comments:
In@tensorrt_llm/_torch/attention_backend/vanilla.py:
- Around line 401-415: The flash-attn call unconditionally enables causal when
attention_mask == PredefinedAttentionMask.CAUSAL, causing triangular masking on
cross-attention unlike the SDPA fallback; change the causal argument in the
flash_attn_varlen_func call to only be true when the mask is causal AND Q/K
lengths match (e.g., causal=(attention_mask == PredefinedAttentionMask.CAUSAL
and max_seqlen_q == max_seqlen_k) or equivalent check using cu_seqlens), so
flash_attn_varlen_func follows the same sdpa_is_causal rule as the SDPA
fallback.In
@tensorrt_llm/_torch/pyexecutor/_util.py:
- Around line 734-776: _split_kv_cache_budget_for_draft currently derives
draft_kv by subtracting target_kv from total_kv returned by
_get_kv_size_per_token(), but _get_kv_size_per_token() includes cross-pool bytes
for encoder-decoder models (already removed by
_split_kv_cache_budget_for_cross), causing the draft share to be over-allocated;
fix by computing draft_kv directly from the draft layout instead of inferring it
from total_kv — call the KV cache manager's get_cache_size_per_token (via
_kv_cache_manager_cls.get_cache_size_per_token) with the draft mapping/layout
(or appropriate tokens_per_block for the draft) to obtain draft_kv, then use
that draft_kv and target_kv to split total_budget in
_split_kv_cache_budget_for_draft (and leave _get_kv_size_per_token and
build_managers unchanged).In
@tensorrt_llm/executor/base_worker.py:
- Line 1: Add the required NVIDIA copyright and Apache-2.0 license header at the
top of the modified Python file (tensorrt_llm/executor/base_worker.py) so the
file begins with the standard multi-line header before any imports (e.g., before
the existing "import copy" line); include the current year of latest
modification and the full Apache-2.0 notice as per project guidelines.In
@tensorrt_llm/executor/executor.py:
- Line 1: Add the standard NVIDIA copyright/header block (with the year of
latest modification and the Apache 2.0 license notice) at the very top of
tensorrt_llm/executor/executor.py before any imports (e.g., before the existing
"import atexit"); ensure the header follows the project's canonical format used
in other .py files and includes the copyright year and the full Apache-2.0
license statement.- Around line 119-166: The synchronous API generate(...) is missing support for
encoder_input_token_ids while generate_async(...) accepts it; update the
generate function signature to add an optional encoder_input_token_ids:
Optional[Union[torch.Tensor, np.ndarray, list]] argument and forward that value
into the GenerationRequest constructor (the same way generate_async does) so
synchronous callers can pass encoder-decoder inputs; keep type hints consistent
with generate_async, update any docstring/comments, and run/adjust related tests
that construct GenerationRequest via generate to ensure parity.In
@tensorrt_llm/executor/request.py:
- Line 1: Add the standard NVIDIA copyright + Apache-2.0 header block at the
very top of tensorrt_llm/executor/request.py (before the existing import
statements); ensure the header includes the correct year of latest modification
and the full Apache-2.0 notice and SPDX identifiers so the file complies with
the project’s source-file licensing guidelines.In
@tensorrt_llm/executor/result.py:
- Line 1: The file tensorrt_llm/executor/result.py is missing the required
NVIDIA copyright/license header; add the standard NVIDIA copyright block
(including the latest modification year and the Apache-2.0 license notice) at
the very top of the file before any imports (before the existing import
asyncio), formatted as a Python comment block so linters remain happy; ensure
the header text matches the project's canonical NVIDIA/Apache-2.0 header
template.In
@tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py:
- Around line 374-606: Add the two new test functions
(test_t5_pytorch_generate_encoder_decoder_end_to_end and
test_t5_pytorch_generate_encoder_decoder_cuda_graph_mixed_encoder_lengths_batch)
to the functional QA test list so they run in scheduled QA; open the QA list
file (llm_function_core.txt) and append entries for these test names (one per
line) following the existing list format, ensuring any required markers/tags
used by the scheduler are included and consistent with other encoder-decoder/T5
tests.
Minor comments:
In@encoder_decoder_porting_guide.md:
- Around line 278-280: The table row under "Model" is incorrect: it lists
google/t5-basebut describes a BART-base checkpoint; update the Model cell to
the actual Hugging Face BART ID (e.g.,facebook/bart-base) and adjust the
description to match (remove or separate any reference to T5), keeping
google/flan-t5-largeas the second size class if intended; ensure the row now
clearly showsfacebook/bart-baseas the BART baseline and
google/flan-t5-largeas the T5 alternative.In
@legacy_enc_dec_architecture.md:
- Around line 123-126: The fenced code block containing the lines
"out//encoder/rank0.engine, config.json" and
"out//decoder/rank0.engine, config.json" is missing a language tag and
triggers MD040; update the opening fence fromtotext (i.e., add the
"text" language tag) so the block is explicitly marked and linting passes while
preserving the existing content.In
@tensorrt_llm/_torch/attention_backend/trtllm_gen.py:
- Around line 1064-1068: The comment in trtllm_gen.py describing the
"Cross-attention context phase" contains a non-ASCII Greek letterα(in "5α"),
which triggers RUF003; update that comment to use plain ASCII such as5aor
stage-5a(leave surrounding text and variable references like
cross_kv_input,cache_seq_lens, andstore_encoder_kv_cacheunchanged) so
the file passes ruff linting.In
@tensorrt_llm/llmapi/llm.py:
- Around line 602-610: The multiline call to self._preprocess has hanging-indent
style that triggers E121; update the argument continuation lines so they either
align under the first character after the opening parenthesis or use a
consistent hanging indent (e.g., indent all continuation lines by one additional
level) to satisfy flake8/PEP8; specifically fix the call to _preprocess in
llm.py and the other similar multiline calls around the other
_preprocess/related invocation sites in the file so the continuation lines align
with the first argument or with a consistent hanging indent, removing the
current misaligned indentation that causes E121.In
@tests/unittest/_torch/executor/test_py_scheduler.py:
- Around line 2333-2411: Tests unpack the triple returned by
scheduler.schedule_request into unused names (disagg, paused) causing Ruff lint
failures; update the unpacking in the listed encoder-init tests (e.g.,
test_guaranteed_no_evict_encoder_does_not_consume_cross_pool,
test_guaranteed_no_evict_skips_encoder_without_cross_pool,
test_guaranteed_no_evict_encoder_does_not_consume_self_pool,
test_max_utilization_admits_encoder_with_cross_pool,
test_max_utilization_skips_encoder_without_cross_pool,
test_max_utilization_encoder_not_evictable_victim) to use throwaway names for
unused slots (for example: fitting, _, _ or fitting, _disagg, _paused) so only
fitting is asserted and linter warnings go away.
Nitpick comments:
In@tensorrt_llm/_torch/modules/cross_attention.py:
- Around line 60-74: The constructors and new methods lack return type
annotations: add explicit "-> None" return annotations to the init method
and to create_weights (and any other new methods around that area, e.g., the
method at lines referenced in the comment) so they satisfy the repo rule "Always
annotate functions"; update the signatures for init in class CrossAttention
(and create_weights) to end with "-> None" while leaving the rest of the
signature unchanged.In
@tensorrt_llm/_torch/pyexecutor/py_executor.py:
- Around line 3148-3161: The loop over encoder_requests and encoder_seq_lens
should enforce equal lengths at the iteration site — update the for loop to use
zip(encoder_requests, encoder_seq_lens, strict=True) so mismatched lengths raise
immediately; keep the preceding assertions or remove them as you prefer, and
ensure the loop still assigns req.py_encoder_output, sets
req.py_skip_cross_kv_projection = False, and req.state =
LlmRequestState.CONTEXT_INIT as before.- Around line 3097-3111: The current broad except Exception around the encoder
forward (the block that calls self.encoder_stream.wait_stream and
self.model_engine.forward_encoder with encoder_requests) should be narrowed to
catch specific expected exceptions (e.g., torch.cuda.CudaError / RuntimeError
for CUDA/torch issues and ValueError or TypeError for input problems) and let
truly unexpected errors propagate; update the except clause to catch those
specific exception types, keep the existing traceback.print_exc(),
logger.error(...) and call to self._handle_errors( error_msg,
requests=encoder_requests) for the handled cases, and only fall back to a
generic except Exception if you intentionally want the existing global behavior.In
@tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py:
- Around line 1062-1109: Add tests covering failure paths for the cross-pool
step by mocking cross_kv_cache_manager.prepare_context and/or
cross_kv_cache_manager.resize_context to return False and asserting the
scheduler (sched.schedule_request) does not admit a request with only one pool
prepared or leak token budget across retries; create variants for (a)
prepare_context returns False, (b) resize_context returns False, and (c) both
return False, using the existing helpers make_encoder_scheduler,
make_kv_cache_manager, make_ctx_request and verifying calls to
self_mgr.prepare_context/resize_context and
cross_mgr.prepare_context/resize_context (call counts) and the schedule_request
outcome (e.g., no context_requests admitted or appropriate deferred state),
matching the existing test style in
test_encoder_then_context_defers_cross_pool_to_context and
test_later_context_chunk_reuses_cross_pool_without_resizing.- Around line 1471-1481: Add a complementary "ready" case to
test_context_init_waiting_on_encoder_event_is_filtered: create a ctx via
make_ctx_request(0) whose py_encoder_output_ready_event.query() returns True
(use a ready event helper), call sched.schedule_request([ready_ctx, gen_req],
set()), and assert that the context request is kept (ids(out.context_requests)
== [0]) while generation_requests remains [1]; reference the test function
test_context_init_waiting_on_encoder_event_is_filtered, make_ctx_request,
py_encoder_output_ready_event.query(), and sched.schedule_request to locate
where to add the new assertion.In
@tests/unittest/_torch/modeling/test_modeling_enc_dec.py:
- Around line 655-756: The timed loop in
TestCrossAttentionDualPoolSmokeBenchmark.test_v1_dual_pool_cross_attention_smoke
only exercises the context path (skip_cross_kv_projection=False) and never
measures the cached-generation hop; update the test to also time the generation
cross-attention call path (skip_cross_kv_projection=True) using the same setup:
add warmup invocation(s) for the generation path, perform a separate timed loop
(or include both calls per iteration) that calls cross_attn with
skip_cross_kv_projection=True (using the same
context_metadata/context_cross_metadata and kv_managers), compute ms_per_iter
for the generation hop (or both hops separately), print its latency/tokens/s,
and add a matching loose assert (similar to the existing self.assertLess) to
catch regressions on the cached read path in addition to the current
context-path check.</details> <details> <summary>🪄 Autofix (Beta)</summary> Fix all unresolved CodeRabbit comments on this PR: - [ ] <!-- {"checkboxId": "4b0d0e0a-96d7-4f10-b296-3a18ea78f0b9"} --> Push a commit to this branch (recommended) - [ ] <!-- {"checkboxId": "ff5b1114-7d8c-49e6-8ac1-43f82af23a33"} --> Create a new PR with the fixes </details> --- <details> <summary>ℹ️ Review info</summary> <details> <summary>⚙️ Run configuration</summary> **Configuration used**: Path: .coderabbit.yaml **Review profile**: CHILL **Plan**: Enterprise **Run ID**: `3e5f2a57-547a-4eae-89a3-2d43b6694f91` </details> <details> <summary>📥 Commits</summary> Reviewing files that changed from the base of the PR and between fd2eb033cf899d5165b8836a52d38fc759782026 and bdf09fa73d1c38fc306687a30d47f1ac7a14fc35. </details> <details> <summary>📒 Files selected for processing (48)</summary> * `cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h` * `cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp` * `cpp/tensorrt_llm/common/attentionOp.cpp` * `cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp` * `cpp/tensorrt_llm/nanobind/thop/bindings.cpp` * `cpp/tensorrt_llm/thop/attentionOp.cpp` * `cpp/tensorrt_llm/thop/attentionOp.h` * `cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp` * `encoder_decoder_porting_guide.md` * `legacy_enc_dec_architecture.md` * `tensorrt_llm/_torch/attention_backend/interface.py` * `tensorrt_llm/_torch/attention_backend/trtllm.py` * `tensorrt_llm/_torch/attention_backend/trtllm_gen.py` * `tensorrt_llm/_torch/attention_backend/vanilla.py` * `tensorrt_llm/_torch/model_config.py` * `tensorrt_llm/_torch/models/__init__.py` * `tensorrt_llm/_torch/models/modeling_bart.py` * `tensorrt_llm/_torch/models/modeling_t5.py` * `tensorrt_llm/_torch/modules/cross_attention.py` * `tensorrt_llm/_torch/modules/encoder_decoder_layer.py` * `tensorrt_llm/_torch/modules/rms_norm.py` * `tensorrt_llm/_torch/pyexecutor/_util.py` * `tensorrt_llm/_torch/pyexecutor/llm_request.py` * `tensorrt_llm/_torch/pyexecutor/model_engine.py` * `tensorrt_llm/_torch/pyexecutor/model_loader.py` * `tensorrt_llm/_torch/pyexecutor/py_executor.py` * `tensorrt_llm/_torch/pyexecutor/resource_manager.py` * `tensorrt_llm/_torch/pyexecutor/scheduler/scheduler.py` * `tensorrt_llm/_torch/pyexecutor/scheduler/scheduler_v2.py` * `tensorrt_llm/executor/base_worker.py` * `tensorrt_llm/executor/executor.py` * `tensorrt_llm/executor/request.py` * `tensorrt_llm/executor/result.py` * `tensorrt_llm/inputs/data.py` * `tensorrt_llm/llmapi/llm.py` * `tensorrt_llm/llmapi/llm_args.py` * `tests/integration/defs/llmapi/test_llm_api_pytorch_t5.py` * `tests/unittest/_torch/executor/test_dual_pool_kv_cache.py` * `tests/unittest/_torch/executor/test_encoder_step.py` * `tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py` * `tests/unittest/_torch/executor/test_py_scheduler.py` * `tests/unittest/_torch/executor/test_request_utils.py` * `tests/unittest/_torch/modeling/test_modeling_enc_dec.py` * `tests/unittest/_torch/test_model_config.py` * `tests/unittest/api_stability/references/llm.yaml` * `tests/unittest/api_stability/references/request_output.yaml` * `tests/unittest/llmapi/test_encoder_decoder_request_api.py` * `tests/unittest/llmapi/test_llm_args.py` </details> </details> <!-- This is an auto-generated comment by CodeRabbit for review status -->
| kv_cache_manager::BaseKVCacheManager& kvCacheManager, | ||
| OptionalRef<kv_cache_manager::BaseKVCacheManager> crossKvCacheManager, | ||
| OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const | ||
| { | ||
| kvCacheManager.startScheduling(); | ||
| if (crossKvCacheManager) | ||
| { | ||
| crossKvCacheManager->startScheduling(); | ||
| } | ||
|
|
||
| // The optimization of delaying requests won't work for variable window attention | ||
| bool skippingIsRelevant = !kvCacheManager.getBlockManager().isVariableWindow(); | ||
|
|
||
| // Keep track of number of requests and block needed for the scheduled requests | ||
| auto scheduledBlocksManager | ||
| = kv_cache_manager::MaxUtilizationScheduledBlocksManager(kvCacheManager, mTwoStepsLookAhead); | ||
| // Mirror the budget tracker for the cross pool when present. | ||
| // Encoder-init requests do not consume either tracker; decoder | ||
| // context/generation requests update both trackers in lockstep. | ||
| std::optional<kv_cache_manager::MaxUtilizationScheduledBlocksManager> scheduledCrossBlocksManager; | ||
| if (crossKvCacheManager) | ||
| { | ||
| scheduledCrossBlocksManager.emplace(*crossKvCacheManager, mTwoStepsLookAhead); | ||
| } | ||
| SizeType32 numScheduledPeftPages{0}; | ||
| std::unordered_set<uint64_t> seenTaskIds; | ||
|
|
||
| // Keep track of blocks contributed by requests in context phase | ||
| auto [newlyContributedContextBlocks, newlyContributedCrossContextBlocks] | ||
| = prefillWithChunkedContextsAlreadyExecuting(activeRequests, kvCacheManager); | ||
| = prefillWithChunkedContextsAlreadyExecuting(activeRequests, kvCacheManager, crossKvCacheManager); |
There was a problem hiding this comment.
Disable cross-pool reuse prefill when the cross cache is variable-window.
The new MAX_UTILIZATION path now forwards crossKvCacheManager into prefillWithChunkedContextsAlreadyExecuting(...), but skippingIsRelevant only checks kvCacheManager. If the cross manager has variable windows and block reuse enabled, prefillWithChunkedContextsAlreadyExecuting() can still reach crossKvCacheManager->analyzePrefixReuse(...), which the guaranteed-no-evict path already treats as unsupported.
Suggested fix
- bool skippingIsRelevant = !kvCacheManager.getBlockManager().isVariableWindow();
+ bool skippingIsRelevant = !kvCacheManager.getBlockManager().isVariableWindow()
+ && (!crossKvCacheManager || !crossKvCacheManager->getBlockManager().isVariableWindow());
@@
- auto [newlyContributedContextBlocks, newlyContributedCrossContextBlocks]
- = prefillWithChunkedContextsAlreadyExecuting(activeRequests, kvCacheManager, crossKvCacheManager);
+ std::unordered_set<BlockKey, BlockKeyHasher> newlyContributedContextBlocks;
+ std::unordered_set<BlockKey, BlockKeyHasher> newlyContributedCrossContextBlocks;
+ if (skippingIsRelevant)
+ {
+ std::tie(newlyContributedContextBlocks, newlyContributedCrossContextBlocks)
+ = prefillWithChunkedContextsAlreadyExecuting(activeRequests, kvCacheManager, crossKvCacheManager);
+ }🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp` around lines 421 - 450,
The code is forwarding crossKvCacheManager into
prefillWithChunkedContextsAlreadyExecuting even when the cross cache uses
variable-window blocks; compute a crossSkippingRelevant flag (e.g. bool
crossSkippingRelevant = crossKvCacheManager &&
!crossKvCacheManager->getBlockManager().isVariableWindow()) and only pass
crossKvCacheManager into prefillWithChunkedContextsAlreadyExecuting when that
flag is true—otherwise pass an empty OptionalRef/std::nullopt—so cross-pool
reuse prefill is disabled for variable-window cross caches; refer to
prefillWithChunkedContextsAlreadyExecuting, crossKvCacheManager, kvCacheManager,
and skippingIsRelevant to locate where to change.
| int64_t num_contexts, int64_t num_ctx_tokens, bool const cross_attention, std::optional<torch::Tensor> cross_kv, | ||
| std::optional<torch::Tensor> encoder_input_lengths) | ||
| { | ||
| TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx); | ||
| // Use these tensors to infer if the attention is using KV cache | ||
| bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_pool_pointers.has_value() | ||
| && host_kv_cache_pool_mapping.has_value(); | ||
|
|
||
| TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv, "Only fused QKV is supported for non-MLA attention now"); | ||
| TLLM_CHECK_WITH_INFO(update_kv_cache, "KV cache update cannot be disabled now"); | ||
| // Cross-attention layers (Step 5β) carry only Q via ``q``; encoder K/V is | ||
| // delivered through the dedicated ``cross_kv`` argument and lives in the | ||
| // paged cross-KV cache. Hence neither ``is_fused_qkv`` nor | ||
| // ``update_kv_cache`` are required to be true for cross-attention: during | ||
| // the decoder context step we still write the (separately-projected) | ||
| // encoder K/V into the cross pool via ``cross_kv``, and during decoder | ||
| // generation we only read from that pool, so ``update_kv_cache`` is False. | ||
| TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv || cross_attention, | ||
| "Only fused QKV is supported for non-MLA non-cross attention now"); | ||
| TLLM_CHECK_WITH_INFO(update_kv_cache || cross_attention, "KV cache update cannot be disabled now"); | ||
| auto qkv_or_q = q; | ||
| if (is_fused_qkv) | ||
| { | ||
| TLLM_CHECK_WITH_INFO(!k.has_value(), "The k tensor should be null if using fused QKV"); | ||
| TLLM_CHECK_WITH_INFO(!v.has_value(), "The v tensor should be null if using fused QKV"); | ||
| } | ||
| if (!is_fused_qkv && update_kv_cache) | ||
| if (!is_fused_qkv && update_kv_cache && !cross_attention) |
There was a problem hiding this comment.
Strengthen the cross-attention precondition checks.
The new path assumes q is Q-only and that encoder-side metadata is present, but the wrapper currently only relaxes the old self-attention checks. With cross_attention=True, a caller can still pass is_fused_qkv=True, omit encoder_input_lengths, or omit cross_kv on the context step and still reach the kernel path with the wrong buffer layout or null metadata.
Suggested guardrail
TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv || cross_attention,
"Only fused QKV is supported for non-MLA non-cross attention now");
TLLM_CHECK_WITH_INFO(update_kv_cache || cross_attention, "KV cache update cannot be disabled now");
+ if (cross_attention)
+ {
+ TLLM_CHECK_WITH_INFO(!is_fused_qkv,
+ "Cross attention expects q to contain Q only; use cross_kv for encoder K/V.");
+ TLLM_CHECK_WITH_INFO(encoder_input_lengths.has_value(),
+ "encoder_input_lengths is required for cross attention.");
+ TLLM_CHECK_WITH_INFO(num_contexts == 0 || cross_kv.has_value(),
+ "cross_kv is required when scheduling cross-attention context requests.");
+ }🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@cpp/tensorrt_llm/thop/attentionOp.cpp` around lines 664 - 688, Add explicit
precondition checks when cross_attention is true: assert that is_fused_qkv is
false, cross_kv.has_value() is true, and encoder_input_lengths.has_value() is
true, and that caller did not pass per-call k or v (i.e. !k.has_value() &&
!v.has_value()). Implement these via TLLM_CHECK_WITH_INFO right after existing
initial checks (use the same style as the existing checks) so the
cross-attention kernel path cannot be entered with fused QKV, missing encoder
metadata, or unexpected k/v buffers.
| if encoder_num_cached_tokens_per_seq is not None: | ||
| from ..metadata import KVCacheParams | ||
| base_params = self.kv_cache_params | ||
| cross_md.kv_cache_params = KVCacheParams( | ||
| use_cache=base_params.use_cache if base_params is not None else | ||
| (cross_kv_cache_manager is not None), | ||
| num_cached_tokens_per_seq=list( | ||
| encoder_num_cached_tokens_per_seq), | ||
| block_ids_per_seq=base_params.block_ids_per_seq | ||
| if base_params is not None else None, | ||
| host_max_attention_window_sizes=base_params. | ||
| host_max_attention_window_sizes | ||
| if base_params is not None else None, | ||
| host_sink_token_length=base_params.host_sink_token_length | ||
| if base_params is not None else None, | ||
| num_extra_kv_tokens=base_params.num_extra_kv_tokens | ||
| if base_params is not None else 0, | ||
| ) |
There was a problem hiding this comment.
Fix kv_cache_params initialization in cross metadata (current defaults can leak self-attention cache state).
When encoder_num_cached_tokens_per_seq is None, cross_md keeps shallow-copied self.kv_cache_params, which contradicts the docstring default (“0 cached tokens”). Also, use_cache should not inherit from base_params when cross_kv_cache_manager is None; that breaks the documented stateless path.
💡 Proposed fix
- if encoder_num_cached_tokens_per_seq is not None:
- from ..metadata import KVCacheParams
- base_params = self.kv_cache_params
- cross_md.kv_cache_params = KVCacheParams(
- use_cache=base_params.use_cache if base_params is not None else
- (cross_kv_cache_manager is not None),
- num_cached_tokens_per_seq=list(
- encoder_num_cached_tokens_per_seq),
- block_ids_per_seq=base_params.block_ids_per_seq
- if base_params is not None else None,
- host_max_attention_window_sizes=base_params.
- host_max_attention_window_sizes
- if base_params is not None else None,
- host_sink_token_length=base_params.host_sink_token_length
- if base_params is not None else None,
- num_extra_kv_tokens=base_params.num_extra_kv_tokens
- if base_params is not None else 0,
- )
+ from ..metadata import KVCacheParams
+ base_params = self.kv_cache_params
+ cached_tokens = (
+ list(encoder_num_cached_tokens_per_seq)
+ if encoder_num_cached_tokens_per_seq is not None
+ else [0] * int(encoder_seq_lens.shape[0])
+ )
+ cross_md.kv_cache_params = KVCacheParams(
+ use_cache=(cross_kv_cache_manager is not None),
+ num_cached_tokens_per_seq=cached_tokens,
+ block_ids_per_seq=base_params.block_ids_per_seq
+ if base_params is not None else None,
+ host_max_attention_window_sizes=base_params.host_max_attention_window_sizes
+ if base_params is not None else None,
+ host_sink_token_length=base_params.host_sink_token_length
+ if base_params is not None else None,
+ num_extra_kv_tokens=base_params.num_extra_kv_tokens
+ if base_params is not None else 0,
+ )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/attention_backend/interface.py` around lines 457 - 474,
The cross-attention metadata currently shallow-copies self.kv_cache_params when
encoder_num_cached_tokens_per_seq is None, which can leak self-attention cache
state and wrongly inherit use_cache; update the block that sets
cross_md.kv_cache_params (the KVCacheParams construction) so that when
encoder_num_cached_tokens_per_seq is None you create a fresh KVCacheParams with
num_cached_tokens_per_seq set to zeros/empty per-seq list (representing the
documented “0 cached tokens”) and set use_cache strictly to
(cross_kv_cache_manager is not None) rather than inheriting
base_params.use_cache; for other optional fields (block_ids_per_seq,
host_max_attention_window_sizes, host_sink_token_length, num_extra_kv_tokens)
copy from base_params only if base_params is not None, otherwise use safe
defaults (None or 0).
| self.mlp = MLP( | ||
| hidden_size=hidden_size, | ||
| intermediate_size=ffn_dim, | ||
| bias=True, | ||
| activation=F.gelu, | ||
| dtype=config.torch_dtype, | ||
| config=model_config, |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
In Hugging Face Transformers, how does BartModel use BartConfig.scale_embedding and BartConfig.activation_function when building BART/MBART?
💡 Result:
In Hugging Face Transformers' BartModel (which supports both BART and mBART via BartConfig), BartConfig.scale_embedding and BartConfig.activation_function are used during model construction as follows: BartConfig.scale_embedding (default: False): Determines whether to scale token embeddings by the square root of the model dimension (d_model). In BartModel.init, the shared embedding layer is initialized as BartScaledWordEmbedding with embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0.[1][2] BartConfig.activation_function (default: "gelu"): Specifies the activation function used in the feed-forward networks (FFNs) of the encoder and decoder layers. In BartEncoderLayer and BartDecoderLayer constructors, self.activation_fn = ACT2FN[config.activation_function], where ACT2FN maps strings like "gelu" to the corresponding function (e.g., torch.nn.functional.gelu).[1][3][4] These parameters are passed to BartEncoder(config) and BartDecoder(config), which instantiate the layers using them. mBART uses the same code path with MBartConfig subclassing BartConfig.[5][6]
Citations:
- 1: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bart/modeling_bart.py
- 2: https://docs.vllm.ai/en/v0.10.1/api/vllm/model_executor/models/bart.html
- 3: https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bart/configuration_bart.py
- 4: https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/bart/modeling_flax_bart.py
- 5: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bart/configuration_bart.py
- 6: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/mbart/configuration_mbart.py
🏁 Script executed:
# First, let's examine the file structure and the specific lines mentioned
wc -l tensorrt_llm/_torch/models/modeling_bart.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 111
🏁 Script executed:
# Read the relevant sections: lines 177-183 (activation), 256-263, and 433-435 (embedding scale)
sed -n '170,190p' tensorrt_llm/_torch/models/modeling_bart.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 654
🏁 Script executed:
# Check lines 256-263 (also FFN activation in decoder)
sed -n '250,270p' tensorrt_llm/_torch/models/modeling_bart.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 634
🏁 Script executed:
# Check lines 433-435 (embedding scale)
sed -n '425,445p' tensorrt_llm/_torch/models/modeling_bart.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 856
🏁 Script executed:
# Look for config/BartConfig initialization and usage in the file
rg -n "BartConfig|config\." tensorrt_llm/_torch/models/modeling_bart.py | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 1536
🏁 Script executed:
# Check imports to understand what config is being used
head -50 tensorrt_llm/_torch/models/modeling_bart.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2085
🏁 Script executed:
# Check if BartConfig has activation_function and scale_embedding attributes
python3 << 'EOF'
from transformers import BartConfig
# Create a default config
config = BartConfig()
# Check for relevant attributes
print("scale_embedding:", hasattr(config, "scale_embedding"), getattr(config, "scale_embedding", "NOT FOUND"))
print("activation_function:", hasattr(config, "activation_function"), getattr(config, "activation_function", "NOT FOUND"))
# Let's also print the default values
print("\nDefault BartConfig attributes:")
print(f" scale_embedding = {config.scale_embedding}")
print(f" activation_function = {config.activation_function}")
EOFRepository: NVIDIA/TensorRT-LLM
Length of output: 188
🏁 Script executed:
# Check if there are any existing references to scale_embedding or activation_function in the codebase
rg -n "scale_embedding|activation_function" tensorrt_llm/_torch/models/Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Look at the full context for the BartModel class around the embedding scale
sed -n '400,450p' tensorrt_llm/_torch/models/modeling_bart.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2023
🏁 Script executed:
# Check if there's a helper function already for activation mapping
rg -n "ACT2FN|activation_fn" tensorrt_llm/_torch/models/modeling_bart.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Look at the documentation comment at the top of the file
head -30 tensorrt_llm/_torch/models/modeling_bart.py | grep -A5 "Embedding scale"Repository: NVIDIA/TensorRT-LLM
Length of output: 154
🏁 Script executed:
# Check if there's any conditional logic elsewhere in the file for config attributes
rg -n "getattr.*config" tensorrt_llm/_torch/models/modeling_bart.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 193
🏁 Script executed:
# Verify the exact structure of both encoder and decoder layer constructors
sed -n '155,200p' tensorrt_llm/_torch/models/modeling_bart.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1405
🏁 Script executed:
# Also check the decoder layer around 256-263
sed -n '235,275p' tensorrt_llm/_torch/models/modeling_bart.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1227
🏁 Script executed:
# Verify the exact line numbers for the embedding scale and activation functions
grep -n "self.embed_scale\|activation=F.gelu" tensorrt_llm/_torch/models/modeling_bart.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 367
🏁 Script executed:
# Check the broader context to understand if these are in the encoder or decoder or both
sed -n '160,195p' tensorrt_llm/_torch/models/modeling_bart.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 1319
🏁 Script executed:
# Look at model class structure to understand scope
grep -n "^class " tensorrt_llm/_torch/models/modeling_bart.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 487
🏁 Script executed:
# Verify the actual usage of embed_scale in the forward method
sed -n '450,470p' tensorrt_llm/_torch/models/modeling_bart.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 955
🏁 Script executed:
# Check if there are tests that verify checkpoint parity
find . -name "*test*bart*" -o -name "*bart*test*" | head -10Repository: NVIDIA/TensorRT-LLM
Length of output: 45
Use config-driven activation function and embedding scale to maintain HF BART checkpoint parity.
HF BART uses BartConfig.activation_function (default: "gelu") and BartConfig.scale_embedding (default: False) to control FFN activation and embedding scaling. Hardcoding F.gelu and always applying sqrt(d_model) scaling will break checkpoint compatibility, especially since the default BART configuration does not scale embeddings.
Add a helper function to map config.activation_function to the correct PyTorch function, and make embedding scale conditional on config.scale_embedding:
Suggested implementation
+def _bart_activation_fn(config: BartConfig):
+ act_name = getattr(config, "activation_function", "gelu")
+ if act_name == "relu":
+ return F.relu
+ if act_name == "gelu_new":
+ return lambda x: F.gelu(x, approximate="tanh")
+ return F.gelu
+
+
class BartEncoderLayer(EncoderLayer):
def __init__(
self,
@@ -178,7 +187,7 @@ class BartEncoderLayer(EncoderLayer):
self.mlp = MLP(
hidden_size=hidden_size,
intermediate_size=ffn_dim,
bias=True,
- activation=F.gelu,
+ activation=_bart_activation_fn(config),
dtype=config.torch_dtype,
config=model_config,
layer_idx=layer_idx,
@@ -256,7 +265,7 @@ class BartDecoderLayer(EncoderDecoderLayer):
self.mlp = MLP(
hidden_size=hidden_size,
intermediate_size=ffn_dim,
bias=True,
- activation=F.gelu,
+ activation=_bart_activation_fn(config),
dtype=config.torch_dtype,
config=model_config,
layer_idx=layer_idx,
@@ -431,7 +440,11 @@ class BartModel(nn.Module):
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
- self.embed_scale = math.sqrt(config.d_model)
+ self.embed_scale = (
+ math.sqrt(config.d_model)
+ if getattr(config, "scale_embedding", False)
+ else 1.0
+ )
# HF BART learned position embeddings reserve indices 0 and 1.
self.position_id_offset = 2Applies to: lines 181, 260, 433 (and usage at 456, 465)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/models/modeling_bart.py` around lines 177 - 183, The code
hardcodes F.gelu and always applies sqrt(d_model) embedding scaling which breaks
HF BART parity; add a small helper (e.g.,
map_activation(config.activation_function)) that returns the correct torch
activation (gelu, relu, gelu_new, etc.) and pass its result into MLP via the
self.mlp instantiation instead of F.gelu, and change embedding scaling to only
multiply by math.sqrt(self.config.d_model) when config.scale_embedding is True
(use config.scale_embedding) wherever embeddings are scaled (the embed token
creation/use sites and any places building the FFN like MLP and forward uses).
Update all usages that build MLP (self.mlp) and any embedding scaling sites to
read config.activation_function and config.scale_embedding so checkpoint
behavior matches HF BART.
| if position_bias is None or attn_metadata.kv_cache_manager is not None: | ||
| return super().forward( | ||
| position_ids=position_ids, | ||
| hidden_states=hidden_states, | ||
| attn_metadata=attn_metadata, | ||
| attention_mask=attention_mask, | ||
| **kwargs, | ||
| ) |
There was a problem hiding this comment.
Don't drop T5 relative bias on cached decode.
As soon as a KV cache is present, this falls back to Attention.forward() and stops applying position_bias entirely. That makes decoder generation incorrect after the first token, because T5 self-attention still needs the learned relative bias for the current query against cached keys.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/models/modeling_t5.py` around lines 257 - 264, The
current branch returns early when attn_metadata.kv_cache_manager is present,
which causes position_bias to be dropped during cached decoding; change the
logic in the T5 attention implementation (the block that checks position_bias
and attn_metadata.kv_cache_manager before calling super().forward) so that when
a kv cache exists you still pass through and apply the existing position_bias
for the current query against cached keys rather than delegating to
Attention.forward() which omits it; specifically, ensure the call to
super().forward (or the T5Attention.forward path) receives the position_bias
(and does not short-circuit on attn_metadata.kv_cache_manager) and preserve
usage of the learned relative bias computation when handling cached decode in
forward.
| scheduled_cross_blocks_manager: Optional[MaxUtilizationScheduledBlocksManager] = None | ||
| if scheduler.cross_kv_cache_manager is not None: | ||
| scheduler.cross_kv_cache_manager.start_scheduling() | ||
| scheduled_cross_blocks_manager = MaxUtilizationScheduledBlocksManager( | ||
| scheduler.cross_kv_cache_manager, scheduler.two_step_lookahead | ||
| ) |
There was a problem hiding this comment.
Guard cross-pool reuse prefill behind the same variable-window check.
MaxUtilizationPolicy.schedule() now sets up a cross KV manager, but it still calls _prefill_contributed_blocks() unconditionally. That helper can invoke analyze_prefix_reuse() on the cross manager, while _is_skipping_relevant() explicitly marks variable-window caches as unsupported for this optimization. With a variable-window cross pool, this path can still hit the unsupported reuse analysis before the skip check is consulted.
Suggested fix
- newly_contributed_context_blocks, _ = scheduler._prefill_contributed_blocks(active_requests)
+ newly_contributed_context_blocks: set = set()
+ if skipping_is_relevant:
+ newly_contributed_context_blocks, _ = scheduler._prefill_contributed_blocks(
+ active_requests
+ )Also applies to: 1151-1157
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/pyexecutor/scheduler/scheduler.py` around lines 1105 -
1110, MaxUtilizationPolicy.schedule() initializes a cross-pool manager but still
calls _prefill_contributed_blocks() unconditionally, which can invoke
analyze_prefix_reuse() on a variable-window cross KV cache (unsupported per
_is_skipping_relevant()); update the calls that prefill cross-pool reuse (the
invocation of _prefill_contributed_blocks() and the similar block around lines
~1151-1157) to run only when the cross KV cache is NOT a variable-window cache —
i.e., guard use of scheduled_cross_blocks_manager and the prefill helper behind
the same variable-window check that _is_skipping_relevant() uses so
analyze_prefix_reuse() is never called for variable-window cross pools.
| assert inputs.get("prompt") is not None \ | ||
| or inputs.get("prompt_token_ids") is not None | ||
| or inputs.get("prompt_token_ids") is not None \ | ||
| or inputs.get("encoder_inputs") is not None \ | ||
| or inputs.get("encoder_input_token_ids") is not None | ||
| return inputs |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for callers of prompt_inputs() to verify how the return value is consumed
rg -n -A5 'prompt_inputs\(' --type=py -g '!*test*'Repository: NVIDIA/TensorRT-LLM
Length of output: 5036
🏁 Script executed:
cat -n tensorrt_llm/inputs/data.py | head -150Repository: NVIDIA/TensorRT-LLM
Length of output: 5875
🏁 Script executed:
rg -n "class TextPrompt|class TokensPrompt" tensorrt_llm/inputs/data.py -A10Repository: NVIDIA/TensorRT-LLM
Length of output: 737
🏁 Script executed:
rg -n "encoder_inputs|encoder_input_token_ids" tensorrt_llm/inputs/data.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 420
🏁 Script executed:
rg -n "encoder_inputs\|encoder_input_token_ids" --type=py | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
rg -B3 -A3 'encoder_inputs.*=' tensorrt_llm/llmapi/llm.py | head -50Repository: NVIDIA/TensorRT-LLM
Length of output: 2289
🏁 Script executed:
rg -B2 -A2 '"encoder_inputs"' tensorrt_llm/ --type=py | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 1008
🏁 Script executed:
rg -B5 -A10 'inputs.pop\("encoder_inputs"' tensorrt_llm/llmapi/llm.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 974
🏁 Script executed:
rg -n "def _process_\|def _prepare_" tensorrt_llm/llmapi/llm.py | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
rg -B5 -A5 "encoder.*decoder.*models" tensorrt_llm/ --type=py | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 4810
🏁 Script executed:
cat tensorrt_llm/inputs/data.py | sed -n '1,90p'Repository: NVIDIA/TensorRT-LLM
Length of output: 3008
Type contract violation: encoder-only inputs do not satisfy the return type.
The assertion allows dicts containing only encoder_inputs or encoder_input_token_ids (without prompt or prompt_token_ids) to pass and be returned. However, both TextPrompt and TokensPrompt TypedDicts have required fields (prompt: str and prompt_token_ids: List[int] respectively). A dict with only encoder fields does not satisfy either type contract.
While encoder fields are marked NotRequired in the schema, they are optional additions to dicts that still require one of the main fields. Downstream code relying on type safety would incorrectly assume prompt or prompt_token_ids is available.
Either the assertion should require at least one of prompt or prompt_token_ids, or the return type should be updated to reflect that encoder-only inputs are supported.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/inputs/data.py` around lines 98 - 102, The current assertion in
the function returning inputs allows encoder-only dicts which violate the
TextPrompt/TokensPrompt return types; update the assertion so it requires at
least one of the main prompt fields (i.e., ensure inputs.get("prompt") is not
None or inputs.get("prompt_token_ids") is not None) rather than permitting
encoder_inputs or encoder_input_token_ids alone, and keep the return type as the
existing TypedDict (TextPrompt/TokensPrompt); adjust the assert expression
(referencing inputs, "prompt", "prompt_token_ids", "encoder_inputs",
"encoder_input_token_ids") accordingly so downstream code can rely on
prompt/prompt_token_ids being present.
| unbatched = self._is_unbatched_optional_inputs( | ||
| inputs, | ||
| encoder_inputs, | ||
| encoder_input_token_ids, | ||
| decoder_input_token_ids, | ||
| ) | ||
| if inputs is not None and not unbatched: | ||
| if isinstance(inputs[0], int): | ||
| unbatched = True | ||
|
|
||
| if unbatched: | ||
| if unbatched and inputs is not None: | ||
| inputs = [inputs] | ||
|
|
||
| inputs = [prompt_inputs(i) for i in inputs] | ||
| if inputs is None: | ||
| batch_len = 1 | ||
| for value in (encoder_inputs, encoder_input_token_ids, | ||
| decoder_input_token_ids): | ||
| if isinstance(value, | ||
| list) and not self._is_token_id_list(value): | ||
| batch_len = len(value) | ||
| break | ||
| request_inputs_list = [None] * batch_len |
There was a problem hiding this comment.
Validate encoder/decoder batching shape before fan-out.
This now infers unbatched from the first non-None auxiliary input and batch_len from the first list-like one, so mixed shapes can submit multiple requests but still follow the single-request return path later. For example, an unbatched encoder_inputs plus batched decoder_input_token_ids will fan out more than one request and then only return the first result. Please reject mixed batched/unbatched encoder-decoder inputs up front, and when batched, require all provided lists to have the same length.
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| """Unit tests for the encoder iteration helpers in PyExecutor (Step 9). | ||
|
|
||
| Covers the two pure-Python helpers that drive the encoder branch of | ||
| ``_executor_loop`` for encoder-decoder models: | ||
|
|
||
| * ``_split_encoder_decoder_context_requests`` — splits the scheduler's | ||
| context bucket into encoder-init vs decoder-context subsets. | ||
| * ``_scatter_encoder_output`` — slices packed encoder hidden states | ||
| back into per-request tensors and transitions request state from | ||
| ``ENCODER_INIT`` to ``CONTEXT_INIT``. | ||
|
|
||
| These helpers do not touch the model engine or KV cache managers, so | ||
| the tests run on CPU only. | ||
| """ |
There was a problem hiding this comment.
Add the full Apache 2.0 notice block to this new source file.
This new Python file only has the SPDX lines. The repo guidelines for .py sources require the NVIDIA copyright header and the Apache 2.0 license notice block, like the other new Python files in this PR.
As per coding guidelines, "All source files (.cpp, .h, .cu, .py) should contain an NVIDIA copyright header with the year of latest modification and Apache 2.0 license notice."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/unittest/_torch/executor/test_encoder_step.py` around lines 1 - 16, The
file tests/unittest/_torch/executor/test_encoder_step.py only contains SPDX
lines; update its top-of-file header to include the full NVIDIA copyright notice
and Apache-2.0 license notice block used across the repo (matching other new .py
files), i.e. add the multi-line copyright header with the latest modification
year (2026) plus the full Apache License, Version 2.0 text/block directly below
the SPDX lines so the file contains both the NVIDIA copyright header and the
complete Apache 2.0 notice.
| hf_model = ( | ||
| transformers.T5ForConditionalGeneration.from_pretrained(self.model_path) | ||
| .to(self.device) | ||
| .to(self.dtype) | ||
| ) | ||
| hf_model.eval() | ||
| hf_config = hf_model.config | ||
| hf_weights = hf_model.state_dict() | ||
|
|
||
| hf_config.torch_dtype = self.dtype | ||
| model_config = ModelConfig( | ||
| pretrained_config=hf_config, | ||
| attn_backend="VANILLA", | ||
| ) | ||
| from tensorrt_llm._torch.models.modeling_t5 import T5ForConditionalGeneration as TllmT5 | ||
|
|
||
| tllm_model = TllmT5(model_config).to(self.device) | ||
| tllm_model.load_weights(hf_weights) |
There was a problem hiding this comment.
These “real weights” tests never hit the float32→bf16 conversion path.
Both cases cast the HF model to bf16 before calling state_dict(), so load_weights() / _convert_hf_bart_weights() only receive already-converted tensors. That means the tests won't catch regressions in the precision-conversion logic this PR is supposed to validate.
As per coding guidelines, "Assess whether new/changed tests cover happy path, important edge cases, and failure modes relevant to the feature or fix."
Also applies to: 1287-1306
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/unittest/_torch/modeling/test_modeling_enc_dec.py` around lines 1208 -
1225, The tests currently cast the HuggingFace model to bf16 before calling
state_dict(), so load_weights()/_convert_hf_bart_weights never see float32
tensors; change the test to obtain hf_weights from the HF model while it is
still in its original float32 dtype (call hf_model.state_dict() before applying
.to(self.dtype) or create a separate copy for casting), then set
hf_config.torch_dtype and cast the HF model for downstream eval as needed; apply
the same change to the other occurrence that casts before state_dict() so
load_weights(TllmT5.load_weights) exercises the float32→bf16 conversion logic.
Description
This PR adds the core encoder-decoder request flow, including encoder execution, decoder context/generation handling, cross-attention metadata, and cross-KV cache management.
Key Changes
beam_width > 1for encoder-decoder generation.Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.Summary by CodeRabbit
Release Notes
New Features
Documentation