From b7ebcaa0808b4cb5630928598398433162c050d9 Mon Sep 17 00:00:00 2001 From: Andrej Simurka Date: Mon, 8 Jun 2026 08:53:11 +0200 Subject: [PATCH] Wired agent.run_stream_events in streaming_query --- src/app/endpoints/streaming_query.py | 32 +- .../llamastack/__init__.py | 3 +- .../llamastack/_model.py | 242 ++++++++++++ .../llamastack/_transport.py | 14 +- src/utils/agents/streaming.py | 2 +- src/utils/pydantic_ai.py | 9 +- .../e2e/features/steps/llm_query_response.py | 5 - tests/integration/conftest.py | 353 +++++++++++++++++- .../test_streaming_query_byok_integration.py | 206 +++------- .../test_streaming_query_integration.py | 59 ++- .../app/endpoints/test_streaming_query.py | 50 +-- 11 files changed, 740 insertions(+), 235 deletions(-) create mode 100644 src/pydantic_ai_lightspeed/llamastack/_model.py diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 913909bb7..0a11dea98 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -36,6 +36,7 @@ APIStatusError as LLSApiStatusError, ) from openai._exceptions import APIStatusError as OpenAIAPIStatusError +from typing_extensions import deprecated from authentication import get_auth_dependency from authentication.interface import AuthTuple @@ -74,6 +75,10 @@ from models.common.responses.types import ResponseInput from models.common.turn_summary import TurnSummary from models.config import Action +from utils.agents.streaming import ( + generate_agent_response, + retrieve_agent_response_generator, +) from utils.conversation_compaction import ( CompactionResult, CompactionStartedEvent, @@ -329,7 +334,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals media_type=response_media_type, ) - generator, turn_summary = await retrieve_response_generator( + generator, turn_summary = await retrieve_agent_response_generator( responses_params=responses_params, context=context, endpoint_path=endpoint_path, @@ -342,16 +347,21 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals ) return StreamingResponse( - generate_response( + generate_agent_response( generator=generator, context=context, responses_params=responses_params, turn_summary=turn_summary, + background_topic_summary_tasks=_background_topic_summary_tasks, ), media_type=response_media_type, ) +@deprecated( + "Deprecated in favor of utils.agents.streaming.retrieve_agent_response_generator.", + stacklevel=2, +) async def retrieve_response_generator( responses_params: ResponsesApiParams, context: ResponseGeneratorContext, @@ -474,7 +484,7 @@ async def generate_response_with_compaction( request_id=context.request_id, ) - compacted = False + _compacted = False compacted_original_input: Optional[ResponseInput] = None try: async for item in apply_compaction( @@ -491,10 +501,10 @@ async def generate_response_with_compaction( yield stream_compaction_event(context.conversation_id) elif isinstance(item, CompactionResult): responses_params = item.params - compacted = item.compacted + _compacted = item.compacted compacted_original_input = item.original_input - generator, turn_summary = await retrieve_response_generator( + generator, turn_summary = await retrieve_agent_response_generator( responses_params=responses_params, context=context, endpoint_path=endpoint_path, @@ -531,18 +541,22 @@ async def generate_response_with_compaction( # The start event was already emitted above; delegate the rest (re-yield, # finalization, compacted-turn storage) to the shared generator. - async for event in generate_response( + async for event in generate_agent_response( generator, context, responses_params, turn_summary, + background_topic_summary_tasks=_background_topic_summary_tasks, emit_start=False, - compacted=compacted, original_input=compacted_original_input, ): yield event +@deprecated( + "Deprecated in favor of utils.agents.streaming.generate_agent_response.", + stacklevel=2, +) async def generate_response( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-branches,too-many-statements generator: AsyncIterator[str], context: ResponseGeneratorContext, @@ -711,6 +725,10 @@ async def generate_response( # pylint: disable=too-many-arguments,too-many-posi ) +@deprecated( + "Deprecated in favor of utils.agents.streaming.agent_response_generator.", + stacklevel=2, +) async def response_generator( # pylint: disable=too-many-branches,too-many-statements,too-many-locals turn_response: AsyncIterator[OpenAIResponseObjectStream], context: ResponseGeneratorContext, diff --git a/src/pydantic_ai_lightspeed/llamastack/__init__.py b/src/pydantic_ai_lightspeed/llamastack/__init__.py index 47eda1e7d..fac9ee826 100644 --- a/src/pydantic_ai_lightspeed/llamastack/__init__.py +++ b/src/pydantic_ai_lightspeed/llamastack/__init__.py @@ -1,5 +1,6 @@ """Pydantic AI provider for Llama Stack.""" +from pydantic_ai_lightspeed.llamastack._model import LlamaStackResponsesModel from pydantic_ai_lightspeed.llamastack._provider import LlamaStackProvider -__all__ = ["LlamaStackProvider"] +__all__ = ["LlamaStackProvider", "LlamaStackResponsesModel"] diff --git a/src/pydantic_ai_lightspeed/llamastack/_model.py b/src/pydantic_ai_lightspeed/llamastack/_model.py new file mode 100644 index 000000000..86ce804c1 --- /dev/null +++ b/src/pydantic_ai_lightspeed/llamastack/_model.py @@ -0,0 +1,242 @@ +"""Custom OpenAI Responses model that works around Llama Stack streaming quirks. + +Llama Stack's Responses API emits ``ResponseFunctionCallArgumentsDeltaEvent`` for MCP +tool calls *before* the corresponding ``ResponseOutputItemAddedEvent``. pydantic_ai's +default handler creates an orphan ``ToolCallPartDelta`` for the unannounced item_id, +which later causes an IndexError in ``part_end_event``. + +Additionally, MCP tool calls arrive as ``McpCall`` items (not ``ResponseFunctionToolCall``), +and pydantic_ai registers them with a ``-call`` vendor_part_id suffix. The buffered +deltas must be replayed with the matching suffix so pydantic_ai can append the +streamed ``tool_args`` content to the correct part. + +This module provides ``LlamaStackResponsesModel`` which wraps the event stream to +buffer those early delta events and replay them correctly once the item is announced. +""" + +from __future__ import annotations as _annotations + +from collections import defaultdict +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any, cast + +from openai import AsyncStream +from openai.types import responses +from pydantic_ai import UnexpectedModelBehavior +from pydantic_ai._run_context import RunContext +from pydantic_ai._utils import PeekableAsyncStream, Unset, number_to_datetime +from pydantic_ai.messages import ModelMessage +from pydantic_ai.models import ( + ModelRequestParameters, + StreamedResponse, + check_allow_model_requests, +) +from pydantic_ai.models.openai import ( + OpenAIResponsesModel, + OpenAIResponsesModelSettings, + OpenAIResponsesStreamedResponse, + _map_api_errors, +) +from pydantic_ai.settings import ModelSettings + +from log import get_logger + +logger = get_logger(__name__) + + +class _FilteredResponseStream: + """Wraps an OpenAI AsyncStream to reorder spurious events from Llama Stack. + + Llama Stack emits ``ResponseFunctionCallArgumentsDeltaEvent`` for MCP tool calls + *before* the ``ResponseOutputItemAddedEvent`` that announces them. This wrapper + buffers those early deltas and replays them once the announcement arrives. + + For ``McpCall`` items specifically, pydantic_ai registers the part with a + ``-call`` vendor_part_id suffix. Buffered deltas are therefore replayed as a + single combined event with the suffixed ``item_id`` so they match the part, plus + a closing ``}`` to complete the outer JSON object that pydantic_ai opens. + """ + + def __init__(self, source: AsyncStream[responses.ResponseStreamEvent]) -> None: + """Wrap an existing stream with reordering logic. + + Args: + source: The raw OpenAI AsyncStream to reorder. + """ + self._source = source + self._announced_item_ids: set[str] = set() + self._buffered_deltas: dict[ + str, list[responses.ResponseFunctionCallArgumentsDeltaEvent] + ] = defaultdict(list) + + async def close(self) -> None: + """Close the underlying stream.""" + await self._source.close() + + def __aiter__(self) -> AsyncIterator[responses.ResponseStreamEvent]: + """Return async iterator that reorders events.""" + return self._filtered_iter() + + async def _filtered_iter( + self, + ) -> AsyncIterator[responses.ResponseStreamEvent]: + """Yield events, buffering early argument deltas until their item is announced.""" + async for event in self._source: + if isinstance(event, responses.ResponseOutputItemAddedEvent): + if ( + isinstance(event.item, responses.ResponseFunctionToolCall) + and event.item.id + ): + item_id = event.item.id + self._announced_item_ids.add(item_id) + yield event + for delta in self._replay_buffered_deltas(item_id): + yield delta + continue + + if isinstance(event.item, responses.response_output_item.McpCall): + item_id = event.item.id + self._announced_item_ids.add(item_id) + yield event + for delta in self._replay_mcp_buffered_deltas(item_id): + yield delta + continue + + elif isinstance(event, responses.ResponseFunctionCallArgumentsDeltaEvent): + if event.item_id not in self._announced_item_ids: + logger.debug( + "Buffering early argument delta for unannounced item_id=%s", + event.item_id, + ) + self._buffered_deltas[event.item_id].append(event) + continue + + yield event + + def _replay_buffered_deltas( + self, item_id: str + ) -> list[responses.ResponseFunctionCallArgumentsDeltaEvent]: + """Return buffered deltas for a ``ResponseFunctionToolCall`` announcement. + + Args: + item_id: The announced item ID. + + Returns: + List of buffered delta events to yield, unchanged. + """ + buffered = self._buffered_deltas.pop(item_id, []) + if buffered: + logger.debug( + "Replaying %d buffered argument deltas for item_id=%s", + len(buffered), + item_id, + ) + return buffered + + def _replay_mcp_buffered_deltas( + self, item_id: str + ) -> list[responses.ResponseFunctionCallArgumentsDeltaEvent]: + """Return buffered deltas for an ``McpCall`` announcement. + + pydantic_ai registers ``McpCall`` parts with ``vendor_part_id=f'{id}-call'`` + and seeds the args string with everything up to ``"tool_args":``. The + buffered deltas contain the actual ``tool_args`` content. We combine them + into a single delta with the suffixed ``item_id`` and append a closing ``}`` + to complete the outer JSON object that pydantic_ai opened. + + Args: + item_id: The announced McpCall item ID. + + Returns: + List containing one synthetic delta event, or empty if nothing buffered. + """ + buffered = self._buffered_deltas.pop(item_id, []) + if not buffered: + return [] + + combined_args = "".join(d.delta for d in buffered) + "}" + logger.debug( + "Replaying %d buffered MCP argument deltas as single event " + "for item_id=%s-call", + len(buffered), + item_id, + ) + return [ + responses.ResponseFunctionCallArgumentsDeltaEvent( + delta=combined_args, + item_id=f"{item_id}-call", + output_index=buffered[0].output_index, + sequence_number=buffered[-1].sequence_number + 1, + type="response.function_call_arguments.delta", + ) + ] + + +class LlamaStackResponsesModel(OpenAIResponsesModel): + """OpenAI Responses model with Llama Stack streaming compatibility fixes. + + Overrides the streaming response processing to buffer and replay + ``ResponseFunctionCallArgumentsDeltaEvent`` events that Llama Stack emits + before the corresponding ``McpCall`` or ``ResponseFunctionToolCall`` item. + """ + + @asynccontextmanager + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, + ) -> AsyncIterator[StreamedResponse]: + """Request a streaming response, filtering Llama Stack-specific event quirks. + + Args: + messages: Model messages for the request. + model_settings: Model-specific settings. + model_request_parameters: Request parameters for the model. + run_context: Optional run context from the agent. + + Yields: + A StreamedResponse with the filtered event stream. + """ + check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) + model_settings_cast = cast(OpenAIResponsesModelSettings, model_settings or {}) + response = await self._responses_create( + messages, True, model_settings_cast, model_request_parameters + ) + + filtered_stream = _FilteredResponseStream(response) + + async with response: + peekable: PeekableAsyncStream[ + responses.ResponseStreamEvent, _FilteredResponseStream + ] = PeekableAsyncStream(filtered_stream) + + with _map_api_errors(self.model_name): + first_chunk = await peekable.peek() + + if isinstance(first_chunk, Unset): + raise UnexpectedModelBehavior( + "Streamed response ended without content or tool calls" + ) + + assert isinstance(first_chunk, responses.ResponseCreatedEvent) + + yield OpenAIResponsesStreamedResponse( + model_request_parameters=model_request_parameters, + _model_name=first_chunk.response.model, + _model_settings=model_settings_cast, + _response=peekable, # type: ignore[arg-type] + _provider_name=self._provider.name, + _provider_url=self._provider.base_url, + _provider_timestamp=( + number_to_datetime(first_chunk.response.created_at) + if first_chunk.response.created_at + else None + ), + ) diff --git a/src/pydantic_ai_lightspeed/llamastack/_transport.py b/src/pydantic_ai_lightspeed/llamastack/_transport.py index 1d63bd60f..e5401bd68 100644 --- a/src/pydantic_ai_lightspeed/llamastack/_transport.py +++ b/src/pydantic_ai_lightspeed/llamastack/_transport.py @@ -17,6 +17,7 @@ ) from llama_stack.core.server.routes import find_matching_route from llama_stack.core.utils.context import preserve_contexts_async_generator +from starlette.responses import StreamingResponse class _AsyncByteStream(httpx.AsyncByteStream): @@ -183,9 +184,16 @@ async def _handle_streaming( result = await func(**merged_body) async def gen() -> AsyncGenerator[bytes, None]: - async for chunk in result: - data = json.dumps(convert_pydantic_to_json_value(chunk)) - yield f"data: {data}\n\n".encode("utf-8") + if isinstance(result, StreamingResponse): + async for chunk in result.body_iterator: + if isinstance(chunk, str): + yield chunk.encode("utf-8") + else: + yield bytes(chunk) + else: + async for chunk in result: + data = json.dumps(convert_pydantic_to_json_value(chunk)) + yield f"data: {data}\n\n".encode("utf-8") wrapped_gen = preserve_contexts_async_generator(gen(), [PROVIDER_DATA_VAR]) diff --git a/src/utils/agents/streaming.py b/src/utils/agents/streaming.py index ef64a75b6..d1799a322 100644 --- a/src/utils/agents/streaming.py +++ b/src/utils/agents/streaming.py @@ -24,7 +24,6 @@ TextPartDelta, ) -from app.endpoints.streaming_query import shield_violation_generator from configuration import configuration from constants import INTERRUPTED_RESPONSE_MESSAGE, MEDIA_TYPE_JSON from log import get_logger @@ -70,6 +69,7 @@ persist_interrupted_turn, register_interrupt_callback, ) +from utils.streaming_sse import shield_violation_generator AgentDispatchEvent: TypeAlias = AgentStreamEvent | AgentRunResultEvent diff --git a/src/utils/pydantic_ai.py b/src/utils/pydantic_ai.py index 5df570dc9..2574c0ca3 100644 --- a/src/utils/pydantic_ai.py +++ b/src/utils/pydantic_ai.py @@ -7,10 +7,13 @@ from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient from llama_stack_client import AsyncLlamaStackClient from pydantic_ai import Agent -from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings +from pydantic_ai.models.openai import OpenAIResponsesModelSettings from models.common.responses.responses_api_params import ResponsesApiParams -from pydantic_ai_lightspeed.llamastack import LlamaStackProvider +from pydantic_ai_lightspeed.llamastack import ( + LlamaStackProvider, + LlamaStackResponsesModel, +) _LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset( { @@ -92,7 +95,7 @@ def build_agent( provider = _llama_stack_provider_from_client(client) settings = _model_settings_from_responses_params(responses_params) - model = OpenAIResponsesModel( + model = LlamaStackResponsesModel( responses_params.model, provider=provider, settings=settings, diff --git a/tests/e2e/features/steps/llm_query_response.py b/tests/e2e/features/steps/llm_query_response.py index 18c76a4cf..b0f992861 100644 --- a/tests/e2e/features/steps/llm_query_response.py +++ b/tests/e2e/features/steps/llm_query_response.py @@ -366,7 +366,6 @@ def _parse_streaming_response(response_text: str) -> dict: full_response = "" full_response_split = [] finished = False - first_token = True stream_error = ( None # {"status_code": int, "response": str, "cause": str} if event "error" ) @@ -380,10 +379,6 @@ def _parse_streaming_response(response_text: str) -> dict: if event == "start": conversation_id = data["data"]["conversation_id"] elif event == "token": - # Skip the first token (shield status message) - if first_token: - first_token = False - continue full_response_split.append(data["data"]["token"]) elif event == "turn_complete": full_response = data["data"]["token"] diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 7fc2edfa0..d7e811e00 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -3,12 +3,30 @@ import os from collections.abc import Generator from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, cast import pytest from fastapi import Request, Response from fastapi.testclient import TestClient -from pytest_mock import MockerFixture +from llama_stack_api.openai_responses import OpenAIResponseObject +from llama_stack_client.types import VersionInfo +from pydantic_ai import AgentRunResultEvent +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + ModelResponse, + NativeToolCallPart, + NativeToolReturnPart, + PartEndEvent, + PartStartEvent, + TextPart, + ToolCallPart, + ToolReturnPart, +) +from pydantic_ai.native_tools import FileSearchTool, MCPServerTool +from pydantic_ai.run import AgentRunResult +from pydantic_ai.usage import RunUsage +from pytest_mock import AsyncMockType, MockerFixture from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.orm import Session, sessionmaker @@ -70,9 +88,6 @@ def create_mock_llm_response( # pylint: disable=too-many-arguments,too-many-pos Returns: Mock LLM response object with the specified configuration. """ - # pylint: disable=import-outside-toplevel - from llama_stack_api.openai_responses import OpenAIResponseObject - mock_response = mocker.MagicMock(spec=OpenAIResponseObject) mock_response.id = "response-123" @@ -154,6 +169,326 @@ def create_mock_tool_call( return mock_tool_call +def create_agent_run_result( # pylint: disable=too-many-arguments,too-many-positional-arguments + mocker: MockerFixture, + *, + content: str = "This is a test response about Ansible.", + response_id: str = "response-123", + input_tokens: int = 10, + output_tokens: int = 5, + model_response: ModelResponse | None = None, + new_messages: list[ModelMessage] | None = None, +) -> AgentRunResult[str]: + """Create a mock AgentRunResult wired for retrieve_agent_response. + + Uses real pydantic-ai message types so build_turn_summary_from_agent_run + exercises the same path as production agent runs. + + Args: + mocker: pytest-mock fixture. + content: Assistant text content for the run. + response_id: Provider response identifier. + input_tokens: Input token count for the run. + output_tokens: Output token count for the run. + model_response: Optional pre-built ModelResponse. + new_messages: Optional message sequence returned by new_messages(). + + Returns: + Mock AgentRunResult compatible with build_turn_summary_from_agent_run. + """ + if model_response is None: + parts = [TextPart(content)] if content else [] + model_response = ModelResponse( + parts=parts, + finish_reason="stop", + provider_response_id=response_id, + ) + + messages = new_messages if new_messages is not None else [model_response] + run_result = mocker.MagicMock(spec=AgentRunResult) + run_result.response = model_response + run_result.usage = RunUsage( + input_tokens=input_tokens, + output_tokens=output_tokens, + requests=1, + ) + run_result.new_messages.return_value = messages + return run_result + + +def create_file_search_agent_run_result( # pylint: disable=too-many-arguments,too-many-positional-arguments + mocker: MockerFixture, + *, + content: str, + response_id: str = "response-tool-rag", + queries: Optional[list[str]] = None, + results: Optional[list[dict[str, Any]]] = None, + input_tokens: int = 10, + output_tokens: int = 5, +) -> AgentRunResult[str]: + """Create an AgentRunResult containing a native file_search tool call.""" + call = NativeToolCallPart( + tool_name=FileSearchTool.kind, + args={"queries": queries or ["test query"]}, + tool_call_id="call-fs-1", + ) + return_part = NativeToolReturnPart( + tool_name=FileSearchTool.kind, + tool_call_id="call-fs-1", + content={ + "status": "success", + "results": results or [], + }, + ) + model_response = ModelResponse( + parts=[call, return_part, TextPart(content)], + finish_reason="stop", + provider_response_id=response_id, + ) + return create_agent_run_result( + mocker, + content=content, + response_id=response_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + model_response=model_response, + ) + + +def create_mcp_list_tools_agent_run_result( # pylint: disable=too-many-arguments,too-many-positional-arguments + mocker: MockerFixture, + *, + content: str, + response_id: str = "response-mcplist", + server_label: str = "kubernetes-server", + tools: Optional[list[dict[str, Any]]] = None, + input_tokens: int = 15, + output_tokens: int = 20, +) -> AgentRunResult[str]: + """Create an AgentRunResult containing an MCP list-tools native tool call.""" + call = NativeToolCallPart( + tool_name=f"{MCPServerTool.kind}:{server_label}", + args={"action": "list_tools"}, + tool_call_id="mcplist-101", + ) + return_part = NativeToolReturnPart( + tool_name=f"{MCPServerTool.kind}:{server_label}", + tool_call_id="mcplist-101", + content={"tools": tools or []}, + ) + model_response = ModelResponse( + parts=[call, return_part, TextPart(content)], + finish_reason="stop", + provider_response_id=response_id, + ) + return create_agent_run_result( + mocker, + content=content, + response_id=response_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + model_response=model_response, + ) + + +def create_multi_tool_agent_run_result( + mocker: MockerFixture, + *, + content: str = "Based on documentation and calculations...", + response_id: str = "response-multi", + input_tokens: int = 40, + output_tokens: int = 60, +) -> AgentRunResult[str]: + """Create an AgentRunResult with file_search and function tool calls.""" + file_search_call = NativeToolCallPart( + tool_name=FileSearchTool.kind, + args={"queries": ["Kubernetes deployment"]}, + tool_call_id="search-1", + ) + file_search_return = NativeToolReturnPart( + tool_name=FileSearchTool.kind, + tool_call_id="search-1", + content={"status": "success", "results": []}, + ) + function_call = ToolCallPart( + tool_name="calculate", + args={"operation": "sum"}, + tool_call_id="func-2", + ) + function_return = ToolReturnPart( + tool_name="calculate", + content={"result": 2}, + tool_call_id="func-2", + ) + model_response = ModelResponse( + parts=[ + file_search_call, + file_search_return, + function_call, + TextPart(content), + ], + finish_reason="stop", + provider_response_id=response_id, + ) + return create_agent_run_result( + mocker, + content=content, + response_id=response_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + model_response=model_response, + new_messages=[model_response, ModelRequest(parts=[function_return])], + ) + + +def set_query_agent_run( + mock_llama_stack_client: AsyncMockType, + mocker: MockerFixture, + **kwargs: Any, +) -> None: + """Configure mock agent.run return value for /query integration tests.""" + mock_llama_stack_client.query_agent.run.return_value = create_agent_run_result( + mocker, + **kwargs, + ) + + +def configure_query_agent_mock( + mocker: MockerFixture, + *, + run_result: AgentRunResult[str] | None = None, + run_side_effect: Any = None, +) -> Any: + """Patch build_agent for /query integration tests and return the mock agent. + + Args: + mocker: pytest-mock fixture. + run_result: AgentRunResult returned by agent.run(). + run_side_effect: Optional exception side effect for agent.run(). + + Returns: + Mock agent exposing AsyncMock run(). + """ + if run_result is None: + run_result = create_agent_run_result(mocker) + + mock_agent = mocker.AsyncMock() + if run_side_effect is not None: + mock_agent.run = mocker.AsyncMock(side_effect=run_side_effect) + else: + mock_agent.run = mocker.AsyncMock(return_value=run_result) + + build_agent_mock = mocker.patch( + "utils.agents.query.build_agent", + return_value=mock_agent, + ) + mock_agent.build_agent_mock = build_agent_mock + return mock_agent + + +def create_text_agent_stream_events( # pylint: disable=too-many-arguments,too-many-positional-arguments + mocker: MockerFixture, + *, + content: str = "Based on the documentation, OpenShift is a Kubernetes distribution.", + response_id: str = "response-inline-stream", + input_tokens: int = 50, + output_tokens: int = 20, +) -> list[Any]: + """Build pydantic-ai stream events for a simple text agent response.""" + run_result = create_agent_run_result( + mocker, + content=content, + response_id=response_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + text_part = TextPart(content=content) + return [ + PartStartEvent(index=0, part=text_part), + PartEndEvent(index=0, part=text_part), + AgentRunResultEvent(result=run_result), + ] + + +def create_file_search_agent_stream_events( # pylint: disable=too-many-arguments,too-many-positional-arguments + mocker: MockerFixture, + *, + content: str, + response_id: str = "response-tool-stream", + queries: Optional[list[str]] = None, + results: Optional[list[dict[str, Any]]] = None, + input_tokens: int = 60, + output_tokens: int = 25, +) -> list[Any]: + """Build pydantic-ai stream events for a file_search tool agent response.""" + run_result = create_file_search_agent_run_result( + mocker, + content=content, + response_id=response_id, + queries=queries, + results=results, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + call_part, return_part, text_part = run_result.response.parts + return [ + PartEndEvent(index=0, part=call_part), + PartStartEvent(index=1, part=return_part), + PartStartEvent(index=2, part=text_part), + PartEndEvent(index=2, part=text_part), + AgentRunResultEvent(result=run_result), + ] + + +def configure_streaming_agent_mock( + mocker: MockerFixture, + *, + stream_events: Optional[list[Any]] = None, +) -> Any: + """Patch build_agent for /streaming_query integration tests. + + Args: + mocker: pytest-mock fixture. + stream_events: Optional pydantic-ai events yielded by run_stream_events. + + Returns: + Mock agent exposing run_stream_events(). + """ + events = stream_events or create_text_agent_stream_events(mocker) + + def _run_stream_events_side_effect(_prompt: str) -> Any: + async def _event_stream() -> Any: + for event in events: + yield event + + ctx = mocker.MagicMock() + ctx.__aenter__ = mocker.AsyncMock(return_value=_event_stream()) + ctx.__aexit__ = mocker.AsyncMock(return_value=None) + return ctx + + mock_agent = mocker.MagicMock() + mock_agent.run_stream_events = mocker.MagicMock( + side_effect=_run_stream_events_side_effect + ) + + build_agent_mock = mocker.patch( + "utils.agents.streaming.build_agent", + return_value=mock_agent, + ) + mock_agent.build_agent_mock = build_agent_mock + return mock_agent + + +def get_agent_responses_params(mock_client: Any) -> Any: + """Return ResponsesApiParams passed to the patched streaming build_agent.""" + return mock_client.build_agent_mock.call_args[0][1] + + +def get_agent_input_text(mock_client: Any) -> str: + """Return the agent prompt text from the patched streaming build_agent call.""" + return cast(str, get_agent_responses_params(mock_client).input) + + # ========================================== # Fixtures # ========================================== @@ -448,10 +783,6 @@ def mock_llama_stack_client_fixture( Yields: mock_client: The mocked Llama Stack client instance. """ - # pylint: disable=import-outside-toplevel - from llama_stack_api.openai_responses import OpenAIResponseObject - from llama_stack_client.types import VersionInfo - # Patch AsyncLlamaStackClientHolder at multiple import locations # This ensures the mock is active both during app startup (app.main) # and during endpoint execution (query, conversations_v1, responses, etc.) @@ -484,6 +815,10 @@ def mock_llama_stack_client_fixture( mock_client.responses.create.return_value = mock_response + mock_agent = configure_query_agent_mock(mocker) + mock_client.query_agent = mock_agent + mock_client.build_agent_mock = mock_agent.build_agent_mock + # Mock models.list mock_model = mocker.MagicMock() mock_model.id = "test-provider/test-model" diff --git a/tests/integration/endpoints/test_streaming_query_byok_integration.py b/tests/integration/endpoints/test_streaming_query_byok_integration.py index c539d4294..33bf90085 100644 --- a/tests/integration/endpoints/test_streaming_query_byok_integration.py +++ b/tests/integration/endpoints/test_streaming_query_byok_integration.py @@ -3,13 +3,12 @@ # pylint: disable=too-many-lines import json -from collections.abc import AsyncIterator, Generator +from collections.abc import Generator from typing import Any import pytest from fastapi import Request, status from fastapi.responses import StreamingResponse -from llama_stack_api.openai_responses import OpenAIResponseObject from pytest_mock import AsyncMockType, MockerFixture import constants @@ -17,6 +16,12 @@ from authentication.interface import AuthTuple from configuration import AppConfig from models.api.requests import QueryRequest +from tests.integration.conftest import ( + configure_streaming_agent_mock, + create_file_search_agent_stream_events, + get_agent_input_text, + get_agent_responses_params, +) from tests.integration.endpoints.test_query_byok_integration import ( _build_base_mock_client, _make_byok_vector_io_response, @@ -50,39 +55,18 @@ async def _collect_sse_events(response: StreamingResponse) -> list[dict[str, Any def _build_base_streaming_mock_client(mocker: MockerFixture) -> Any: """Build a base mock Llama Stack client configured for streaming responses. - Extends the base query mock client with streaming-specific stubs: - conversations.items.create and a streaming responses.create. + Extends the base query mock client with a patched pydantic-ai streaming + agent and topic-summary responses.create stub. """ mock_client = _build_base_mock_client(mocker) - # Streaming additions + mock_agent = configure_streaming_agent_mock(mocker) + mock_client.streaming_agent = mock_agent + mock_client.build_agent_mock = mock_agent.build_agent_mock + mock_client.conversations.items.create = mocker.AsyncMock() - async def _mock_stream() -> AsyncIterator[Any]: - chunk = mocker.MagicMock() - chunk.type = "response.output_text.done" - chunk.text = ( - "Based on the documentation, OpenShift is a Kubernetes distribution." - ) - yield chunk - - # Emit response.completed so referenced_documents propagate to end event - completed_chunk = mocker.MagicMock() - completed_chunk.type = "response.completed" - mock_final = mocker.MagicMock(spec=OpenAIResponseObject) - mock_final.id = "response-inline-stream" - mock_final.error = None - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 50 - mock_usage.output_tokens = 20 - mock_final.usage = mock_usage - mock_final.output = [] - completed_chunk.response = mock_final - yield completed_chunk - - async def _responses_create(**kwargs: Any) -> Any: - if kwargs.get("stream", True): - return _mock_stream() + async def _responses_create(**_kwargs: Any) -> Any: mock_resp = mocker.MagicMock() mock_resp.output = [mocker.MagicMock(content="topic summary")] return mock_resp @@ -152,78 +136,32 @@ def mock_streaming_byok_tool_client_fixture( # pylint: disable=too-many-stateme mock_list_result.data = [mock_vector_store] mock_client.vector_stores.list.return_value = mock_list_result - # Build a streaming response with file_search and completion events - async def _mock_tool_stream() -> AsyncIterator[Any]: - # file_search output item done - item_done_chunk = mocker.MagicMock() - item_done_chunk.type = "response.output_item.done" - item_done_chunk.output_index = 0 - - mock_item = mocker.MagicMock() - mock_item.type = "file_search_call" - mock_item.id = "call-fs-stream-1" - mock_item.queries = ["What is OpenShift?"] - mock_item.status = "completed" - - mock_result = mocker.MagicMock() - mock_result.file_id = "doc-ocp-1" - mock_result.filename = "openshift-docs.txt" - mock_result.score = 0.92 - mock_result.text = "OpenShift is a Kubernetes distribution by Red Hat." - mock_result.attributes = { - "doc_url": "https://docs.redhat.com/ocp/overview", - } - mock_result.model_dump = mocker.Mock( - return_value={ - "file_id": "doc-ocp-1", - "filename": "openshift-docs.txt", - "score": 0.92, - "text": "OpenShift is a Kubernetes distribution.", - "attributes": {"doc_url": "https://docs.redhat.com/ocp/overview"}, - } - ) - mock_item.results = [mock_result] - item_done_chunk.item = mock_item - yield item_done_chunk - - # Text done - text_done_chunk = mocker.MagicMock() - text_done_chunk.type = "response.output_text.done" - text_done_chunk.text = ( - "Based on the documentation, OpenShift is a Kubernetes distribution." - ) - yield text_done_chunk - - # Response completed - completed_chunk = mocker.MagicMock() - completed_chunk.type = "response.completed" - mock_final_response = mocker.MagicMock(spec=OpenAIResponseObject) - mock_final_response.id = "response-tool-stream" - mock_final_response.error = None - - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 60 - mock_usage.output_tokens = 25 - mock_final_response.usage = mock_usage - - # file_search results in the final response output - mock_fs_output = mocker.MagicMock() - mock_fs_output.type = "file_search_call" - mock_fs_output.id = "call-fs-stream-1" - mock_fs_output.results = [mock_result] - mock_final_response.output = [mock_fs_output] - - completed_chunk.response = mock_final_response - yield completed_chunk - - async def _responses_create(**kwargs: Any) -> Any: - if kwargs.get("stream", True): - return _mock_tool_stream() - mock_resp = mocker.MagicMock() - mock_resp.output = [mocker.MagicMock(content="topic summary")] - return mock_resp - - mock_client.responses.create = mocker.AsyncMock(side_effect=_responses_create) + mock_agent = configure_streaming_agent_mock( + mocker, + stream_events=create_file_search_agent_stream_events( + mocker, + content=( + "Based on the documentation, OpenShift is a Kubernetes distribution." + ), + response_id="response-tool-stream", + queries=["What is OpenShift?"], + results=[ + { + "text": "OpenShift is a Kubernetes distribution by Red Hat.", + "score": 0.92, + "attributes": { + "doc_url": "https://docs.redhat.com/ocp/overview", + "title": "openshift-docs.txt", + "document_id": "doc-ocp-1", + }, + } + ], + input_tokens=60, + output_tokens=25, + ), + ) + mock_client.streaming_agent = mock_agent + mock_client.build_agent_mock = mock_agent.build_agent_mock mock_holder_class.return_value.get_client.return_value = mock_client yield mock_client @@ -309,12 +247,8 @@ async def test_streaming_query_byok_inline_rag_injects_context( assert isinstance(response, StreamingResponse) - # Verify RAG context was injected into responses.create input - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_streaming_byok_client.responses.create.call_args_list[0] - call_kwargs = create_call.kwargs - input_text = call_kwargs["input"] + # Verify RAG context was injected into the agent prompt + input_text = get_agent_input_text(mock_streaming_byok_client) assert "file_search found" in input_text assert "OpenShift is a Kubernetes distribution" in input_text @@ -448,11 +382,8 @@ async def test_streaming_query_byok_request_vector_store_ids_filters_configured_ call_kwargs = mock_client.vector_io.query.call_args.kwargs assert call_kwargs["vector_store_id"] == "vs-source-a" - # Verify source-a context was injected into the LLM input - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + # Verify source-a context was injected into the agent prompt + input_text = get_agent_input_text(mock_client) assert "file_search found" in input_text @@ -484,10 +415,7 @@ async def test_streaming_query_byok_inline_rag_empty_vector_store_ids_no_context assert isinstance(response, StreamingResponse) mock_streaming_byok_client.vector_io.query.assert_not_called() - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_streaming_byok_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + input_text = get_agent_input_text(mock_streaming_byok_client) assert "file_search found" not in input_text @@ -525,11 +453,7 @@ async def test_streaming_query_byok_inline_rag_error_handled_gracefully( assert isinstance(response, StreamingResponse) # No inline RAG context should be injected when the search fails. - # "file_search found" is the header added by _format_rag_context when chunks are present. - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_streaming_byok_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + input_text = get_agent_input_text(mock_streaming_byok_client) assert "file_search found" not in input_text @@ -725,17 +649,14 @@ async def test_streaming_query_byok_combined_inline_and_tool_rag( assert isinstance(response, StreamingResponse) assert response.status_code == status.HTTP_200_OK - # Verify inline RAG context was injected - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - call_kwargs = create_call.kwargs - input_text = call_kwargs["input"] + # Verify inline RAG context was injected into the agent prompt + input_text = get_agent_input_text(mock_client) assert "file_search found" in input_text - # Verify tool RAG file_search was passed - assert call_kwargs.get("tools") is not None - assert any(tool.get("type") == "file_search" for tool in call_kwargs["tools"]) + # Verify tool RAG file_search was passed to the agent + responses_params = get_agent_responses_params(mock_client) + assert responses_params.tools is not None + assert any(tool.type == "file_search" for tool in responses_params.tools) # ============================================================================== @@ -812,10 +733,7 @@ async def test_streaming_query_byok_only_configured_rag_id_is_queried( ] assert "vs-source-b" not in queried_stores - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + input_text = get_agent_input_text(mock_client) assert "file_search found" in input_text @@ -897,10 +815,7 @@ async def _side_effect(**kwargs: Any) -> Any: assert isinstance(response, StreamingResponse) # Verify Doc B (weighted 2.0) appears before Doc A (weighted 0.9) in context - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + input_text = get_agent_input_text(mock_client) pos_b = input_text.find("Doc B low similarity boosted") pos_a = input_text.find("Doc A high similarity") assert pos_b != -1 and pos_a != -1 @@ -969,10 +884,7 @@ async def test_streaming_query_rag_content_limit_caps_context( # pylint: disabl assert isinstance(response, StreamingResponse) # Verify the context header reports the capped count - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + input_text = get_agent_input_text(mock_client) expected_header = f"file_search found {constants.INLINE_RAG_MAX_CHUNKS} chunks:" assert expected_header in input_text @@ -1058,10 +970,7 @@ async def _side_effect(**kwargs: Any) -> Any: assert isinstance(response, StreamingResponse) - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + input_text = get_agent_input_text(mock_client) expected_header = f"file_search found {constants.INLINE_RAG_MAX_CHUNKS} chunks:" assert expected_header in input_text @@ -1132,8 +1041,7 @@ async def test_streaming_query_rag_content_limit_caps_inline_rag( # pylint: dis assert isinstance(response, StreamingResponse) - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + input_text = get_agent_input_text(mock_client) expected_header = "file_search found 3 chunks:" assert expected_header in input_text diff --git a/tests/integration/endpoints/test_streaming_query_integration.py b/tests/integration/endpoints/test_streaming_query_integration.py index 5a7e51620..10ecb5a6f 100644 --- a/tests/integration/endpoints/test_streaming_query_integration.py +++ b/tests/integration/endpoints/test_streaming_query_integration.py @@ -1,12 +1,13 @@ """Integration tests for the /streaming_query endpoint (using Responses API).""" -from collections.abc import AsyncIterator, Generator +from collections.abc import Generator from typing import Any import pytest from fastapi import HTTPException, Request, status from fastapi.responses import StreamingResponse from fastapi.testclient import TestClient +from llama_stack_client.types import VersionInfo from pytest_mock import AsyncMockType, MockerFixture from app.endpoints.streaming_query import streaming_query_endpoint_handler @@ -14,6 +15,13 @@ from configuration import AppConfig from models.api.requests import QueryRequest from models.common.query import Attachment +from tests.integration.conftest import ( + configure_streaming_agent_mock, + create_text_agent_stream_events, +) +from tests.integration.endpoints.test_query_byok_integration import ( + _build_base_mock_client, +) @pytest.fixture(name="mock_streaming_llama_stack_client") @@ -22,32 +30,26 @@ def mock_llama_stack_streaming_fixture( ) -> Generator[Any, None, None]: """Mock only the Llama Stack client (holder + client). - Configures the client so the real handler runs: models, vector_stores, - conversations, shields, vector_io, and responses.create returning a minimal - stream. No other code paths are patched. + Configures the client so the real handler runs with a patched pydantic-ai + streaming agent. No other code paths are patched. """ mock_holder_class = mocker.patch( "app.endpoints.streaming_query.AsyncLlamaStackClientHolder" ) - mock_client = mocker.AsyncMock() - - mock_model = mocker.MagicMock() - mock_model.id = "test-provider/test-model" - mock_model.custom_metadata = { - "provider_id": "test-provider", - "model_type": "llm", - } - mock_client.models.list.return_value = [mock_model] - - mock_vector_stores_response = mocker.MagicMock() - mock_vector_stores_response.data = [] - mock_client.vector_stores.list.return_value = mock_vector_stores_response - - mock_conversation = mocker.MagicMock() - mock_conversation.id = "conv_" + "a" * 48 - mock_client.conversations.create = mocker.AsyncMock(return_value=mock_conversation) - - mock_client.shields.list.return_value = [] + mock_client = _build_base_mock_client(mocker) + + mock_agent = configure_streaming_agent_mock( + mocker, + stream_events=create_text_agent_stream_events( + mocker, + content="test", + response_id="response-stream-test", + input_tokens=10, + output_tokens=5, + ), + ) + mock_client.streaming_agent = mock_agent + mock_client.build_agent_mock = mock_agent.build_agent_mock mock_client.conversations.items.create = mocker.AsyncMock() @@ -56,20 +58,13 @@ def mock_llama_stack_streaming_fixture( mock_vector_io_response.scores = [] mock_client.vector_io.query = mocker.AsyncMock(return_value=mock_vector_io_response) - async def _mock_stream() -> AsyncIterator[Any]: - chunk = mocker.MagicMock() - chunk.type = "response.output_text.done" - chunk.text = "test" - yield chunk - - async def _responses_create(**kwargs: Any) -> Any: - if kwargs.get("stream", True): - return _mock_stream() + async def _responses_create(**_kwargs: Any) -> Any: mock_resp = mocker.MagicMock() mock_resp.output = [mocker.MagicMock(content="topic summary")] return mock_resp mock_client.responses.create = mocker.AsyncMock(side_effect=_responses_create) + mock_client.inspect.version.return_value = VersionInfo(version="0.2.22") mock_holder_class.return_value.get_client.return_value = mock_client diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 45762c03e..dd5efd227 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -204,19 +204,19 @@ async def mock_generator() -> AsyncIterator[str]: mock_turn_summary = TurnSummary() mocker.patch( - "app.endpoints.streaming_query.retrieve_response_generator", - return_value=(mock_generator(), mock_turn_summary), + "app.endpoints.streaming_query.retrieve_agent_response_generator", + new=mocker.AsyncMock(return_value=(mock_generator(), mock_turn_summary)), ) - async def mock_generate_response( + async def mock_generate_agent_response( *_args: Any, **_kwargs: Any ) -> AsyncIterator[str]: async for item in mock_generator(): yield item mocker.patch( - "app.endpoints.streaming_query.generate_response", - side_effect=mock_generate_response, + "app.endpoints.streaming_query.generate_agent_response", + side_effect=mock_generate_agent_response, ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id", @@ -291,19 +291,19 @@ async def mock_generator() -> AsyncIterator[str]: mock_turn_summary = TurnSummary() mocker.patch( - "app.endpoints.streaming_query.retrieve_response_generator", - return_value=(mock_generator(), mock_turn_summary), + "app.endpoints.streaming_query.retrieve_agent_response_generator", + new=mocker.AsyncMock(return_value=(mock_generator(), mock_turn_summary)), ) - async def mock_generate_response( + async def mock_generate_agent_response( *_args: Any, **_kwargs: Any ) -> AsyncIterator[str]: async for item in mock_generator(): yield item mocker.patch( - "app.endpoints.streaming_query.generate_response", - side_effect=mock_generate_response, + "app.endpoints.streaming_query.generate_agent_response", + side_effect=mock_generate_agent_response, ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id", @@ -389,19 +389,19 @@ async def mock_generator() -> AsyncIterator[str]: mock_turn_summary = TurnSummary() mocker.patch( - "app.endpoints.streaming_query.retrieve_response_generator", - return_value=(mock_generator(), mock_turn_summary), + "app.endpoints.streaming_query.retrieve_agent_response_generator", + new=mocker.AsyncMock(return_value=(mock_generator(), mock_turn_summary)), ) - async def mock_generate_response( + async def mock_generate_agent_response( *_args: Any, **_kwargs: Any ) -> AsyncIterator[str]: async for item in mock_generator(): yield item mocker.patch( - "app.endpoints.streaming_query.generate_response", - side_effect=mock_generate_response, + "app.endpoints.streaming_query.generate_agent_response", + side_effect=mock_generate_agent_response, ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id", @@ -485,19 +485,19 @@ async def mock_generator() -> AsyncIterator[str]: mock_turn_summary = TurnSummary() mocker.patch( - "app.endpoints.streaming_query.retrieve_response_generator", - return_value=(mock_generator(), mock_turn_summary), + "app.endpoints.streaming_query.retrieve_agent_response_generator", + new=mocker.AsyncMock(return_value=(mock_generator(), mock_turn_summary)), ) - async def mock_generate_response( + async def mock_generate_agent_response( *_args: Any, **_kwargs: Any ) -> AsyncIterator[str]: async for item in mock_generator(): yield item mocker.patch( - "app.endpoints.streaming_query.generate_response", - side_effect=mock_generate_response, + "app.endpoints.streaming_query.generate_agent_response", + side_effect=mock_generate_agent_response, ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id", @@ -583,19 +583,19 @@ async def mock_generator() -> AsyncIterator[str]: mock_turn_summary = TurnSummary() mocker.patch( - "app.endpoints.streaming_query.retrieve_response_generator", - return_value=(mock_generator(), mock_turn_summary), + "app.endpoints.streaming_query.retrieve_agent_response_generator", + new=mocker.AsyncMock(return_value=(mock_generator(), mock_turn_summary)), ) - async def mock_generate_response( + async def mock_generate_agent_response( *_args: Any, **_kwargs: Any ) -> AsyncIterator[str]: async for item in mock_generator(): yield item mocker.patch( - "app.endpoints.streaming_query.generate_response", - side_effect=mock_generate_response, + "app.endpoints.streaming_query.generate_agent_response", + side_effect=mock_generate_agent_response, ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id",