From aa37f4e61ab333aa67bed41e9a9a0a265e8a575a Mon Sep 17 00:00:00 2001 From: Brian Krabach Date: Tue, 10 Feb 2026 12:14:16 -0800 Subject: [PATCH 1/7] fix: read HookResult modify action from tool:post events (#4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a hook returns action='modify' on tool:post, the orchestrator now detects the modified data by comparing the returned result with the original, and uses it instead of the original get_serialized_output(). This enables tool output truncation, sanitization, and transformation hooks to work correctly. 🤖 Generated with [Amplifier](https://github.com/microsoft/amplifier) Co-authored-by: Amplifier <240397093+microsoft-amplifier@users.noreply.github.com> --- amplifier_module_loop_basic/__init__.py | 24 ++- tests/test_hook_modify.py | 227 ++++++++++++++++++++++++ 2 files changed, 249 insertions(+), 2 deletions(-) create mode 100644 tests/test_hook_modify.py diff --git a/amplifier_module_loop_basic/__init__.py b/amplifier_module_loop_basic/__init__.py index 8bac1f1..af96822 100644 --- a/amplifier_module_loop_basic/__init__.py +++ b/amplifier_module_loop_basic/__init__.py @@ -5,6 +5,7 @@ # Amplifier module metadata __amplifier_module_type__ = "orchestrator" +import json import logging from typing import Any @@ -463,8 +464,27 @@ async def execute_single_tool( f"Stored ephemeral injection from tool:post ({tool_name}) for next iteration" ) - # Return success with result content (JSON-serialized for dict/list) - result_content = result.get_serialized_output() + # Check if a hook modified the tool result. + # hooks.emit() chains modify actions: when a hook + # returns action="modify", the data dict is replaced. + # We detect this by checking if the returned "result" + # is a different object than what we originally sent. + modified_result = None + if post_result and post_result.data is not None: + returned_result = post_result.data.get("result") + if ( + returned_result is not None + and returned_result is not result_data + ): + modified_result = returned_result + + if modified_result is not None: + if isinstance(modified_result, (dict, list)): + result_content = json.dumps(modified_result) + else: + result_content = str(modified_result) + else: + result_content = result.get_serialized_output() return (tool_call_id, result_content) except Exception as te: diff --git a/tests/test_hook_modify.py b/tests/test_hook_modify.py new file mode 100644 index 0000000..491ea1c --- /dev/null +++ b/tests/test_hook_modify.py @@ -0,0 +1,227 @@ +"""Tests for hook modify action on tool:post events. + +Verifies that when a hook returns HookResult(action="modify", data={"result": ...}) +on a tool:post event, the orchestrator uses the modified data instead of the +original result.get_serialized_output(). +""" + +import json + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from amplifier_core.hooks import HookRegistry +from amplifier_core.models import HookResult + + +def _make_tool_result(output, success=True): + """Create a mock tool result with get_serialized_output() method.""" + result = MagicMock() + result.success = success + result.output = output + result.error = None + + def get_serialized_output(): + if isinstance(output, (dict, list)): + return json.dumps(output) + return str(output) + + result.get_serialized_output = get_serialized_output + + def to_dict(): + return {"success": success, "output": output, "error": None} + + result.to_dict = to_dict + return result + + +def _make_provider_responses(tool_calls_response, text_response_str="Done"): + """Create mock provider responses: one with tool calls, one with text.""" + tool_call = MagicMock() + tool_call.id = "tc_1" + tool_call.name = "test_tool" + tool_call.arguments = {"key": "value"} + + tool_response = MagicMock() + tool_response.content = [MagicMock(type="text", text="Using tool")] + tool_response.tool_calls = [tool_call] + tool_response.usage = None + tool_response.content_blocks = None + tool_response.metadata = None + + text_block = MagicMock() + text_block.text = text_response_str + text_block.type = "text" + text_response = MagicMock() + text_response.content = [text_block] + text_response.tool_calls = None + text_response.usage = None + text_response.content_blocks = None + text_response.metadata = None + + mock_provider = AsyncMock() + mock_provider.complete = AsyncMock(side_effect=[tool_response, text_response]) + mock_provider.priority = 1 + + return mock_provider + + +def _make_context(): + """Create a mock context that captures add_message calls.""" + context = AsyncMock() + messages_added = [] + + async def capture_add_message(msg): + messages_added.append(msg) + + context.add_message = AsyncMock(side_effect=capture_add_message) + context.get_messages_for_request = AsyncMock( + return_value=[{"role": "user", "content": "test"}] + ) + return context, messages_added + + +def _get_tool_result_messages(messages_added): + """Extract tool result messages from captured messages.""" + return [msg for msg in messages_added if msg.get("role") == "tool"] + + +@pytest.mark.asyncio +async def test_tool_post_modify_replaces_result(): + """When a hook returns action='modify' on tool:post, the modified data + should be used instead of the original get_serialized_output().""" + with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}): + from amplifier_module_loop_basic import BasicOrchestrator + + orchestrator = BasicOrchestrator({"max_iterations": 5}) + + # Tool with original output + original_output = {"original": True, "big_data": "x" * 1000} + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "A test tool" + mock_tool.input_schema = {"type": "object", "properties": {}} + mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output)) + tools = {"test_tool": mock_tool} + + # Hook that returns modify with new data + modified_content = {"modified": True, "truncated": True} + hooks = HookRegistry() + + async def modify_hook(event: str, data: dict) -> HookResult: + if event == "tool:post": + return HookResult(action="modify", data={"result": modified_content}) + return HookResult() + + hooks.register("tool:post", modify_hook, priority=50, name="test_modify") + + mock_provider = _make_provider_responses(tool_calls_response=True) + providers = {"test_provider": mock_provider} + context, messages_added = _make_context() + + await orchestrator.execute( + prompt="test prompt", + context=context, + providers=providers, + tools=tools, + hooks=hooks, + ) + + tool_msgs = _get_tool_result_messages(messages_added) + assert len(tool_msgs) == 1, f"Expected 1 tool result, got {len(tool_msgs)}" + + tool_result_content = tool_msgs[0]["content"] + + # The content should be the MODIFIED data, not the original + assert tool_result_content == json.dumps(modified_content), ( + f"Expected modified content {json.dumps(modified_content)}, " + f"got {tool_result_content}" + ) + + # Verify the original data is NOT used + assert "big_data" not in tool_result_content + + +@pytest.mark.asyncio +async def test_tool_post_no_modify_uses_original(): + """When no hook returns modify, the original get_serialized_output() is used.""" + with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}): + from amplifier_module_loop_basic import BasicOrchestrator + + orchestrator = BasicOrchestrator({"max_iterations": 5}) + + original_output = {"original": True} + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "A test tool" + mock_tool.input_schema = {"type": "object", "properties": {}} + mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output)) + tools = {"test_tool": mock_tool} + + # No hooks registered — default continue action + hooks = HookRegistry() + + mock_provider = _make_provider_responses(tool_calls_response=True) + providers = {"test_provider": mock_provider} + context, messages_added = _make_context() + + await orchestrator.execute( + prompt="test prompt", + context=context, + providers=providers, + tools=tools, + hooks=hooks, + ) + + tool_msgs = _get_tool_result_messages(messages_added) + assert len(tool_msgs) == 1 + tool_result_content = tool_msgs[0]["content"] + + # Should use original serialized output + assert tool_result_content == json.dumps(original_output), ( + f"Expected original {json.dumps(original_output)}, got {tool_result_content}" + ) + + +@pytest.mark.asyncio +async def test_tool_post_modify_with_string_result(): + """When a hook returns modify with a string result, it should be used as-is.""" + with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}): + from amplifier_module_loop_basic import BasicOrchestrator + + orchestrator = BasicOrchestrator({"max_iterations": 5}) + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.description = "A test tool" + mock_tool.input_schema = {"type": "object", "properties": {}} + mock_tool.execute = AsyncMock(return_value=_make_tool_result("original text")) + tools = {"test_tool": mock_tool} + + hooks = HookRegistry() + + async def modify_hook(event: str, data: dict) -> HookResult: + if event == "tool:post": + return HookResult( + action="modify", + data={"result": "truncated string result"}, + ) + return HookResult() + + hooks.register("tool:post", modify_hook, priority=50, name="test_modify") + + mock_provider = _make_provider_responses(tool_calls_response=True) + providers = {"test_provider": mock_provider} + context, messages_added = _make_context() + + await orchestrator.execute( + prompt="test prompt", + context=context, + providers=providers, + tools=tools, + hooks=hooks, + ) + + tool_msgs = _get_tool_result_messages(messages_added) + assert len(tool_msgs) == 1 + assert tool_msgs[0]["content"] == "truncated string result" From 3a0246f791ce1e40f009179e75f4458d20647aa0 Mon Sep 17 00:00:00 2001 From: Brian Krabach Date: Mon, 23 Feb 2026 07:22:12 -0800 Subject: [PATCH 2/7] fix: use safe dual-access pattern in CancelledError handler (#6) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace bare tc.id and tc.name attribute access with the safe getattr/dict.get dual-access pattern already used at every other tool_call access site in the file. Prevents AttributeError when providers return tool_calls as plain dicts and the user cancels during tool execution. Adds regression tests verifying CancelledError handler works with dict-based tool_calls. 🤖 Generated with [Amplifier](https://github.com/microsoft/amplifier) Co-authored-by: Amplifier <240397093+microsoft-amplifier@users.noreply.github.com> --- amplifier_module_loop_basic/__init__.py | 5 +- tests/test_cancelled_error_dict_tool_calls.py | 105 ++++++++++++++++++ 2 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 tests/test_cancelled_error_dict_tool_calls.py diff --git a/amplifier_module_loop_basic/__init__.py b/amplifier_module_loop_basic/__init__.py index af96822..cdf6d7c 100644 --- a/amplifier_module_loop_basic/__init__.py +++ b/amplifier_module_loop_basic/__init__.py @@ -533,8 +533,9 @@ async def execute_single_tool( await context.add_message( { "role": "tool", - "tool_call_id": tc.id, - "content": f'{{"error": "Tool execution was cancelled by user", "cancelled": true, "tool": "{tc.name}"}}', + "tool_call_id": getattr(tc, "id", None) + or tc.get("id"), + "content": f'{{"error": "Tool execution was cancelled by user", "cancelled": true, "tool": "{getattr(tc, "name", None) or tc.get("tool")}"}}', } ) # Re-raise to let the cancellation propagate diff --git a/tests/test_cancelled_error_dict_tool_calls.py b/tests/test_cancelled_error_dict_tool_calls.py new file mode 100644 index 0000000..8145e76 --- /dev/null +++ b/tests/test_cancelled_error_dict_tool_calls.py @@ -0,0 +1,105 @@ +"""Tests for CancelledError handler with dict-based tool_calls. + +Regression test for unsafe tc.id / tc.name access at lines 536-537. +The CancelledError handler used bare attribute access on tool_call objects +that may be plain dicts. Every other access site (9 of them) uses the safe +dual-access pattern: getattr(tc, "id", None) or tc.get("id"). +""" + +import asyncio + +import pytest + +from amplifier_core.testing import EventRecorder, MockContextManager + +from amplifier_module_loop_basic import BasicOrchestrator + + +class DictToolCallProvider: + """Provider that returns tool_calls as plain dicts (not ToolCall objects). + + Some providers return tool_calls as dicts rather than objects. + The orchestrator explicitly accommodates this with a dual-access pattern. + """ + + name = "dict-provider" + + async def complete(self, request, **kwargs): + return type( + "Response", + (), + { + "content": "Calling tool", + "tool_calls": [ + {"id": "tc1", "tool": "cancel_tool", "arguments": {}} + ], + "usage": None, + "content_blocks": None, + "metadata": None, + }, + )() + + +class CancellingTool: + """Tool that raises CancelledError to simulate immediate cancellation.""" + + name = "cancel_tool" + description = "tool that simulates cancellation" + input_schema = {"type": "object", "properties": {}} + + async def execute(self, args): + raise asyncio.CancelledError() + + +@pytest.mark.asyncio +async def test_cancelled_error_handler_with_dict_tool_calls(): + """CancelledError handler must not crash when tool_calls are plain dicts. + + Without the fix, line 536 (tc.id) raises: + AttributeError: 'dict' object has no attribute 'id' + + With the fix, CancelledError propagates cleanly after synthesizing + cancelled tool results into the context. + """ + orchestrator = BasicOrchestrator({}) + context = MockContextManager() + hooks = EventRecorder() + + with pytest.raises(asyncio.CancelledError): + await orchestrator.execute( + prompt="Test", + context=context, + providers={"default": DictToolCallProvider()}, + tools={"cancel_tool": CancellingTool()}, + hooks=hooks, + ) + + +@pytest.mark.asyncio +async def test_cancelled_error_synthesizes_messages_for_dict_tool_calls(): + """After fix, cancelled tool results are properly added to context. + + Verifies the synthesized cancellation messages contain the correct + tool_call_id and tool name extracted via the safe dual-access pattern. + """ + orchestrator = BasicOrchestrator({}) + context = MockContextManager() + hooks = EventRecorder() + + with pytest.raises(asyncio.CancelledError): + await orchestrator.execute( + prompt="Test", + context=context, + providers={"default": DictToolCallProvider()}, + tools={"cancel_tool": CancellingTool()}, + hooks=hooks, + ) + + # Find the synthesized cancellation message in context + tool_messages = [m for m in context.messages if m.get("role") == "tool"] + assert len(tool_messages) >= 1, "Expected at least one synthesized tool message" + + cancel_msg = tool_messages[-1] + assert cancel_msg["tool_call_id"] == "tc1" + assert "cancelled" in cancel_msg["content"] + assert "cancel_tool" in cancel_msg["content"] From 4448ab02640f359fe1c57cb1a761baf17e6aedd8 Mon Sep 17 00:00:00 2001 From: Salil Das Date: Mon, 23 Feb 2026 10:03:23 -0800 Subject: [PATCH 3/7] fix: protect result-writing loops from CancelledError and catch CancelledError in tool execution MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Harden the orchestrator's result-writing and tool-execution loops against CancelledError to prevent orphaned tool_calls that break the conversation state. When a CancelledError escapes a result-writing loop, the assistant message is committed without the corresponding tool result. The LLM sees a tool_use without a tool_result on the next turn, causing a protocol violation or infinite retry loop. This fix catches CancelledError at two levels: - In the result-writing loop, ensuring partial results are committed - In tool execution, ensuring the tool result is always recorded Rebased on main to pick up the defensive getattr access pattern for tool call attributes. Companion fixes already merged in loop-streaming#14 and loop-events#5. Fixes: microsoft-amplifier/amplifier-support#46 🤖 Generated with [Amplifier](https://github.com/microsoft/amplifier) Co-Authored-By: Amplifier <240397093+microsoft-amplifier@users.noreply.github.com> --- amplifier_module_loop_basic/__init__.py | 91 ++++++++++++++++++------- 1 file changed, 65 insertions(+), 26 deletions(-) diff --git a/amplifier_module_loop_basic/__init__.py b/amplifier_module_loop_basic/__init__.py index cdf6d7c..1b39ed7 100644 --- a/amplifier_module_loop_basic/__init__.py +++ b/amplifier_module_loop_basic/__init__.py @@ -487,7 +487,7 @@ async def execute_single_tool( result_content = result.get_serialized_output() return (tool_call_id, result_content) - except Exception as te: + except (Exception, asyncio.CancelledError) as te: # Emit error event await hooks.emit( TOOL_ERROR, @@ -524,19 +524,28 @@ async def execute_single_tool( ) except asyncio.CancelledError: # Immediate cancellation (second Ctrl+C) - synthesize cancelled results - # for ALL tool_calls to maintain tool_use/tool_result pairing + # for ALL tool_calls to maintain tool_use/tool_result pairing. + # Protect from further CancelledError using kernel + # catch-continue-reraise pattern so all results are written. logger.info( "Tool execution cancelled - synthesizing cancelled results" ) for tc in tool_calls: - if hasattr(context, "add_message"): - await context.add_message( - { - "role": "tool", - "tool_call_id": getattr(tc, "id", None) - or tc.get("id"), - "content": f'{{"error": "Tool execution was cancelled by user", "cancelled": true, "tool": "{getattr(tc, "name", None) or tc.get("tool")}"}}', - } + try: + if hasattr(context, "add_message"): + await context.add_message( + { + "role": "tool", + "tool_call_id": getattr(tc, "id", None) + or tc.get("id"), + "content": f'{{"error": "Tool execution was cancelled by user", "cancelled": true, "tool": "{getattr(tc, "name", None) or tc.get("tool")}"}}', + } + ) + except asyncio.CancelledError: + logger.info( + "CancelledError during synthetic result write - " + "completing remaining writes to prevent " + "orphaned tool_calls" ) # Re-raise to let the cancellation propagate raise @@ -546,15 +555,30 @@ async def execute_single_tool( # MUST add tool results to context before returning # Otherwise we leave orphaned tool_calls without matching tool_results # which violates provider API contracts (Anthropic, OpenAI) + # Protect from CancelledError using kernel catch-continue-reraise + # pattern (coordinator.cleanup, hooks.emit) so all results are + # written even if force-cancel arrives mid-loop. + _cancel_error = None for tool_call_id, content in tool_results: - if hasattr(context, "add_message"): - await context.add_message( - { - "role": "tool", - "tool_call_id": tool_call_id, - "content": content, - } - ) + try: + if hasattr(context, "add_message"): + await context.add_message( + { + "role": "tool", + "tool_call_id": tool_call_id, + "content": content, + } + ) + except asyncio.CancelledError: + if _cancel_error is None: + _cancel_error = asyncio.CancelledError() + logger.info( + "CancelledError during tool result write - " + "completing remaining writes to prevent " + "orphaned tool_calls" + ) + if _cancel_error is not None: + raise _cancel_error await hooks.emit( ORCHESTRATOR_COMPLETE, { @@ -566,15 +590,30 @@ async def execute_single_tool( return final_content # Add all tool results to context in original order (deterministic) + # Protect from CancelledError using kernel catch-continue-reraise + # pattern so all results are written even if force-cancel arrives + # mid-loop. + _cancel_error = None for tool_call_id, content in tool_results: - if hasattr(context, "add_message"): - await context.add_message( - { - "role": "tool", - "tool_call_id": tool_call_id, - "content": content, - } - ) + try: + if hasattr(context, "add_message"): + await context.add_message( + { + "role": "tool", + "tool_call_id": tool_call_id, + "content": content, + } + ) + except asyncio.CancelledError: + if _cancel_error is None: + _cancel_error = asyncio.CancelledError() + logger.info( + "CancelledError during tool result write - " + "completing remaining writes to prevent " + "orphaned tool_calls" + ) + if _cancel_error is not None: + raise _cancel_error # After executing tools, continue loop to get final response iteration += 1 From 3b9ff5e54a13b5312b7c4d56929ac396c7f0e801 Mon Sep 17 00:00:00 2001 From: Brian Krabach Date: Fri, 6 Mar 2026 07:34:57 -0800 Subject: [PATCH 4/7] feat: CP-6 normalize loop-basic with execution:start/end, tool_call_id, observability.events (#7) - Add EXECUTION_START and EXECUTION_END imports from amplifier_core.events - Move asyncio to top-level import (was inline inside execute method) - Emit execution:start after prompt:submit processing (skipped on deny) - Wrap main loop body in try/except/finally to emit execution:end on ALL exit paths - status='completed' on normal completion - status='cancelled' on cancellation paths (loop-start, post-provider, post-tools) - status='cancelled' for asyncio.CancelledError exceptions - status='error' for all other exceptions - Add tool_call_id to TOOL_PRE event payload - Add tool_call_id to TOOL_POST event payload - Add tool_call_id to both TOOL_ERROR event payloads (not-found and exception paths) - Register observability.events in mount() with execution:start and execution:end Tests added in tests/test_lifecycle_events.py: - test_execution_start_fires_with_prompt - test_execution_end_fires_on_normal_completion - test_execution_end_response_matches_return_value - test_execution_start_not_fired_when_prompt_submit_denied - test_execution_end_fires_cancelled_at_loop_start - test_execution_end_fires_cancelled_after_provider - test_execution_end_fires_on_provider_exception - test_execution_end_fires_exactly_once - test_tool_pre_includes_tool_call_id - test_tool_post_includes_tool_call_id - test_tool_error_not_found_includes_tool_call_id - test_tool_error_exception_includes_tool_call_id - test_mount_registers_observability_events --- amplifier_module_loop_basic/__init__.py | 1346 ++++++++++++----------- tests/test_lifecycle_events.py | 475 ++++++++ 2 files changed, 1179 insertions(+), 642 deletions(-) create mode 100644 tests/test_lifecycle_events.py diff --git a/amplifier_module_loop_basic/__init__.py b/amplifier_module_loop_basic/__init__.py index 1b39ed7..2e29f62 100644 --- a/amplifier_module_loop_basic/__init__.py +++ b/amplifier_module_loop_basic/__init__.py @@ -5,6 +5,7 @@ # Amplifier module metadata __amplifier_module_type__ = "orchestrator" +import asyncio import json import logging from typing import Any @@ -14,6 +15,8 @@ from amplifier_core import ModuleCoordinator from amplifier_core.events import CONTENT_BLOCK_END from amplifier_core.events import CONTENT_BLOCK_START +from amplifier_core.events import EXECUTION_END +from amplifier_core.events import EXECUTION_START from amplifier_core.events import ORCHESTRATOR_COMPLETE from amplifier_core.events import PROMPT_COMPLETE from amplifier_core.events import PROMPT_SUBMIT @@ -33,6 +36,14 @@ async def mount(coordinator: ModuleCoordinator, config: dict[str, Any] | None = None): config = config or {} + coordinator.register_contributor( + "observability.events", + "loop-basic", + lambda: [ + "execution:start", + "execution:end", + ], + ) orchestrator = BasicOrchestrator(config) await coordinator.mount("orchestrator", orchestrator) logger.info("Mounted BasicOrchestrator (desired-state)") @@ -68,6 +79,9 @@ async def execute( if result.action == "deny": return f"Operation denied: {result.reason}" + # Emit execution start (after prompt:submit so denials skip this) + await hooks.emit(EXECUTION_START, {"prompt": prompt}) + # Add user message if hasattr(context, "add_message"): await context.add_message({"role": "user", "content": prompt}) @@ -85,479 +99,540 @@ async def execute( # Agentic loop: continue until we get a text response (no tool calls) iteration = 0 final_content = "" + execution_status = "completed" + + try: + while self.max_iterations == -1 or iteration < self.max_iterations: + # Check for cancellation at iteration start + if coordinator and coordinator.cancellation.is_cancelled: + # Emit orchestrator complete with cancelled status + await hooks.emit( + ORCHESTRATOR_COMPLETE, + { + "orchestrator": "loop-basic", + "turn_count": iteration, + "status": "cancelled", + }, + ) + execution_status = "cancelled" + return final_content - while self.max_iterations == -1 or iteration < self.max_iterations: - # Check for cancellation at iteration start - if coordinator and coordinator.cancellation.is_cancelled: - # Emit orchestrator complete with cancelled status - await hooks.emit( - ORCHESTRATOR_COMPLETE, - { - "orchestrator": "loop-basic", - "turn_count": iteration, - "status": "cancelled", - }, - ) - return final_content - - # Emit provider request BEFORE getting messages (allows hook injections) - result = await hooks.emit( - PROVIDER_REQUEST, {"provider": provider_name, "iteration": iteration} - ) - if coordinator: - result = await coordinator.process_hook_result( - result, "provider:request", "orchestrator" + # Emit provider request BEFORE getting messages (allows hook injections) + result = await hooks.emit( + PROVIDER_REQUEST, + {"provider": provider_name, "iteration": iteration}, ) - if result.action == "deny": - return f"Operation denied: {result.reason}" - - # Get messages for LLM request (context handles compaction internally) - # Pass provider for dynamic budget calculation based on model's context window - if hasattr(context, "get_messages_for_request"): - message_dicts = await context.get_messages_for_request( - provider=provider - ) - else: - # Fallback for simple contexts without the method - message_dicts = getattr( - context, "messages", [{"role": "user", "content": prompt}] - ) - - # Append ephemeral injection if present (temporary, not stored) - if ( - result.action == "inject_context" - and result.ephemeral - and result.context_injection - ): - message_dicts = list(message_dicts) # Copy to avoid modifying context - - # Check if we should append to last tool result - if result.append_to_last_tool_result and len(message_dicts) > 0: - last_msg = message_dicts[-1] - # Append to last message if it's a tool result - if last_msg.get("role") == "tool": - # Append to existing content - original_content = last_msg.get("content", "") - message_dicts[-1] = { - **last_msg, - "content": f"{original_content}\n\n{result.context_injection}", - } - logger.debug( - "Appended ephemeral injection to last tool result message" - ) - else: - # Fall back to new message if last message isn't a tool result - message_dicts.append( - { - "role": result.context_injection_role, - "content": result.context_injection, - } - ) - logger.debug( - f"Last message role is '{last_msg.get('role')}', not 'tool' - " - "created new message for injection" - ) + if coordinator: + result = await coordinator.process_hook_result( + result, "provider:request", "orchestrator" + ) + if result.action == "deny": + return f"Operation denied: {result.reason}" + + # Get messages for LLM request (context handles compaction internally) + # Pass provider for dynamic budget calculation based on model's context window + if hasattr(context, "get_messages_for_request"): + message_dicts = await context.get_messages_for_request( + provider=provider + ) else: - # Default behavior: append as new message - message_dicts.append( - { - "role": result.context_injection_role, - "content": result.context_injection, - } + # Fallback for simple contexts without the method + message_dicts = getattr( + context, "messages", [{"role": "user", "content": prompt}] ) - # Apply pending ephemeral injections from tool:post hooks - if self._pending_ephemeral_injections: - message_dicts = list(message_dicts) # Ensure we have a mutable list - for injection in self._pending_ephemeral_injections: - if ( - injection.get("append_to_last_tool_result") - and len(message_dicts) > 0 - ): + # Append ephemeral injection if present (temporary, not stored) + if ( + result.action == "inject_context" + and result.ephemeral + and result.context_injection + ): + message_dicts = list( + message_dicts + ) # Copy to avoid modifying context + + # Check if we should append to last tool result + if result.append_to_last_tool_result and len(message_dicts) > 0: last_msg = message_dicts[-1] + # Append to last message if it's a tool result if last_msg.get("role") == "tool": + # Append to existing content original_content = last_msg.get("content", "") message_dicts[-1] = { **last_msg, - "content": f"{original_content}\n\n{injection['content']}", + "content": f"{original_content}\n\n{result.context_injection}", } logger.debug( - "Applied pending ephemeral injection to last tool result" + "Appended ephemeral injection to last tool result message" ) else: + # Fall back to new message if last message isn't a tool result message_dicts.append( { - "role": injection["role"], - "content": injection["content"], + "role": result.context_injection_role, + "content": result.context_injection, } ) logger.debug( - "Last message not a tool result, created new message for injection" + f"Last message role is '{last_msg.get('role')}', not 'tool' - " + "created new message for injection" ) else: + # Default behavior: append as new message message_dicts.append( - {"role": injection["role"], "content": injection["content"]} - ) - logger.debug( - "Applied pending ephemeral injection as new message" - ) - # Clear pending injections after applying - self._pending_ephemeral_injections = [] - - # Convert to ChatRequest with Message objects - try: - messages_objects = [Message(**msg) for msg in message_dicts] - - # Convert tools to ToolSpec format for ChatRequest - tools_list = None - if tools: - tools_list = [ - ToolSpec( - name=t.name, - description=t.description, - parameters=t.input_schema, + { + "role": result.context_injection_role, + "content": result.context_injection, + } ) - for t in tools.values() - ] - chat_request = ChatRequest( - messages=messages_objects, - tools=tools_list, - reasoning_effort=self.config.get("reasoning_effort"), - ) - logger.debug( - f"Created ChatRequest with {len(messages_objects)} messages" - ) - logger.debug( - f"Message roles: {[m.role for m in chat_request.messages]}" - ) - except Exception as e: - logger.error(f"Failed to create ChatRequest: {e}") - logger.error(f"Message dicts: {message_dicts}") - raise - try: - if hasattr(provider, "complete"): - # Pass extended_thinking if enabled in orchestrator config - kwargs = {} - if self.extended_thinking: - kwargs["extended_thinking"] = True - response = await provider.complete(chat_request, **kwargs) + # Apply pending ephemeral injections from tool:post hooks + if self._pending_ephemeral_injections: + message_dicts = list(message_dicts) # Ensure we have a mutable list + for injection in self._pending_ephemeral_injections: + if ( + injection.get("append_to_last_tool_result") + and len(message_dicts) > 0 + ): + last_msg = message_dicts[-1] + if last_msg.get("role") == "tool": + original_content = last_msg.get("content", "") + message_dicts[-1] = { + **last_msg, + "content": f"{original_content}\n\n{injection['content']}", + } + logger.debug( + "Applied pending ephemeral injection to last tool result" + ) + else: + message_dicts.append( + { + "role": injection["role"], + "content": injection["content"], + } + ) + logger.debug( + "Last message not a tool result, created new message for injection" + ) + else: + message_dicts.append( + { + "role": injection["role"], + "content": injection["content"], + } + ) + logger.debug( + "Applied pending ephemeral injection as new message" + ) + # Clear pending injections after applying + self._pending_ephemeral_injections = [] + + # Convert to ChatRequest with Message objects + try: + messages_objects = [Message(**msg) for msg in message_dicts] + + # Convert tools to ToolSpec format for ChatRequest + tools_list = None + if tools: + tools_list = [ + ToolSpec( + name=t.name, + description=t.description, + parameters=t.input_schema, + ) + for t in tools.values() + ] - # Check for immediate cancellation after provider returns - # This allows force-cancel to take effect as soon as the blocking - # provider call completes, before processing the response - if coordinator and coordinator.cancellation.is_immediate: - # Emit cancelled status and exit - await hooks.emit( - ORCHESTRATOR_COMPLETE, - { - "orchestrator": "loop-basic", - "turn_count": iteration, - "status": "cancelled", - }, + chat_request = ChatRequest( + messages=messages_objects, + tools=tools_list, + reasoning_effort=self.config.get("reasoning_effort"), + ) + logger.debug( + f"Created ChatRequest with {len(messages_objects)} messages" + ) + logger.debug( + f"Message roles: {[m.role for m in chat_request.messages]}" + ) + except Exception as e: + logger.error(f"Failed to create ChatRequest: {e}") + logger.error(f"Message dicts: {message_dicts}") + raise + try: + if hasattr(provider, "complete"): + # Pass extended_thinking if enabled in orchestrator config + kwargs = {} + if self.extended_thinking: + kwargs["extended_thinking"] = True + response = await provider.complete(chat_request, **kwargs) + + # Check for immediate cancellation after provider returns + # This allows force-cancel to take effect as soon as the blocking + # provider call completes, before processing the response + if coordinator and coordinator.cancellation.is_immediate: + # Emit cancelled status and exit + await hooks.emit( + ORCHESTRATOR_COMPLETE, + { + "orchestrator": "loop-basic", + "turn_count": iteration, + "status": "cancelled", + }, + ) + execution_status = "cancelled" + return final_content + else: + raise RuntimeError( + f"Provider {provider_name} missing 'complete'" ) - return final_content - else: - raise RuntimeError(f"Provider {provider_name} missing 'complete'") - usage = getattr(response, "usage", None) - content = getattr(response, "content", None) - tool_calls = getattr(response, "tool_calls", None) + usage = getattr(response, "usage", None) + content = getattr(response, "content", None) + tool_calls = getattr(response, "tool_calls", None) - await hooks.emit( - PROVIDER_RESPONSE, - { - "provider": provider_name, - "usage": usage, - "tool_calls": bool(tool_calls), - }, - ) + await hooks.emit( + PROVIDER_RESPONSE, + { + "provider": provider_name, + "usage": usage, + "tool_calls": bool(tool_calls), + }, + ) - # Emit content block events if present - content_blocks = getattr(response, "content_blocks", None) - logger.info( - f"Response has content_blocks: {content_blocks is not None} - count: {len(content_blocks) if content_blocks else 0}" - ) - if content_blocks: - total_blocks = len(content_blocks) - logger.info(f"Emitting events for {total_blocks} content blocks") - for idx, block in enumerate(content_blocks): + # Emit content block events if present + content_blocks = getattr(response, "content_blocks", None) + logger.info( + f"Response has content_blocks: {content_blocks is not None} - count: {len(content_blocks) if content_blocks else 0}" + ) + if content_blocks: + total_blocks = len(content_blocks) logger.info( - f"Emitting CONTENT_BLOCK_START for block {idx}, type: {block.type.value}" + f"Emitting events for {total_blocks} content blocks" ) - # Emit block start (without non-serializable raw object) - await hooks.emit( - CONTENT_BLOCK_START, - { - "block_type": block.type.value, + for idx, block in enumerate(content_blocks): + logger.info( + f"Emitting CONTENT_BLOCK_START for block {idx}, type: {block.type.value}" + ) + # Emit block start (without non-serializable raw object) + await hooks.emit( + CONTENT_BLOCK_START, + { + "block_type": block.type.value, + "block_index": idx, + "total_blocks": total_blocks, + }, + ) + + # Emit block end with complete block, usage, and total count + event_data = { "block_index": idx, "total_blocks": total_blocks, - }, - ) - - # Emit block end with complete block, usage, and total count - event_data = { - "block_index": idx, - "total_blocks": total_blocks, - "block": block.to_dict(), - } - if usage: - event_data["usage"] = ( - usage.model_dump() - if hasattr(usage, "model_dump") - else usage - ) - await hooks.emit(CONTENT_BLOCK_END, event_data) - - # Handle tool calls (parallel execution) - if tool_calls: - # Add assistant message with tool calls BEFORE executing them - if hasattr(context, "add_message"): - # Store structured content from response.content (our Pydantic models) - response_content = getattr(response, "content", None) - if response_content and isinstance(response_content, list): - assistant_msg = { - "role": "assistant", - "content": [ - block.model_dump() - if hasattr(block, "model_dump") - else block - for block in response_content - ], - "tool_calls": [ - { - "id": getattr(tc, "id", None) or tc.get("id"), - "tool": getattr(tc, "name", None) - or tc.get("tool"), - "arguments": getattr(tc, "arguments", None) - or tc.get("arguments") - or {}, - } - for tc in tool_calls - ], - } - else: - assistant_msg = { - "role": "assistant", - "content": content if content else "", - "tool_calls": [ - { - "id": getattr(tc, "id", None) or tc.get("id"), - "tool": getattr(tc, "name", None) - or tc.get("tool"), - "arguments": getattr(tc, "arguments", None) - or tc.get("arguments") - or {}, - } - for tc in tool_calls - ], + "block": block.to_dict(), } + if usage: + event_data["usage"] = ( + usage.model_dump() + if hasattr(usage, "model_dump") + else usage + ) + await hooks.emit(CONTENT_BLOCK_END, event_data) + + # Handle tool calls (parallel execution) + if tool_calls: + # Add assistant message with tool calls BEFORE executing them + if hasattr(context, "add_message"): + # Store structured content from response.content (our Pydantic models) + response_content = getattr(response, "content", None) + if response_content and isinstance(response_content, list): + assistant_msg = { + "role": "assistant", + "content": [ + block.model_dump() + if hasattr(block, "model_dump") + else block + for block in response_content + ], + "tool_calls": [ + { + "id": getattr(tc, "id", None) + or tc.get("id"), + "tool": getattr(tc, "name", None) + or tc.get("tool"), + "arguments": getattr(tc, "arguments", None) + or tc.get("arguments") + or {}, + } + for tc in tool_calls + ], + } + else: + assistant_msg = { + "role": "assistant", + "content": content if content else "", + "tool_calls": [ + { + "id": getattr(tc, "id", None) + or tc.get("id"), + "tool": getattr(tc, "name", None) + or tc.get("tool"), + "arguments": getattr(tc, "arguments", None) + or tc.get("arguments") + or {}, + } + for tc in tool_calls + ], + } - # Preserve provider metadata (provider-agnostic passthrough) - # This enables providers to maintain state across steps (e.g., OpenAI reasoning items) - if hasattr(response, "metadata") and response.metadata: - assistant_msg["metadata"] = response.metadata - - await context.add_message(assistant_msg) - - # Execute tools in parallel (user guidance: assume parallel intent when multiple tool calls) - import asyncio - import uuid - - # Generate parallel group ID for event correlation - parallel_group_id = str(uuid.uuid4()) - - # Create tasks for parallel execution - async def execute_single_tool( - tc: Any, group_id: str - ) -> tuple[str, str]: - """Execute one tool, handling all errors gracefully. - - Always returns (tool_call_id, result_or_error) tuple. - Never raises - errors become error results. - """ - tool_name = getattr(tc, "name", None) or tc.get("tool") - tool_call_id = getattr(tc, "id", None) or tc.get("id") - args = ( - getattr(tc, "arguments", None) or tc.get("arguments") or {} - ) - tool = tools.get(tool_name) - - # Register tool with cancellation token for visibility - if coordinator: - coordinator.cancellation.register_tool_start( - tool_call_id, tool_name + # Preserve provider metadata (provider-agnostic passthrough) + # This enables providers to maintain state across steps (e.g., OpenAI reasoning items) + if hasattr(response, "metadata") and response.metadata: + assistant_msg["metadata"] = response.metadata + + await context.add_message(assistant_msg) + + # Execute tools in parallel (user guidance: assume parallel intent when multiple tool calls) + import uuid + + # Generate parallel group ID for event correlation + parallel_group_id = str(uuid.uuid4()) + + # Create tasks for parallel execution + async def execute_single_tool( + tc: Any, group_id: str + ) -> tuple[str, str]: + """Execute one tool, handling all errors gracefully. + + Always returns (tool_call_id, result_or_error) tuple. + Never raises - errors become error results. + """ + tool_name = getattr(tc, "name", None) or tc.get("tool") + tool_call_id = getattr(tc, "id", None) or tc.get("id") + args = ( + getattr(tc, "arguments", None) + or tc.get("arguments") + or {} ) + tool = tools.get(tool_name) - try: - try: - # Emit and process tool pre (allows hooks to block or request approval) - pre_result = await hooks.emit( - TOOL_PRE, - { - "tool_name": tool_name, - "tool_input": args, - "parallel_group_id": group_id, - }, + # Register tool with cancellation token for visibility + if coordinator: + coordinator.cancellation.register_tool_start( + tool_call_id, tool_name ) - if coordinator: - pre_result = await coordinator.process_hook_result( - pre_result, "tool:pre", tool_name - ) - if pre_result.action == "deny": - return ( - tool_call_id, - f"Denied by hook: {pre_result.reason}", - ) - if not tool: - error_msg = f"Error: Tool '{tool_name}' not found" - await hooks.emit( - TOOL_ERROR, + try: + try: + # Emit and process tool pre (allows hooks to block or request approval) + pre_result = await hooks.emit( + TOOL_PRE, { "tool_name": tool_name, - "error": { - "type": "RuntimeError", - "msg": error_msg, - }, + "tool_call_id": tool_call_id, + "tool_input": args, "parallel_group_id": group_id, }, ) - return (tool_call_id, error_msg) - - result = await tool.execute(args) + if coordinator: + pre_result = ( + await coordinator.process_hook_result( + pre_result, "tool:pre", tool_name + ) + ) + if pre_result.action == "deny": + return ( + tool_call_id, + f"Denied by hook: {pre_result.reason}", + ) + + if not tool: + error_msg = ( + f"Error: Tool '{tool_name}' not found" + ) + await hooks.emit( + TOOL_ERROR, + { + "tool_name": tool_name, + "tool_call_id": tool_call_id, + "error": { + "type": "RuntimeError", + "msg": error_msg, + }, + "parallel_group_id": group_id, + }, + ) + return (tool_call_id, error_msg) - # Serialize result for logging - result_data = result - if hasattr(result, "to_dict"): - result_data = result.to_dict() + result = await tool.execute(args) - # Emit and process tool post (allows hooks to inject feedback) - post_result = await hooks.emit( - TOOL_POST, - { - "tool_name": tool_name, - "tool_input": args, - "result": result_data, - "parallel_group_id": group_id, - }, - ) - if coordinator: - await coordinator.process_hook_result( - post_result, "tool:post", tool_name - ) + # Serialize result for logging + result_data = result + if hasattr(result, "to_dict"): + result_data = result.to_dict() - # Store ephemeral injection from tool:post for next iteration - if ( - post_result.action == "inject_context" - and post_result.ephemeral - and post_result.context_injection - ): - self._pending_ephemeral_injections.append( + # Emit and process tool post (allows hooks to inject feedback) + post_result = await hooks.emit( + TOOL_POST, { - "role": post_result.context_injection_role, - "content": post_result.context_injection, - "append_to_last_tool_result": post_result.append_to_last_tool_result, - } - ) - logger.debug( - f"Stored ephemeral injection from tool:post ({tool_name}) for next iteration" + "tool_name": tool_name, + "tool_call_id": tool_call_id, + "tool_input": args, + "result": result_data, + "parallel_group_id": group_id, + }, ) + if coordinator: + await coordinator.process_hook_result( + post_result, "tool:post", tool_name + ) - # Check if a hook modified the tool result. - # hooks.emit() chains modify actions: when a hook - # returns action="modify", the data dict is replaced. - # We detect this by checking if the returned "result" - # is a different object than what we originally sent. - modified_result = None - if post_result and post_result.data is not None: - returned_result = post_result.data.get("result") + # Store ephemeral injection from tool:post for next iteration if ( - returned_result is not None - and returned_result is not result_data + post_result.action == "inject_context" + and post_result.ephemeral + and post_result.context_injection ): - modified_result = returned_result + self._pending_ephemeral_injections.append( + { + "role": post_result.context_injection_role, + "content": post_result.context_injection, + "append_to_last_tool_result": post_result.append_to_last_tool_result, + } + ) + logger.debug( + f"Stored ephemeral injection from tool:post ({tool_name}) for next iteration" + ) - if modified_result is not None: - if isinstance(modified_result, (dict, list)): - result_content = json.dumps(modified_result) + # Check if a hook modified the tool result. + # hooks.emit() chains modify actions: when a hook + # returns action="modify", the data dict is replaced. + # We detect this by checking if the returned "result" + # is a different object than what we originally sent. + modified_result = None + if post_result and post_result.data is not None: + returned_result = post_result.data.get("result") + if ( + returned_result is not None + and returned_result is not result_data + ): + modified_result = returned_result + + if modified_result is not None: + if isinstance(modified_result, (dict, list)): + result_content = json.dumps(modified_result) + else: + result_content = str(modified_result) else: - result_content = str(modified_result) - else: - result_content = result.get_serialized_output() - return (tool_call_id, result_content) - - except (Exception, asyncio.CancelledError) as te: - # Emit error event - await hooks.emit( - TOOL_ERROR, - { - "tool_name": tool_name, - "error": { - "type": type(te).__name__, - "msg": str(te), + result_content = result.get_serialized_output() + return (tool_call_id, result_content) + + except (Exception, asyncio.CancelledError) as te: + # Emit error event + await hooks.emit( + TOOL_ERROR, + { + "tool_name": tool_name, + "tool_call_id": tool_call_id, + "error": { + "type": type(te).__name__, + "msg": str(te), + }, + "parallel_group_id": group_id, }, - "parallel_group_id": group_id, - }, - ) + ) - # Return failure with error message (don't raise!) - error_msg = f"Error executing tool: {str(te)}" - logger.error(f"Tool {tool_name} failed: {te}") - return (tool_call_id, error_msg) - finally: - # Unregister tool from cancellation token - if coordinator: - coordinator.cancellation.register_tool_complete( - tool_call_id - ) + # Return failure with error message (don't raise!) + error_msg = f"Error executing tool: {str(te)}" + logger.error(f"Tool {tool_name} failed: {te}") + return (tool_call_id, error_msg) + finally: + # Unregister tool from cancellation token + if coordinator: + coordinator.cancellation.register_tool_complete( + tool_call_id + ) - # Execute all tools in parallel with asyncio.gather - # return_exceptions=False because we handle exceptions inside execute_single_tool - # Wrap in try/except for CancelledError to handle immediate cancellation - try: - tool_results = await asyncio.gather( - *[ - execute_single_tool(tc, parallel_group_id) - for tc in tool_calls - ] - ) - except asyncio.CancelledError: - # Immediate cancellation (second Ctrl+C) - synthesize cancelled results - # for ALL tool_calls to maintain tool_use/tool_result pairing. - # Protect from further CancelledError using kernel - # catch-continue-reraise pattern so all results are written. - logger.info( - "Tool execution cancelled - synthesizing cancelled results" - ) - for tc in tool_calls: - try: - if hasattr(context, "add_message"): - await context.add_message( - { - "role": "tool", - "tool_call_id": getattr(tc, "id", None) - or tc.get("id"), - "content": f'{{"error": "Tool execution was cancelled by user", "cancelled": true, "tool": "{getattr(tc, "name", None) or tc.get("tool")}"}}', - } + # Execute all tools in parallel with asyncio.gather + # return_exceptions=False because we handle exceptions inside execute_single_tool + # Wrap in try/except for CancelledError to handle immediate cancellation + try: + tool_results = await asyncio.gather( + *[ + execute_single_tool(tc, parallel_group_id) + for tc in tool_calls + ] + ) + except asyncio.CancelledError: + # Immediate cancellation (second Ctrl+C) - synthesize cancelled results + # for ALL tool_calls to maintain tool_use/tool_result pairing. + # Protect from further CancelledError using kernel + # catch-continue-reraise pattern so all results are written. + logger.info( + "Tool execution cancelled - synthesizing cancelled results" + ) + for tc in tool_calls: + try: + if hasattr(context, "add_message"): + await context.add_message( + { + "role": "tool", + "tool_call_id": getattr(tc, "id", None) + or tc.get("id"), + "content": f'{{"error": "Tool execution was cancelled by user", "cancelled": true, "tool": "{getattr(tc, "name", None) or tc.get("tool")}"}}', + } + ) + except asyncio.CancelledError: + logger.info( + "CancelledError during synthetic result write - " + "completing remaining writes to prevent " + "orphaned tool_calls" ) - except asyncio.CancelledError: - logger.info( - "CancelledError during synthetic result write - " - "completing remaining writes to prevent " - "orphaned tool_calls" - ) - # Re-raise to let the cancellation propagate - raise - - # Check for immediate cancellation (graceful path - tools completed) - if coordinator and coordinator.cancellation.is_immediate: - # MUST add tool results to context before returning - # Otherwise we leave orphaned tool_calls without matching tool_results - # which violates provider API contracts (Anthropic, OpenAI) + # Re-raise to let the cancellation propagate + raise + + # Check for immediate cancellation (graceful path - tools completed) + if coordinator and coordinator.cancellation.is_immediate: + # MUST add tool results to context before returning + # Otherwise we leave orphaned tool_calls without matching tool_results + # which violates provider API contracts (Anthropic, OpenAI) + # Protect from CancelledError using kernel catch-continue-reraise + # pattern (coordinator.cleanup, hooks.emit) so all results are + # written even if force-cancel arrives mid-loop. + _cancel_error = None + for tool_call_id, content in tool_results: + try: + if hasattr(context, "add_message"): + await context.add_message( + { + "role": "tool", + "tool_call_id": tool_call_id, + "content": content, + } + ) + except asyncio.CancelledError: + if _cancel_error is None: + _cancel_error = asyncio.CancelledError() + logger.info( + "CancelledError during tool result write - " + "completing remaining writes to prevent " + "orphaned tool_calls" + ) + if _cancel_error is not None: + raise _cancel_error + await hooks.emit( + ORCHESTRATOR_COMPLETE, + { + "orchestrator": "loop-basic", + "turn_count": iteration, + "status": "cancelled", + }, + ) + execution_status = "cancelled" + return final_content + + # Add all tool results to context in original order (deterministic) # Protect from CancelledError using kernel catch-continue-reraise - # pattern (coordinator.cleanup, hooks.emit) so all results are - # written even if force-cancel arrives mid-loop. + # pattern so all results are written even if force-cancel arrives + # mid-loop. _cancel_error = None for tool_call_id, content in tool_results: try: @@ -579,252 +654,239 @@ async def execute_single_tool( ) if _cancel_error is not None: raise _cancel_error - await hooks.emit( - ORCHESTRATOR_COMPLETE, - { - "orchestrator": "loop-basic", - "turn_count": iteration, - "status": "cancelled", - }, - ) - return final_content - - # Add all tool results to context in original order (deterministic) - # Protect from CancelledError using kernel catch-continue-reraise - # pattern so all results are written even if force-cancel arrives - # mid-loop. - _cancel_error = None - for tool_call_id, content in tool_results: - try: - if hasattr(context, "add_message"): - await context.add_message( - { - "role": "tool", - "tool_call_id": tool_call_id, - "content": content, - } - ) - except asyncio.CancelledError: - if _cancel_error is None: - _cancel_error = asyncio.CancelledError() - logger.info( - "CancelledError during tool result write - " - "completing remaining writes to prevent " - "orphaned tool_calls" - ) - if _cancel_error is not None: - raise _cancel_error - # After executing tools, continue loop to get final response - iteration += 1 - continue - - # If we have content (no tool calls), we're done - if content: - # Extract text from content blocks - if isinstance(content, list): - text_parts = [] - for block in content: - if hasattr(block, "text"): - text_parts.append(block.text) - elif isinstance(block, dict) and "text" in block: - text_parts.append(block["text"]) - final_content = "\n\n".join(text_parts) if text_parts else "" - else: - final_content = content - if hasattr(context, "add_message"): - # Store structured content from response.content (our Pydantic models) - response_content = getattr(response, "content", None) - if response_content and isinstance(response_content, list): - assistant_msg = { - "role": "assistant", - "content": [ - block.model_dump() - if hasattr(block, "model_dump") - else block - for block in response_content - ], - } + # After executing tools, continue loop to get final response + iteration += 1 + continue + + # If we have content (no tool calls), we're done + if content: + # Extract text from content blocks + if isinstance(content, list): + text_parts = [] + for block in content: + if hasattr(block, "text"): + text_parts.append(block.text) + elif isinstance(block, dict) and "text" in block: + text_parts.append(block["text"]) + final_content = ( + "\n\n".join(text_parts) if text_parts else "" + ) else: - assistant_msg = {"role": "assistant", "content": content} - # Preserve provider metadata (provider-agnostic passthrough) - if hasattr(response, "metadata") and response.metadata: - assistant_msg["metadata"] = response.metadata - await context.add_message(assistant_msg) - break - - # No content and no tool calls - this shouldn't happen but handle it - logger.warning("Provider returned neither content nor tool calls") - iteration += 1 - - except LLMError as e: + final_content = content + if hasattr(context, "add_message"): + # Store structured content from response.content (our Pydantic models) + response_content = getattr(response, "content", None) + if response_content and isinstance(response_content, list): + assistant_msg = { + "role": "assistant", + "content": [ + block.model_dump() + if hasattr(block, "model_dump") + else block + for block in response_content + ], + } + else: + assistant_msg = { + "role": "assistant", + "content": content, + } + # Preserve provider metadata (provider-agnostic passthrough) + if hasattr(response, "metadata") and response.metadata: + assistant_msg["metadata"] = response.metadata + await context.add_message(assistant_msg) + break + + # No content and no tool calls - this shouldn't happen but handle it + logger.warning("Provider returned neither content nor tool calls") + iteration += 1 + + except LLMError as e: + await hooks.emit( + PROVIDER_ERROR, + { + "provider": provider_name, + "error": {"type": type(e).__name__, "msg": str(e)}, + "retryable": e.retryable, + "status_code": e.status_code, + }, + ) + raise + except Exception as e: + await hooks.emit( + PROVIDER_ERROR, + { + "provider": provider_name, + "error": {"type": type(e).__name__, "msg": str(e)}, + }, + ) + raise + + # Check if we exceeded max iterations (only if not unlimited) + if ( + self.max_iterations != -1 + and iteration >= self.max_iterations + and not final_content + ): + logger.warning( + f"Max iterations ({self.max_iterations}) reached without final response" + ) + + # Inject system reminder to agent before final response await hooks.emit( - PROVIDER_ERROR, + PROVIDER_REQUEST, { "provider": provider_name, - "error": {"type": type(e).__name__, "msg": str(e)}, - "retryable": e.retryable, - "status_code": e.status_code, + "iteration": iteration, + "max_reached": True, }, ) - raise - except Exception as e: - await hooks.emit( - PROVIDER_ERROR, + if coordinator: + # Inject ephemeral reminder (not stored in context) + await coordinator.process_hook_result( + HookResult( + action="inject_context", + context_injection=""" + You have reached the maximum number of iterations for this turn. Please provide a response to the user now, summarizing your progress and noting what remains to be done. You can continue in the next turn if needed. + """, + context_injection_role="system", + ephemeral=True, + suppress_output=True, + ), + "provider:request", + "orchestrator", + ) + + # Get one final response with the reminder (context handles compaction internally) + if hasattr(context, "get_messages_for_request"): + message_dicts = await context.get_messages_for_request( + provider=provider + ) + else: + message_dicts = getattr( + context, "messages", [{"role": "user", "content": prompt}] + ) + message_dicts = list(message_dicts) + message_dicts.append( { - "provider": provider_name, - "error": {"type": type(e).__name__, "msg": str(e)}, - }, + "role": "user", + "content": """ + You have reached the maximum number of iterations for this turn. Please provide a response to the user now, summarizing your progress and noting what remains to be done. You can continue in the next turn if needed. + + DO NOT mention this iteration limit or reminder to the user explicitly. Simply wrap up naturally. + """, + } ) - raise - - # Check if we exceeded max iterations (only if not unlimited) - if ( - self.max_iterations != -1 - and iteration >= self.max_iterations - and not final_content - ): - logger.warning( - f"Max iterations ({self.max_iterations}) reached without final response" - ) - # Inject system reminder to agent before final response + try: + messages_objects = [Message(**msg) for msg in message_dicts] + + # Convert tools to ToolSpec format for ChatRequest + tools_list = None + if tools: + tools_list = [ + ToolSpec( + name=t.name, + description=t.description, + parameters=t.input_schema, + ) + for t in tools.values() + ] + + chat_request = ChatRequest( + messages=messages_objects, + tools=tools_list, + reasoning_effort=self.config.get("reasoning_effort"), + ) + + kwargs = {} + if self.extended_thinking: + kwargs["extended_thinking"] = True + + response = await provider.complete(chat_request, **kwargs) + content = getattr(response, "content", None) + content_blocks = getattr(response, "content_blocks", None) + + if content: + final_content = content + if hasattr(context, "add_message"): + # Store structured content from response.content (our Pydantic models) + response_content = getattr(response, "content", None) + if response_content and isinstance(response_content, list): + assistant_msg = { + "role": "assistant", + "content": [ + block.model_dump() + if hasattr(block, "model_dump") + else block + for block in response_content + ], + } + else: + assistant_msg = { + "role": "assistant", + "content": content, + } + # Preserve provider metadata (provider-agnostic passthrough) + if hasattr(response, "metadata") and response.metadata: + assistant_msg["metadata"] = response.metadata + await context.add_message(assistant_msg) + + except LLMError as e: + await hooks.emit( + PROVIDER_ERROR, + { + "provider": provider_name, + "error": {"type": type(e).__name__, "msg": str(e)}, + "retryable": e.retryable, + "status_code": e.status_code, + }, + ) + logger.error( + f"Error getting final response after max iterations: {e}" + ) + except Exception as e: + await hooks.emit( + PROVIDER_ERROR, + { + "provider": provider_name, + "error": {"type": type(e).__name__, "msg": str(e)}, + }, + ) + logger.error( + f"Error getting final response after max iterations: {e}" + ) + await hooks.emit( - PROVIDER_REQUEST, + PROMPT_COMPLETE, { - "provider": provider_name, - "iteration": iteration, - "max_reached": True, + "response_preview": (final_content or "")[:200], + "length": len(final_content or ""), }, ) - if coordinator: - # Inject ephemeral reminder (not stored in context) - await coordinator.process_hook_result( - HookResult( - action="inject_context", - context_injection=""" -You have reached the maximum number of iterations for this turn. Please provide a response to the user now, summarizing your progress and noting what remains to be done. You can continue in the next turn if needed. -""", - context_injection_role="system", - ephemeral=True, - suppress_output=True, - ), - "provider:request", - "orchestrator", - ) - # Get one final response with the reminder (context handles compaction internally) - if hasattr(context, "get_messages_for_request"): - message_dicts = await context.get_messages_for_request( - provider=provider - ) - else: - message_dicts = getattr( - context, "messages", [{"role": "user", "content": prompt}] - ) - message_dicts = list(message_dicts) - message_dicts.append( + # Emit orchestrator complete event + await hooks.emit( + ORCHESTRATOR_COMPLETE, { - "role": "user", - "content": """ -You have reached the maximum number of iterations for this turn. Please provide a response to the user now, summarizing your progress and noting what remains to be done. You can continue in the next turn if needed. - -DO NOT mention this iteration limit or reminder to the user explicitly. Simply wrap up naturally. -""", - } + "orchestrator": "loop-basic", + "turn_count": iteration, + "status": "success" if final_content else "incomplete", + }, ) - try: - messages_objects = [Message(**msg) for msg in message_dicts] - - # Convert tools to ToolSpec format for ChatRequest - tools_list = None - if tools: - tools_list = [ - ToolSpec( - name=t.name, - description=t.description, - parameters=t.input_schema, - ) - for t in tools.values() - ] - - chat_request = ChatRequest( - messages=messages_objects, - tools=tools_list, - reasoning_effort=self.config.get("reasoning_effort"), - ) - - kwargs = {} - if self.extended_thinking: - kwargs["extended_thinking"] = True - - response = await provider.complete(chat_request, **kwargs) - content = getattr(response, "content", None) - content_blocks = getattr(response, "content_blocks", None) - - if content: - final_content = content - if hasattr(context, "add_message"): - # Store structured content from response.content (our Pydantic models) - response_content = getattr(response, "content", None) - if response_content and isinstance(response_content, list): - assistant_msg = { - "role": "assistant", - "content": [ - block.model_dump() - if hasattr(block, "model_dump") - else block - for block in response_content - ], - } - else: - assistant_msg = {"role": "assistant", "content": content} - # Preserve provider metadata (provider-agnostic passthrough) - if hasattr(response, "metadata") and response.metadata: - assistant_msg["metadata"] = response.metadata - await context.add_message(assistant_msg) - - except LLMError as e: - await hooks.emit( - PROVIDER_ERROR, - { - "provider": provider_name, - "error": {"type": type(e).__name__, "msg": str(e)}, - "retryable": e.retryable, - "status_code": e.status_code, - }, - ) - logger.error(f"Error getting final response after max iterations: {e}") - except Exception as e: - await hooks.emit( - PROVIDER_ERROR, - { - "provider": provider_name, - "error": {"type": type(e).__name__, "msg": str(e)}, - }, - ) - logger.error(f"Error getting final response after max iterations: {e}") - - await hooks.emit( - PROMPT_COMPLETE, - { - "response_preview": (final_content or "")[:200], - "length": len(final_content or ""), - }, - ) - - # Emit orchestrator complete event - await hooks.emit( - ORCHESTRATOR_COMPLETE, - { - "orchestrator": "loop-basic", - "turn_count": iteration, - "status": "success" if final_content else "incomplete", - }, - ) - - return final_content + return final_content + except asyncio.CancelledError: + execution_status = "cancelled" + raise + except Exception: + execution_status = "error" + raise + finally: + await hooks.emit( + EXECUTION_END, {"response": final_content, "status": execution_status} + ) def _select_provider(self, providers: dict[str, Any]) -> Any: """Select a provider based on priority.""" diff --git a/tests/test_lifecycle_events.py b/tests/test_lifecycle_events.py new file mode 100644 index 0000000..47460a8 --- /dev/null +++ b/tests/test_lifecycle_events.py @@ -0,0 +1,475 @@ +"""Tests for CP-6: execution lifecycle events, tool_call_id in tool events, +and observability.events registration. + +Verifies that: +- execution:start fires with {prompt} after prompt:submit succeeds +- execution:end fires with {response, status="completed"} on normal exit +- execution:end fires with {response, status="cancelled"} on cancellation paths +- execution:end fires with {response, status="error"} on exception paths +- execution:start is NOT fired when prompt:submit returns "deny" +- tool_call_id appears in TOOL_PRE, TOOL_POST, and TOOL_ERROR event payloads +- mount() registers observability.events contributions +""" + +import pytest + +from amplifier_core import events as amp_events +from amplifier_core.message_models import ChatResponse, TextBlock, ToolCall +from amplifier_core.models import HookResult +from amplifier_core.testing import EventRecorder, MockContextManager + +from amplifier_module_loop_basic import BasicOrchestrator + +# Import event constants via the module to avoid pyright false-positives on PyO3 re-exports +EXECUTION_START = amp_events.EXECUTION_START # type: ignore[attr-defined] +EXECUTION_END = amp_events.EXECUTION_END # type: ignore[attr-defined] +TOOL_PRE = amp_events.TOOL_PRE # type: ignore[attr-defined] +TOOL_POST = amp_events.TOOL_POST # type: ignore[attr-defined] +TOOL_ERROR = amp_events.TOOL_ERROR # type: ignore[attr-defined] + + +# --------------------------------------------------------------------------- +# Helpers / shared fixtures +# --------------------------------------------------------------------------- + + +class SimpleTextProvider: + """Provider that always returns a plain text response.""" + + async def complete(self, request, **kwargs): + return ChatResponse( + content=[TextBlock(text="Hello, world!")], + tool_calls=None, + ) + + +class ToolThenTextProvider: + """Provider returns one tool call, then a plain text response.""" + + def __init__(self, tool_name="test_tool", tool_id="tc-abc123"): + self._tool_name = tool_name + self._tool_id = tool_id + self._call_count = 0 + + async def complete(self, request, **kwargs): + self._call_count += 1 + if self._call_count == 1: + return ChatResponse( + content=[TextBlock(text="Using tool")], + tool_calls=[ + ToolCall(id=self._tool_id, name=self._tool_name, arguments={"x": 1}) + ], + ) + return ChatResponse( + content=[TextBlock(text="Done!")], + tool_calls=None, + ) + + +class ErrorProvider: + """Provider that always raises an exception.""" + + def __init__(self, error=None): + self._error = error or RuntimeError("provider failed") + + async def complete(self, request, **kwargs): + raise self._error + + +class SimpleTool: + """Tool that returns a successful result.""" + + def __init__(self, name="test_tool"): + self.name = name + self.description = "A test tool" + self.input_schema = {"type": "object", "properties": {}} + + async def execute(self, args): + from amplifier_core import ToolResult + + return ToolResult(success=True, output="tool result") + + +class MockCancellation: + """Minimal cancellation token stub.""" + + def __init__(self, is_cancelled=False, is_immediate=False): + self.is_cancelled = is_cancelled + self.is_immediate = is_immediate + + def register_tool_start(self, tool_call_id, tool_name): + pass + + def register_tool_complete(self, tool_call_id): + pass + + +class MockCoordinator: + """Minimal coordinator stub for cancellation tests.""" + + def __init__(self, is_cancelled=False, is_immediate=False): + self.cancellation = MockCancellation( + is_cancelled=is_cancelled, is_immediate=is_immediate + ) + self._contributions: dict = {} + + async def process_hook_result(self, result, event_name, source): + return result + + def register_contributor(self, channel, name, callback): + self._contributions[(channel, name)] = callback + + +# --------------------------------------------------------------------------- +# execution:start and execution:end on normal completion +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_execution_start_fires_with_prompt(): + """execution:start event is emitted with the prompt payload.""" + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + + await orchestrator.execute( + prompt="Say hello", + context=context, + providers={"default": SimpleTextProvider()}, + tools={}, + hooks=hooks, + ) + + events = hooks.get_events(EXECUTION_START) + assert len(events) == 1, f"Expected 1 execution:start, got {len(events)}" + _, data = events[0] + assert data["prompt"] == "Say hello" + + +@pytest.mark.asyncio +async def test_execution_end_fires_on_normal_completion(): + """execution:end fires with status='completed' after a normal response.""" + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + + await orchestrator.execute( + prompt="Say hello", + context=context, + providers={"default": SimpleTextProvider()}, + tools={}, + hooks=hooks, + ) + + events = hooks.get_events(EXECUTION_END) + assert len(events) == 1, f"Expected 1 execution:end, got {len(events)}" + _, data = events[0] + assert data["status"] == "completed" + assert "response" in data + + +@pytest.mark.asyncio +async def test_execution_end_response_matches_return_value(): + """execution:end payload 'response' matches the value returned by execute().""" + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + + return_val = await orchestrator.execute( + prompt="Say hello", + context=context, + providers={"default": SimpleTextProvider()}, + tools={}, + hooks=hooks, + ) + + _, data = hooks.get_events(EXECUTION_END)[0] + # The response in the event should equal the returned string + assert data["response"] == return_val + + +# --------------------------------------------------------------------------- +# execution:start NOT fired when prompt:submit is denied +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_execution_start_not_fired_when_prompt_submit_denied(): + """execution:start must NOT fire when the coordinator denies prompt:submit.""" + + class DenyingCoordinator(MockCoordinator): + async def process_hook_result(self, result, event_name, source): + if event_name == "prompt:submit": + return HookResult(action="deny", reason="blocked by policy") + return result + + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + coordinator = DenyingCoordinator() + + result = await orchestrator.execute( + prompt="blocked", + context=context, + providers={"default": SimpleTextProvider()}, + tools={}, + hooks=hooks, + coordinator=coordinator, + ) + + assert "denied" in result.lower() + assert len(hooks.get_events(EXECUTION_START)) == 0, ( + "execution:start should not fire when prompt:submit is denied" + ) + assert len(hooks.get_events(EXECUTION_END)) == 0, ( + "execution:end should not fire when prompt:submit is denied (no execution began)" + ) + + +# --------------------------------------------------------------------------- +# execution:end on cancellation paths +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_execution_end_fires_cancelled_at_loop_start(): + """execution:end fires with status='cancelled' when cancelled at loop start.""" + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + coordinator = MockCoordinator(is_cancelled=True) + + await orchestrator.execute( + prompt="test", + context=context, + providers={"default": SimpleTextProvider()}, + tools={}, + hooks=hooks, + coordinator=coordinator, + ) + + events = hooks.get_events(EXECUTION_END) + assert len(events) == 1 + _, data = events[0] + assert data["status"] == "cancelled" + + +@pytest.mark.asyncio +async def test_execution_end_fires_cancelled_after_provider(): + """execution:end fires with status='cancelled' on immediate cancellation + after the provider returns (is_immediate path).""" + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + + # Cancellation that takes effect after first provider call + class ImmediateCancellationCoordinator(MockCoordinator): + def __init__(self): + super().__init__() + self.cancellation = MockCancellation(is_cancelled=False, is_immediate=True) + + coordinator = ImmediateCancellationCoordinator() + + await orchestrator.execute( + prompt="test", + context=context, + providers={"default": SimpleTextProvider()}, + tools={}, + hooks=hooks, + coordinator=coordinator, + ) + + events = hooks.get_events(EXECUTION_END) + assert len(events) == 1 + _, data = events[0] + assert data["status"] == "cancelled" + + +# --------------------------------------------------------------------------- +# execution:end on error path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_execution_end_fires_on_provider_exception(): + """execution:end fires with status='error' when provider raises.""" + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + + with pytest.raises(RuntimeError, match="provider failed"): + await orchestrator.execute( + prompt="test", + context=context, + providers={"default": ErrorProvider()}, + tools={}, + hooks=hooks, + ) + + events = hooks.get_events(EXECUTION_END) + assert len(events) == 1 + _, data = events[0] + assert data["status"] == "error" + + +@pytest.mark.asyncio +async def test_execution_end_fires_exactly_once(): + """execution:end is emitted exactly once per execute() call on the happy path.""" + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + + await orchestrator.execute( + prompt="test", + context=context, + providers={"default": SimpleTextProvider()}, + tools={}, + hooks=hooks, + ) + + assert len(hooks.get_events(EXECUTION_START)) == 1 + assert len(hooks.get_events(EXECUTION_END)) == 1 + + +# --------------------------------------------------------------------------- +# tool_call_id in TOOL_PRE and TOOL_POST payloads +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_tool_pre_includes_tool_call_id(): + """TOOL_PRE event payload must include tool_call_id.""" + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + provider = ToolThenTextProvider(tool_id="my-call-id-123") + tool = SimpleTool("test_tool") + + await orchestrator.execute( + prompt="test", + context=context, + providers={"default": provider}, + tools={"test_tool": tool}, + hooks=hooks, + ) + + pre_events = hooks.get_events(TOOL_PRE) + assert len(pre_events) == 1 + _, data = pre_events[0] + assert "tool_call_id" in data, "TOOL_PRE must include tool_call_id" + assert data["tool_call_id"] == "my-call-id-123" + + +@pytest.mark.asyncio +async def test_tool_post_includes_tool_call_id(): + """TOOL_POST event payload must include tool_call_id.""" + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + provider = ToolThenTextProvider(tool_id="post-call-id-456") + tool = SimpleTool("test_tool") + + await orchestrator.execute( + prompt="test", + context=context, + providers={"default": provider}, + tools={"test_tool": tool}, + hooks=hooks, + ) + + post_events = hooks.get_events(TOOL_POST) + assert len(post_events) == 1 + _, data = post_events[0] + assert "tool_call_id" in data, "TOOL_POST must include tool_call_id" + assert data["tool_call_id"] == "post-call-id-456" + + +@pytest.mark.asyncio +async def test_tool_error_not_found_includes_tool_call_id(): + """TOOL_ERROR (tool not found) event payload must include tool_call_id.""" + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + # Provider requests a tool that doesn't exist in tools dict + provider = ToolThenTextProvider( + tool_name="missing_tool", tool_id="error-call-id-789" + ) + + await orchestrator.execute( + prompt="test", + context=context, + providers={"default": provider}, + tools={}, # empty — 'missing_tool' won't be found + hooks=hooks, + ) + + error_events = hooks.get_events(TOOL_ERROR) + assert len(error_events) == 1 + _, data = error_events[0] + assert "tool_call_id" in data, "TOOL_ERROR (not found) must include tool_call_id" + assert data["tool_call_id"] == "error-call-id-789" + + +@pytest.mark.asyncio +async def test_tool_error_exception_includes_tool_call_id(): + """TOOL_ERROR (exception in tool.execute) event payload must include tool_call_id.""" + + class FailingTool: + name = "failing_tool" + description = "fails" + input_schema = {"type": "object", "properties": {}} + + async def execute(self, args): + raise ValueError("tool exploded") + + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + provider = ToolThenTextProvider(tool_name="failing_tool", tool_id="exc-call-id-000") + + await orchestrator.execute( + prompt="test", + context=context, + providers={"default": provider}, + tools={"failing_tool": FailingTool()}, + hooks=hooks, + ) + + error_events = hooks.get_events(TOOL_ERROR) + assert len(error_events) == 1 + _, data = error_events[0] + assert "tool_call_id" in data, "TOOL_ERROR (exception) must include tool_call_id" + assert data["tool_call_id"] == "exc-call-id-000" + + +# --------------------------------------------------------------------------- +# observability.events registration in mount() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_mount_registers_observability_events(): + """mount() must call coordinator.register_contributor('observability.events', ...).""" + from amplifier_module_loop_basic import mount + + contributions = {} + + class CapturingCoordinator: + async def mount(self, role, instance): + pass + + def register_contributor(self, channel, name, callback): + contributions[(channel, name)] = callback + + coordinator = CapturingCoordinator() + await mount(coordinator) + + assert ("observability.events", "loop-basic") in contributions, ( + "mount() must register 'loop-basic' as a contributor to 'observability.events'" + ) + + # Verify the callback returns the expected events + callback = contributions[("observability.events", "loop-basic")] + events_list = callback() + assert "execution:start" in events_list + assert "execution:end" in events_list From ed8da2d3de44a53240a59f381e3d761539fd697c Mon Sep 17 00:00:00 2001 From: Brian Krabach Date: Fri, 6 Mar 2026 09:40:27 -0800 Subject: [PATCH 5/7] feat: set _tool_dispatch_context on coordinator before tool.execute() The orchestrator now sets coordinator._tool_dispatch_context = { "tool_call_id": tool_call_id, "parallel_group_id": group_id, } immediately before calling tool.execute() in execute_single_tool, and clears it in the outer finally block (alongside register_tool_complete). This lets tools that need framework correlation IDs (e.g. the delegate tool) read them via: dispatch_ctx = getattr(self.coordinator, '_tool_dispatch_context', {}) tool_call_id = dispatch_ctx.get('tool_call_id', '') Uses setattr() to avoid type errors on the dynamic private attribute. The finally block always clears the context to prevent leaking between parallel tool executions. Fixes Check 4: delegate:agent_spawned tool_call_id was always empty because the delegate tool was reading from tool input (which the LLM never populates with framework IDs). --- amplifier_module_loop_basic/__init__.py | 14 +- tests/test_tool_dispatch_context.py | 300 ++++++++++++++++++++++++ 2 files changed, 313 insertions(+), 1 deletion(-) create mode 100644 tests/test_tool_dispatch_context.py diff --git a/amplifier_module_loop_basic/__init__.py b/amplifier_module_loop_basic/__init__.py index 2e29f62..be4f942 100644 --- a/amplifier_module_loop_basic/__init__.py +++ b/amplifier_module_loop_basic/__init__.py @@ -419,6 +419,17 @@ async def execute_single_tool( coordinator.cancellation.register_tool_start( tool_call_id, tool_name ) + # Set dispatch context so tools (e.g. delegate) can + # read the framework-assigned tool_call_id and + # parallel_group_id. Cleared in the finally block. + setattr( + coordinator, + "_tool_dispatch_context", + { + "tool_call_id": tool_call_id, + "parallel_group_id": group_id, + }, + ) try: try: @@ -545,8 +556,9 @@ async def execute_single_tool( logger.error(f"Tool {tool_name} failed: {te}") return (tool_call_id, error_msg) finally: - # Unregister tool from cancellation token + # Clear dispatch context and unregister tool if coordinator: + setattr(coordinator, "_tool_dispatch_context", {}) coordinator.cancellation.register_tool_complete( tool_call_id ) diff --git a/tests/test_tool_dispatch_context.py b/tests/test_tool_dispatch_context.py new file mode 100644 index 0000000..817a3bf --- /dev/null +++ b/tests/test_tool_dispatch_context.py @@ -0,0 +1,300 @@ +"""Tests for _tool_dispatch_context set on coordinator during tool.execute(). + +Verifies that BasicOrchestrator sets coordinator._tool_dispatch_context +with the correct tool_call_id and parallel_group_id immediately before +calling tool.execute(), and clears it in a finally block afterward. + +Covers: +- execute_single_tool (inner path): context set with tool_call_id and parallel_group_id +- context cleared after tool completes normally +- context cleared even when tool raises an exception +- Integration: full execute() path sets dispatch context during tool call +""" + +import pytest +from amplifier_core import ToolResult +from amplifier_core.message_models import ChatResponse, TextBlock, ToolCall +from amplifier_core.testing import EventRecorder, MockContextManager + +from amplifier_module_loop_basic import BasicOrchestrator + + +# --------------------------------------------------------------------------- +# Helpers (reuse pattern from test_lifecycle_events.py) +# --------------------------------------------------------------------------- + + +class MockCancellation: + """Minimal cancellation token stub.""" + + is_cancelled: bool = False + is_immediate: bool = False + + def register_tool_start(self, tool_call_id: str, tool_name: str) -> None: + pass + + def register_tool_complete(self, tool_call_id: str) -> None: + pass + + +class MockCoordinator: + """Minimal coordinator stub that supports _tool_dispatch_context.""" + + def __init__(self, is_cancelled: bool = False, is_immediate: bool = False) -> None: + self.cancellation = MockCancellation() + self.cancellation.is_cancelled = is_cancelled + self.cancellation.is_immediate = is_immediate + self._contributions: dict = {} + + async def process_hook_result(self, result: object, event_name: str, source: str) -> object: + return result + + def register_contributor(self, channel: str, name: str, callback: object) -> None: + self._contributions[(channel, name)] = callback + + +# --------------------------------------------------------------------------- +# Integration tests via full execute() path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_execute_sets_tool_call_id_in_dispatch_context_during_tool_execution() -> None: + """BasicOrchestrator sets tool_call_id in coordinator._tool_dispatch_context + before calling tool.execute(). + """ + captured: dict = {} + coordinator = MockCoordinator() + + class CapturingTool: + name = "capture_tool" + description = "Captures dispatch context during execution" + input_schema: dict = {"type": "object", "properties": {}} + + async def execute(self, args: dict) -> ToolResult: + captured.update(getattr(coordinator, "_tool_dispatch_context", {})) + return ToolResult(success=True, output="done") + + class ToolThenTextProvider: + def __init__(self) -> None: + self._call_count = 0 + + async def complete(self, request: object, **kwargs: object) -> ChatResponse: + self._call_count += 1 + if self._call_count == 1: + return ChatResponse( + content=[TextBlock(text="Using tool")], + tool_calls=[ + ToolCall( + id="dispatch-call-id-001", + name="capture_tool", + arguments={"_": 1}, + ) + ], + ) + return ChatResponse(content=[TextBlock(text="Done!")]) + + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + + await orchestrator.execute( + prompt="test", + context=context, + providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item] + tools={"capture_tool": CapturingTool()}, # type: ignore[dict-item] + hooks=hooks, # type: ignore[arg-type] + coordinator=coordinator, # type: ignore[arg-type] + ) + + assert captured.get("tool_call_id") == "dispatch-call-id-001", ( + "coordinator._tool_dispatch_context must have tool_call_id set during tool.execute()" + ) + + +@pytest.mark.asyncio +async def test_execute_sets_parallel_group_id_in_dispatch_context() -> None: + """BasicOrchestrator sets parallel_group_id in coordinator._tool_dispatch_context.""" + captured: dict = {} + coordinator = MockCoordinator() + + class CapturingTool: + name = "capture_tool" + description = "Captures dispatch context" + input_schema: dict = {"type": "object", "properties": {}} + + async def execute(self, args: dict) -> ToolResult: + captured.update(getattr(coordinator, "_tool_dispatch_context", {})) + return ToolResult(success=True, output="done") + + class ToolThenTextProvider: + def __init__(self) -> None: + self._call_count = 0 + + async def complete(self, request: object, **kwargs: object) -> ChatResponse: + self._call_count += 1 + if self._call_count == 1: + return ChatResponse( + content=[TextBlock(text="Using tool")], + tool_calls=[ + ToolCall(id="group-test-call-id", name="capture_tool", arguments={"_": 1}) + ], + ) + return ChatResponse(content=[TextBlock(text="Done!")]) + + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + + await orchestrator.execute( + prompt="test", + context=context, + providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item] + tools={"capture_tool": CapturingTool()}, # type: ignore[dict-item] + hooks=hooks, # type: ignore[arg-type] + coordinator=coordinator, # type: ignore[arg-type] + ) + + # parallel_group_id is a UUID generated per-batch — verify it's a non-empty string + pgid = captured.get("parallel_group_id") + assert isinstance(pgid, str) and pgid, ( + "_tool_dispatch_context must have a non-empty parallel_group_id string" + ) + + +@pytest.mark.asyncio +async def test_execute_clears_dispatch_context_after_tool_completes() -> None: + """BasicOrchestrator clears coordinator._tool_dispatch_context after tool.execute().""" + coordinator = MockCoordinator() + + class SimpleTool: + name = "simple_tool" + description = "Returns a successful result" + input_schema: dict = {"type": "object", "properties": {}} + + async def execute(self, args: dict) -> ToolResult: + return ToolResult(success=True, output="done") + + class ToolThenTextProvider: + def __init__(self) -> None: + self._call_count = 0 + + async def complete(self, request: object, **kwargs: object) -> ChatResponse: + self._call_count += 1 + if self._call_count == 1: + return ChatResponse( + content=[TextBlock(text="Using tool")], + tool_calls=[ + ToolCall(id="clear-test-id", name="simple_tool", arguments={"_": 1}) + ], + ) + return ChatResponse(content=[TextBlock(text="Done!")]) + + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + + await orchestrator.execute( + prompt="test", + context=context, + providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item] + tools={"simple_tool": SimpleTool()}, # type: ignore[dict-item] + hooks=hooks, # type: ignore[arg-type] + coordinator=coordinator, # type: ignore[arg-type] + ) + + ctx_after = getattr(coordinator, "_tool_dispatch_context", None) + assert ctx_after == {}, ( + "_tool_dispatch_context must be cleared to {} after tool execution completes" + ) + + +@pytest.mark.asyncio +async def test_execute_clears_dispatch_context_after_tool_raises() -> None: + """BasicOrchestrator clears coordinator._tool_dispatch_context even when tool raises.""" + coordinator = MockCoordinator() + + class RaisingTool: + name = "raising_tool" + description = "Always raises an exception" + input_schema: dict = {"type": "object", "properties": {}} + + async def execute(self, args: dict) -> ToolResult: + raise ValueError("tool exploded") + + class ToolThenTextProvider: + def __init__(self) -> None: + self._call_count = 0 + + async def complete(self, request: object, **kwargs: object) -> ChatResponse: + self._call_count += 1 + if self._call_count == 1: + return ChatResponse( + content=[TextBlock(text="Using tool")], + tool_calls=[ + ToolCall(id="raise-test-id", name="raising_tool", arguments={"_": 1}) + ], + ) + return ChatResponse(content=[TextBlock(text="Done despite error!")]) + + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + + # BasicOrchestrator handles tool exceptions gracefully (no raise propagation) + await orchestrator.execute( + prompt="test", + context=context, + providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item] + tools={"raising_tool": RaisingTool()}, # type: ignore[dict-item] + hooks=hooks, # type: ignore[arg-type] + coordinator=coordinator, # type: ignore[arg-type] + ) + + ctx_after = getattr(coordinator, "_tool_dispatch_context", None) + assert ctx_after == {}, ( + "_tool_dispatch_context must be cleared even when tool.execute() raises" + ) + + +@pytest.mark.asyncio +async def test_execute_no_coordinator_does_not_set_dispatch_context() -> None: + """Without a coordinator, execute() still runs tools normally (no dispatch context set).""" + + class SimpleTool: + name = "simple_tool" + description = "Returns a successful result" + input_schema: dict = {"type": "object", "properties": {}} + + async def execute(self, args: dict) -> ToolResult: + return ToolResult(success=True, output="done") + + class ToolThenTextProvider: + def __init__(self) -> None: + self._call_count = 0 + + async def complete(self, request: object, **kwargs: object) -> ChatResponse: + self._call_count += 1 + if self._call_count == 1: + return ChatResponse( + content=[TextBlock(text="Using tool")], + tool_calls=[ + ToolCall(id="no-coord-id", name="simple_tool", arguments={"_": 1}) + ], + ) + return ChatResponse(content=[TextBlock(text="Done!")]) + + orchestrator = BasicOrchestrator({}) + hooks = EventRecorder() + context = MockContextManager() + + # Should complete without error even when coordinator=None + result = await orchestrator.execute( + prompt="test", + context=context, + providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item] + tools={"simple_tool": SimpleTool()}, # type: ignore[dict-item] + hooks=hooks, # type: ignore[arg-type] + ) + + assert result is not None From 783b031e75cd425a2069855ad7f619781484c874 Mon Sep 17 00:00:00 2001 From: Brian Krabach Date: Fri, 20 Mar 2026 08:14:27 -0700 Subject: [PATCH 6/7] =?UTF-8?q?fix:=20remove=20[tool.uv.sources]=20overrid?= =?UTF-8?q?e=20=E2=80=94=20resolve=20amplifier-core=20from=20PyPI?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The [tool.uv.sources] section forced amplifier-core to resolve from git (requiring a Rust toolchain). Removed so that uv resolves from PyPI, matching the standard install path. Also pinned amplifier-core>=1.0.10 in dev dependency-group. --- pyproject.toml | 4 +--- uv.lock | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 98b7ccb..bc0a1d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,13 +33,11 @@ allow-direct-references = true [dependency-groups] dev = [ - "amplifier-core", + "amplifier-core>=1.0.10", "pytest>=8.0.0", "pytest-asyncio>=1.0.0", ] -[tool.uv.sources] -amplifier-core = { git = "https://github.com/microsoft/amplifier-core", branch = "main" } [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/uv.lock b/uv.lock index b3e7958..dbc7687 100644 --- a/uv.lock +++ b/uv.lock @@ -4,8 +4,8 @@ requires-python = ">=3.11" [[package]] name = "amplifier-core" -version = "1.0.0" -source = { git = "https://github.com/microsoft/amplifier-core?branch=main#f246c6f36589d45cb92a69f4e354fe9edc8249c2" } +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, { name = "pydantic" }, @@ -13,6 +13,16 @@ dependencies = [ { name = "tomli" }, { name = "typing-extensions" }, ] +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/0a/c9f979aa34ff43d86323fd02c6e0b2049cf583f48a75eefe4e2d4ea39a5b/amplifier_core-1.3.0-cp311-abi3-macosx_10_12_x86_64.whl", hash = "sha256:e94a6adfb538c2f621ddd4eb44f2f2dc645934fef804bc5318d4595520f4f4c9", size = 8096649, upload-time = "2026-03-19T14:01:24.218Z" }, + { url = "https://files.pythonhosted.org/packages/01/58/619daa63943870340673c625695038b86567adc9896b9d81ff8a2a707b3b/amplifier_core-1.3.0-cp311-abi3-macosx_11_0_arm64.whl", hash = "sha256:d051f91a2c3c61c240f828daf892ac2cfa8d34e9850b245e7ebc96e87f9ac606", size = 7216307, upload-time = "2026-03-19T14:01:26.032Z" }, + { url = "https://files.pythonhosted.org/packages/6b/35/6aeb099f012c8d97ccd4aad878a1422fe3148840922e4818a7c1279fdc57/amplifier_core-1.3.0-cp311-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af25c5e691638496226aa88e80056cb1870b715ddebd311bccb855c653ace05f", size = 7593745, upload-time = "2026-03-19T14:01:28.027Z" }, + { url = "https://files.pythonhosted.org/packages/c8/8e/d2a092e31f1924c6c977a3a091b026560a31133fb1f94bf4389f71a8b249/amplifier_core-1.3.0-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18f1d9e2ffa6a9aac507699c039d3dc8dc23fe696a5af515e2fce1d90b34e11b", size = 8624234, upload-time = "2026-03-19T14:01:30.068Z" }, + { url = "https://files.pythonhosted.org/packages/b8/a3/d6e83e879ecdc300fb4516f1f68f6bab4ab5e2dcc4037a90b9df74861911/amplifier_core-1.3.0-cp311-abi3-win_amd64.whl", hash = "sha256:1bf419d8d659821589b6af68d7a9d6c5e4495095bd7b240ff67527e0ab137985", size = 8887729, upload-time = "2026-03-19T14:01:32.392Z" }, + { url = "https://files.pythonhosted.org/packages/12/00/36e1f6456a7a6782986918f4ec6890fe6526d3fda478bd46c57fd7cfa9b3/amplifier_core-1.3.0-cp311-abi3-win_arm64.whl", hash = "sha256:eebac607c14c5fac1e12eabb97c09ade64bd2d003f2e9fa1566ac4898bfe848d", size = 7658166, upload-time = "2026-03-19T14:01:34.443Z" }, + { url = "https://files.pythonhosted.org/packages/02/f0/3beca3cc30323e60f88c8126612a05e3d97f0c951c2d7920458e7ae8e480/amplifier_core-1.3.0-cp313-cp313t-win_arm64.whl", hash = "sha256:a2f8c128f78c8c615f6e411f378b6ad5dc70db9d342a347435df759e4ea4cdbf", size = 7648908, upload-time = "2026-03-19T14:01:36.763Z" }, + { url = "https://files.pythonhosted.org/packages/d8/38/84b012e1b50226ab97cf8ad9f688708bcb6a34a2c6dd139a28d60d40cd71/amplifier_core-1.3.0-cp314-cp314t-win_arm64.whl", hash = "sha256:5bbfac975bf6757c479b31bd07b1d465fde49d95dded497c4d161509919dd93b", size = 7647745, upload-time = "2026-03-19T14:01:38.955Z" }, +] [[package]] name = "amplifier-module-loop-basic" @@ -30,7 +40,7 @@ dev = [ [package.metadata.requires-dev] dev = [ - { name = "amplifier-core", git = "https://github.com/microsoft/amplifier-core?branch=main" }, + { name = "amplifier-core", specifier = ">=1.0.10" }, { name = "pytest", specifier = ">=8.0.0" }, { name = "pytest-asyncio", specifier = ">=1.0.0" }, ] From d9a445dd4bb704a462eb8dff04345e311e26f88f Mon Sep 17 00:00:00 2001 From: mark licta Date: Thu, 2 Apr 2026 21:55:07 +0000 Subject: [PATCH 7/7] fix: Handle ToolCall Pydantic models correctly The orchestrator was failing when processing ToolCall objects because it used getattr(tc, 'id', None) or tc.get('id'), which fails when: 1. tc is a Pydantic model (no .get() method) 2. The attribute exists but is falsy (empty string) Changed to use hasattr() checks: tc.id if hasattr(tc, 'id') else tc.get('id') This properly handles both Pydantic models and dictionaries. Fixes tool execution errors with custom tool modules. --- amplifier_module_loop_basic/__init__.py | 33 +++++++++---------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/amplifier_module_loop_basic/__init__.py b/amplifier_module_loop_basic/__init__.py index be4f942..6037aea 100644 --- a/amplifier_module_loop_basic/__init__.py +++ b/amplifier_module_loop_basic/__init__.py @@ -354,13 +354,9 @@ async def execute( ], "tool_calls": [ { - "id": getattr(tc, "id", None) - or tc.get("id"), - "tool": getattr(tc, "name", None) - or tc.get("tool"), - "arguments": getattr(tc, "arguments", None) - or tc.get("arguments") - or {}, + "id": tc.id if hasattr(tc, "id") else tc.get("id"), + "tool": tc.name if hasattr(tc, "name") else tc.get("tool"), + "arguments": tc.arguments if hasattr(tc, "arguments") else (tc.get("arguments") or {}), } for tc in tool_calls ], @@ -371,13 +367,9 @@ async def execute( "content": content if content else "", "tool_calls": [ { - "id": getattr(tc, "id", None) - or tc.get("id"), - "tool": getattr(tc, "name", None) - or tc.get("tool"), - "arguments": getattr(tc, "arguments", None) - or tc.get("arguments") - or {}, + "id": tc.id if hasattr(tc, "id") else tc.get("id"), + "tool": tc.name if hasattr(tc, "name") else tc.get("tool"), + "arguments": tc.arguments if hasattr(tc, "arguments") else (tc.get("arguments") or {}), } for tc in tool_calls ], @@ -405,12 +397,10 @@ async def execute_single_tool( Always returns (tool_call_id, result_or_error) tuple. Never raises - errors become error results. """ - tool_name = getattr(tc, "name", None) or tc.get("tool") - tool_call_id = getattr(tc, "id", None) or tc.get("id") + tool_name = tc.name if hasattr(tc, "name") else tc.get("tool") + tool_call_id = tc.id if hasattr(tc, "id") else tc.get("id") args = ( - getattr(tc, "arguments", None) - or tc.get("arguments") - or {} + tc.arguments if hasattr(tc, "arguments") else (tc.get("arguments") or {}) ) tool = tools.get(tool_name) @@ -587,9 +577,8 @@ async def execute_single_tool( await context.add_message( { "role": "tool", - "tool_call_id": getattr(tc, "id", None) - or tc.get("id"), - "content": f'{{"error": "Tool execution was cancelled by user", "cancelled": true, "tool": "{getattr(tc, "name", None) or tc.get("tool")}"}}', + "tool_call_id": tc.id if hasattr(tc, "id") else tc.get("id"), + "content": f'{{"error": "Tool execution was cancelled by user", "cancelled": true, "tool": "{tc.name if hasattr(tc, "name") else tc.get("tool")}"}}', } ) except asyncio.CancelledError: