Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@
from models.common.responses.types import ResponseInput
from models.common.turn_summary import TurnSummary
from models.config import Action
from utils.agents.streaming import (
generate_agent_response,
retrieve_agent_response_generator,
)
from utils.conversation_compaction import (
CompactionResult,
CompactionStartedEvent,
Expand Down Expand Up @@ -329,7 +333,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
media_type=response_media_type,
)

generator, turn_summary = await retrieve_response_generator(
generator, turn_summary = await retrieve_agent_response_generator(
responses_params=responses_params,
context=context,
endpoint_path=endpoint_path,
Expand All @@ -342,11 +346,12 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
)

return StreamingResponse(
generate_response(
generate_agent_response(
generator=generator,
context=context,
responses_params=responses_params,
turn_summary=turn_summary,
background_topic_summary_tasks=_background_topic_summary_tasks,
),
media_type=response_media_type,
)
Expand Down Expand Up @@ -762,6 +767,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
# Store MCP call item info for later lookup when arguments.done event occurs
elif event_type == "response.output_item.added":
item_added_chunk = cast(OutputItemAddedChunk, chunk)

if item_added_chunk.item.type == "mcp_call":
mcp_call_item = cast(MCPCall, item_added_chunk.item)
mcp_calls[item_added_chunk.output_index] = (
Expand Down
3 changes: 2 additions & 1 deletion src/pydantic_ai_lightspeed/llamastack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Pydantic AI provider for Llama Stack."""

from pydantic_ai_lightspeed.llamastack._model import LlamaStackResponsesModel
from pydantic_ai_lightspeed.llamastack._provider import LlamaStackProvider

__all__ = ["LlamaStackProvider"]
__all__ = ["LlamaStackProvider", "LlamaStackResponsesModel"]
242 changes: 242 additions & 0 deletions src/pydantic_ai_lightspeed/llamastack/_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
"""Custom OpenAI Responses model that works around Llama Stack streaming quirks.

Llama Stack's Responses API emits ``ResponseFunctionCallArgumentsDeltaEvent`` for MCP
tool calls *before* the corresponding ``ResponseOutputItemAddedEvent``. pydantic_ai's
default handler creates an orphan ``ToolCallPartDelta`` for the unannounced item_id,
which later causes an IndexError in ``part_end_event``.

Additionally, MCP tool calls arrive as ``McpCall`` items (not ``ResponseFunctionToolCall``),
and pydantic_ai registers them with a ``-call`` vendor_part_id suffix. The buffered
deltas must be replayed with the matching suffix so pydantic_ai can append the
streamed ``tool_args`` content to the correct part.

This module provides ``LlamaStackResponsesModel`` which wraps the event stream to
buffer those early delta events and replay them correctly once the item is announced.
"""

from __future__ import annotations as _annotations

from collections import defaultdict
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any, cast

from openai import AsyncStream
from openai.types import responses
from pydantic_ai import UnexpectedModelBehavior
from pydantic_ai._run_context import RunContext
from pydantic_ai._utils import PeekableAsyncStream, Unset, number_to_datetime
from pydantic_ai.messages import ModelMessage
from pydantic_ai.models import (
ModelRequestParameters,
StreamedResponse,
check_allow_model_requests,
)
from pydantic_ai.models.openai import (
OpenAIResponsesModel,
OpenAIResponsesModelSettings,
OpenAIResponsesStreamedResponse,
_map_api_errors,
)
from pydantic_ai.settings import ModelSettings

from log import get_logger

logger = get_logger(__name__)


class _FilteredResponseStream:
"""Wraps an OpenAI AsyncStream to reorder spurious events from Llama Stack.

Llama Stack emits ``ResponseFunctionCallArgumentsDeltaEvent`` for MCP tool calls
*before* the ``ResponseOutputItemAddedEvent`` that announces them. This wrapper
buffers those early deltas and replays them once the announcement arrives.

For ``McpCall`` items specifically, pydantic_ai registers the part with a
``-call`` vendor_part_id suffix. Buffered deltas are therefore replayed as a
single combined event with the suffixed ``item_id`` so they match the part, plus
a closing ``}`` to complete the outer JSON object that pydantic_ai opens.
"""

def __init__(self, source: AsyncStream[responses.ResponseStreamEvent]) -> None:
"""Wrap an existing stream with reordering logic.

Args:
source: The raw OpenAI AsyncStream to reorder.
"""
self._source = source
self._announced_item_ids: set[str] = set()
self._buffered_deltas: dict[
str, list[responses.ResponseFunctionCallArgumentsDeltaEvent]
] = defaultdict(list)

async def close(self) -> None:
"""Close the underlying stream."""
await self._source.close()

def __aiter__(self) -> AsyncIterator[responses.ResponseStreamEvent]:
"""Return async iterator that reorders events."""
return self._filtered_iter()

async def _filtered_iter(
self,
) -> AsyncIterator[responses.ResponseStreamEvent]:
"""Yield events, buffering early argument deltas until their item is announced."""
async for event in self._source:
if isinstance(event, responses.ResponseOutputItemAddedEvent):
if (
isinstance(event.item, responses.ResponseFunctionToolCall)
and event.item.id
):
item_id = event.item.id
self._announced_item_ids.add(item_id)
yield event
for delta in self._replay_buffered_deltas(item_id):
yield delta
continue

if isinstance(event.item, responses.response_output_item.McpCall):
item_id = event.item.id
self._announced_item_ids.add(item_id)
yield event
for delta in self._replay_mcp_buffered_deltas(item_id):
yield delta
continue

elif isinstance(event, responses.ResponseFunctionCallArgumentsDeltaEvent):
if event.item_id not in self._announced_item_ids:
logger.debug(
"Buffering early argument delta for unannounced item_id=%s",
event.item_id,
)
self._buffered_deltas[event.item_id].append(event)
continue

yield event

def _replay_buffered_deltas(
self, item_id: str
) -> list[responses.ResponseFunctionCallArgumentsDeltaEvent]:
"""Return buffered deltas for a ``ResponseFunctionToolCall`` announcement.

Args:
item_id: The announced item ID.

Returns:
List of buffered delta events to yield, unchanged.
"""
buffered = self._buffered_deltas.pop(item_id, [])
if buffered:
logger.debug(
"Replaying %d buffered argument deltas for item_id=%s",
len(buffered),
item_id,
)
return buffered

def _replay_mcp_buffered_deltas(
self, item_id: str
) -> list[responses.ResponseFunctionCallArgumentsDeltaEvent]:
"""Return buffered deltas for an ``McpCall`` announcement.

pydantic_ai registers ``McpCall`` parts with ``vendor_part_id=f'{id}-call'``
and seeds the args string with everything up to ``"tool_args":``. The
buffered deltas contain the actual ``tool_args`` content. We combine them
into a single delta with the suffixed ``item_id`` and append a closing ``}``
to complete the outer JSON object that pydantic_ai opened.

Args:
item_id: The announced McpCall item ID.

Returns:
List containing one synthetic delta event, or empty if nothing buffered.
"""
buffered = self._buffered_deltas.pop(item_id, [])
if not buffered:
return []

combined_args = "".join(d.delta for d in buffered) + "}"
logger.debug(
"Replaying %d buffered MCP argument deltas as single event "
"for item_id=%s-call",
len(buffered),
item_id,
)
return [
responses.ResponseFunctionCallArgumentsDeltaEvent(
delta=combined_args,
item_id=f"{item_id}-call",
output_index=buffered[0].output_index,
sequence_number=buffered[-1].sequence_number + 1,
type="response.function_call_arguments.delta",
)
]


class LlamaStackResponsesModel(OpenAIResponsesModel):
"""OpenAI Responses model with Llama Stack streaming compatibility fixes.

Overrides the streaming response processing to buffer and replay
``ResponseFunctionCallArgumentsDeltaEvent`` events that Llama Stack emits
before the corresponding ``McpCall`` or ``ResponseFunctionToolCall`` item.
"""

@asynccontextmanager
async def request_stream(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
run_context: RunContext[Any] | None = None,
) -> AsyncIterator[StreamedResponse]:
"""Request a streaming response, filtering Llama Stack-specific event quirks.

Args:
messages: Model messages for the request.
model_settings: Model-specific settings.
model_request_parameters: Request parameters for the model.
run_context: Optional run context from the agent.

Yields:
A StreamedResponse with the filtered event stream.
"""
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
model_settings_cast = cast(OpenAIResponsesModelSettings, model_settings or {})
response = await self._responses_create(
messages, True, model_settings_cast, model_request_parameters
)

filtered_stream = _FilteredResponseStream(response)

async with response:
peekable: PeekableAsyncStream[
responses.ResponseStreamEvent, _FilteredResponseStream
] = PeekableAsyncStream(filtered_stream)

with _map_api_errors(self.model_name):
first_chunk = await peekable.peek()

if isinstance(first_chunk, Unset):
raise UnexpectedModelBehavior(
"Streamed response ended without content or tool calls"
)

assert isinstance(first_chunk, responses.ResponseCreatedEvent)

yield OpenAIResponsesStreamedResponse(
model_request_parameters=model_request_parameters,
_model_name=first_chunk.response.model,
_model_settings=model_settings_cast,
_response=peekable, # type: ignore[arg-type]
_provider_name=self._provider.name,
_provider_url=self._provider.base_url,
_provider_timestamp=(
number_to_datetime(first_chunk.response.created_at)
if first_chunk.response.created_at
else None
),
)
14 changes: 11 additions & 3 deletions src/pydantic_ai_lightspeed/llamastack/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from llama_stack.core.server.routes import find_matching_route
from llama_stack.core.utils.context import preserve_contexts_async_generator
from starlette.responses import StreamingResponse


class _AsyncByteStream(httpx.AsyncByteStream):
Expand Down Expand Up @@ -183,9 +184,16 @@ async def _handle_streaming(
result = await func(**merged_body)

async def gen() -> AsyncGenerator[bytes, None]:
async for chunk in result:
data = json.dumps(convert_pydantic_to_json_value(chunk))
yield f"data: {data}\n\n".encode("utf-8")
if isinstance(result, StreamingResponse):
async for chunk in result.body_iterator:
if isinstance(chunk, str):
yield chunk.encode("utf-8")
else:
yield bytes(chunk)
else:
async for chunk in result:
data = json.dumps(convert_pydantic_to_json_value(chunk))
yield f"data: {data}\n\n".encode("utf-8")

wrapped_gen = preserve_contexts_async_generator(gen(), [PROVIDER_DATA_VAR])

Expand Down
2 changes: 1 addition & 1 deletion src/utils/agents/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
TextPartDelta,
)

from app.endpoints.streaming_query import shield_violation_generator
from configuration import configuration
from constants import INTERRUPTED_RESPONSE_MESSAGE, MEDIA_TYPE_JSON
from log import get_logger
Expand Down Expand Up @@ -70,6 +69,7 @@
persist_interrupted_turn,
register_interrupt_callback,
)
from utils.streaming_sse import shield_violation_generator

AgentDispatchEvent: TypeAlias = AgentStreamEvent | AgentRunResultEvent

Expand Down
9 changes: 6 additions & 3 deletions src/utils/pydantic_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
from llama_stack_client import AsyncLlamaStackClient
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
from pydantic_ai.models.openai import OpenAIResponsesModelSettings

from models.common.responses.responses_api_params import ResponsesApiParams
from pydantic_ai_lightspeed.llamastack import LlamaStackProvider
from pydantic_ai_lightspeed.llamastack import (
LlamaStackProvider,
LlamaStackResponsesModel,
)

_LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset(
{
Expand Down Expand Up @@ -92,7 +95,7 @@ def build_agent(
provider = _llama_stack_provider_from_client(client)
settings = _model_settings_from_responses_params(responses_params)

model = OpenAIResponsesModel(
model = LlamaStackResponsesModel(
responses_params.model,
provider=provider,
settings=settings,
Expand Down
5 changes: 0 additions & 5 deletions tests/e2e/features/steps/llm_query_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ def _parse_streaming_response(response_text: str) -> dict:
full_response = ""
full_response_split = []
finished = False
first_token = True
stream_error = (
None # {"status_code": int, "response": str, "cause": str} if event "error"
)
Expand All @@ -380,10 +379,6 @@ def _parse_streaming_response(response_text: str) -> dict:
if event == "start":
conversation_id = data["data"]["conversation_id"]
elif event == "token":
# Skip the first token (shield status message)
if first_token:
first_token = False
continue
full_response_split.append(data["data"]["token"])
elif event == "turn_complete":
full_response = data["data"]["token"]
Expand Down
Loading
Loading