diff --git a/amplifier_module_loop_basic/__init__.py b/amplifier_module_loop_basic/__init__.py
index 8bac1f1..6037aea 100644
--- a/amplifier_module_loop_basic/__init__.py
+++ b/amplifier_module_loop_basic/__init__.py
@@ -5,6 +5,8 @@
# Amplifier module metadata
__amplifier_module_type__ = "orchestrator"
+import asyncio
+import json
import logging
from typing import Any
@@ -13,6 +15,8 @@
from amplifier_core import ModuleCoordinator
from amplifier_core.events import CONTENT_BLOCK_END
from amplifier_core.events import CONTENT_BLOCK_START
+from amplifier_core.events import EXECUTION_END
+from amplifier_core.events import EXECUTION_START
from amplifier_core.events import ORCHESTRATOR_COMPLETE
from amplifier_core.events import PROMPT_COMPLETE
from amplifier_core.events import PROMPT_SUBMIT
@@ -32,6 +36,14 @@
async def mount(coordinator: ModuleCoordinator, config: dict[str, Any] | None = None):
config = config or {}
+ coordinator.register_contributor(
+ "observability.events",
+ "loop-basic",
+ lambda: [
+ "execution:start",
+ "execution:end",
+ ],
+ )
orchestrator = BasicOrchestrator(config)
await coordinator.mount("orchestrator", orchestrator)
logger.info("Mounted BasicOrchestrator (desired-state)")
@@ -67,6 +79,9 @@ async def execute(
if result.action == "deny":
return f"Operation denied: {result.reason}"
+ # Emit execution start (after prompt:submit so denials skip this)
+ await hooks.emit(EXECUTION_START, {"prompt": prompt})
+
# Add user message
if hasattr(context, "add_message"):
await context.add_message({"role": "user", "content": prompt})
@@ -84,687 +99,795 @@ async def execute(
# Agentic loop: continue until we get a text response (no tool calls)
iteration = 0
final_content = ""
+ execution_status = "completed"
+
+ try:
+ while self.max_iterations == -1 or iteration < self.max_iterations:
+ # Check for cancellation at iteration start
+ if coordinator and coordinator.cancellation.is_cancelled:
+ # Emit orchestrator complete with cancelled status
+ await hooks.emit(
+ ORCHESTRATOR_COMPLETE,
+ {
+ "orchestrator": "loop-basic",
+ "turn_count": iteration,
+ "status": "cancelled",
+ },
+ )
+ execution_status = "cancelled"
+ return final_content
- while self.max_iterations == -1 or iteration < self.max_iterations:
- # Check for cancellation at iteration start
- if coordinator and coordinator.cancellation.is_cancelled:
- # Emit orchestrator complete with cancelled status
- await hooks.emit(
- ORCHESTRATOR_COMPLETE,
- {
- "orchestrator": "loop-basic",
- "turn_count": iteration,
- "status": "cancelled",
- },
- )
- return final_content
-
- # Emit provider request BEFORE getting messages (allows hook injections)
- result = await hooks.emit(
- PROVIDER_REQUEST, {"provider": provider_name, "iteration": iteration}
- )
- if coordinator:
- result = await coordinator.process_hook_result(
- result, "provider:request", "orchestrator"
- )
- if result.action == "deny":
- return f"Operation denied: {result.reason}"
-
- # Get messages for LLM request (context handles compaction internally)
- # Pass provider for dynamic budget calculation based on model's context window
- if hasattr(context, "get_messages_for_request"):
- message_dicts = await context.get_messages_for_request(
- provider=provider
+ # Emit provider request BEFORE getting messages (allows hook injections)
+ result = await hooks.emit(
+ PROVIDER_REQUEST,
+ {"provider": provider_name, "iteration": iteration},
)
- else:
- # Fallback for simple contexts without the method
- message_dicts = getattr(
- context, "messages", [{"role": "user", "content": prompt}]
- )
-
- # Append ephemeral injection if present (temporary, not stored)
- if (
- result.action == "inject_context"
- and result.ephemeral
- and result.context_injection
- ):
- message_dicts = list(message_dicts) # Copy to avoid modifying context
-
- # Check if we should append to last tool result
- if result.append_to_last_tool_result and len(message_dicts) > 0:
- last_msg = message_dicts[-1]
- # Append to last message if it's a tool result
- if last_msg.get("role") == "tool":
- # Append to existing content
- original_content = last_msg.get("content", "")
- message_dicts[-1] = {
- **last_msg,
- "content": f"{original_content}\n\n{result.context_injection}",
- }
- logger.debug(
- "Appended ephemeral injection to last tool result message"
- )
- else:
- # Fall back to new message if last message isn't a tool result
- message_dicts.append(
- {
- "role": result.context_injection_role,
- "content": result.context_injection,
- }
- )
- logger.debug(
- f"Last message role is '{last_msg.get('role')}', not 'tool' - "
- "created new message for injection"
- )
+ if coordinator:
+ result = await coordinator.process_hook_result(
+ result, "provider:request", "orchestrator"
+ )
+ if result.action == "deny":
+ return f"Operation denied: {result.reason}"
+
+ # Get messages for LLM request (context handles compaction internally)
+ # Pass provider for dynamic budget calculation based on model's context window
+ if hasattr(context, "get_messages_for_request"):
+ message_dicts = await context.get_messages_for_request(
+ provider=provider
+ )
else:
- # Default behavior: append as new message
- message_dicts.append(
- {
- "role": result.context_injection_role,
- "content": result.context_injection,
- }
+ # Fallback for simple contexts without the method
+ message_dicts = getattr(
+ context, "messages", [{"role": "user", "content": prompt}]
)
- # Apply pending ephemeral injections from tool:post hooks
- if self._pending_ephemeral_injections:
- message_dicts = list(message_dicts) # Ensure we have a mutable list
- for injection in self._pending_ephemeral_injections:
- if (
- injection.get("append_to_last_tool_result")
- and len(message_dicts) > 0
- ):
+ # Append ephemeral injection if present (temporary, not stored)
+ if (
+ result.action == "inject_context"
+ and result.ephemeral
+ and result.context_injection
+ ):
+ message_dicts = list(
+ message_dicts
+ ) # Copy to avoid modifying context
+
+ # Check if we should append to last tool result
+ if result.append_to_last_tool_result and len(message_dicts) > 0:
last_msg = message_dicts[-1]
+ # Append to last message if it's a tool result
if last_msg.get("role") == "tool":
+ # Append to existing content
original_content = last_msg.get("content", "")
message_dicts[-1] = {
**last_msg,
- "content": f"{original_content}\n\n{injection['content']}",
+ "content": f"{original_content}\n\n{result.context_injection}",
}
logger.debug(
- "Applied pending ephemeral injection to last tool result"
+ "Appended ephemeral injection to last tool result message"
)
else:
+ # Fall back to new message if last message isn't a tool result
message_dicts.append(
{
- "role": injection["role"],
- "content": injection["content"],
+ "role": result.context_injection_role,
+ "content": result.context_injection,
}
)
logger.debug(
- "Last message not a tool result, created new message for injection"
+ f"Last message role is '{last_msg.get('role')}', not 'tool' - "
+ "created new message for injection"
)
else:
+ # Default behavior: append as new message
message_dicts.append(
- {"role": injection["role"], "content": injection["content"]}
- )
- logger.debug(
- "Applied pending ephemeral injection as new message"
- )
- # Clear pending injections after applying
- self._pending_ephemeral_injections = []
-
- # Convert to ChatRequest with Message objects
- try:
- messages_objects = [Message(**msg) for msg in message_dicts]
-
- # Convert tools to ToolSpec format for ChatRequest
- tools_list = None
- if tools:
- tools_list = [
- ToolSpec(
- name=t.name,
- description=t.description,
- parameters=t.input_schema,
+ {
+ "role": result.context_injection_role,
+ "content": result.context_injection,
+ }
)
- for t in tools.values()
- ]
- chat_request = ChatRequest(
- messages=messages_objects,
- tools=tools_list,
- reasoning_effort=self.config.get("reasoning_effort"),
- )
- logger.debug(
- f"Created ChatRequest with {len(messages_objects)} messages"
- )
- logger.debug(
- f"Message roles: {[m.role for m in chat_request.messages]}"
- )
- except Exception as e:
- logger.error(f"Failed to create ChatRequest: {e}")
- logger.error(f"Message dicts: {message_dicts}")
- raise
- try:
- if hasattr(provider, "complete"):
- # Pass extended_thinking if enabled in orchestrator config
- kwargs = {}
- if self.extended_thinking:
- kwargs["extended_thinking"] = True
- response = await provider.complete(chat_request, **kwargs)
+ # Apply pending ephemeral injections from tool:post hooks
+ if self._pending_ephemeral_injections:
+ message_dicts = list(message_dicts) # Ensure we have a mutable list
+ for injection in self._pending_ephemeral_injections:
+ if (
+ injection.get("append_to_last_tool_result")
+ and len(message_dicts) > 0
+ ):
+ last_msg = message_dicts[-1]
+ if last_msg.get("role") == "tool":
+ original_content = last_msg.get("content", "")
+ message_dicts[-1] = {
+ **last_msg,
+ "content": f"{original_content}\n\n{injection['content']}",
+ }
+ logger.debug(
+ "Applied pending ephemeral injection to last tool result"
+ )
+ else:
+ message_dicts.append(
+ {
+ "role": injection["role"],
+ "content": injection["content"],
+ }
+ )
+ logger.debug(
+ "Last message not a tool result, created new message for injection"
+ )
+ else:
+ message_dicts.append(
+ {
+ "role": injection["role"],
+ "content": injection["content"],
+ }
+ )
+ logger.debug(
+ "Applied pending ephemeral injection as new message"
+ )
+ # Clear pending injections after applying
+ self._pending_ephemeral_injections = []
+
+ # Convert to ChatRequest with Message objects
+ try:
+ messages_objects = [Message(**msg) for msg in message_dicts]
+
+ # Convert tools to ToolSpec format for ChatRequest
+ tools_list = None
+ if tools:
+ tools_list = [
+ ToolSpec(
+ name=t.name,
+ description=t.description,
+ parameters=t.input_schema,
+ )
+ for t in tools.values()
+ ]
- # Check for immediate cancellation after provider returns
- # This allows force-cancel to take effect as soon as the blocking
- # provider call completes, before processing the response
- if coordinator and coordinator.cancellation.is_immediate:
- # Emit cancelled status and exit
- await hooks.emit(
- ORCHESTRATOR_COMPLETE,
- {
- "orchestrator": "loop-basic",
- "turn_count": iteration,
- "status": "cancelled",
- },
+ chat_request = ChatRequest(
+ messages=messages_objects,
+ tools=tools_list,
+ reasoning_effort=self.config.get("reasoning_effort"),
+ )
+ logger.debug(
+ f"Created ChatRequest with {len(messages_objects)} messages"
+ )
+ logger.debug(
+ f"Message roles: {[m.role for m in chat_request.messages]}"
+ )
+ except Exception as e:
+ logger.error(f"Failed to create ChatRequest: {e}")
+ logger.error(f"Message dicts: {message_dicts}")
+ raise
+ try:
+ if hasattr(provider, "complete"):
+ # Pass extended_thinking if enabled in orchestrator config
+ kwargs = {}
+ if self.extended_thinking:
+ kwargs["extended_thinking"] = True
+ response = await provider.complete(chat_request, **kwargs)
+
+ # Check for immediate cancellation after provider returns
+ # This allows force-cancel to take effect as soon as the blocking
+ # provider call completes, before processing the response
+ if coordinator and coordinator.cancellation.is_immediate:
+ # Emit cancelled status and exit
+ await hooks.emit(
+ ORCHESTRATOR_COMPLETE,
+ {
+ "orchestrator": "loop-basic",
+ "turn_count": iteration,
+ "status": "cancelled",
+ },
+ )
+ execution_status = "cancelled"
+ return final_content
+ else:
+ raise RuntimeError(
+ f"Provider {provider_name} missing 'complete'"
)
- return final_content
- else:
- raise RuntimeError(f"Provider {provider_name} missing 'complete'")
- usage = getattr(response, "usage", None)
- content = getattr(response, "content", None)
- tool_calls = getattr(response, "tool_calls", None)
+ usage = getattr(response, "usage", None)
+ content = getattr(response, "content", None)
+ tool_calls = getattr(response, "tool_calls", None)
- await hooks.emit(
- PROVIDER_RESPONSE,
- {
- "provider": provider_name,
- "usage": usage,
- "tool_calls": bool(tool_calls),
- },
- )
+ await hooks.emit(
+ PROVIDER_RESPONSE,
+ {
+ "provider": provider_name,
+ "usage": usage,
+ "tool_calls": bool(tool_calls),
+ },
+ )
- # Emit content block events if present
- content_blocks = getattr(response, "content_blocks", None)
- logger.info(
- f"Response has content_blocks: {content_blocks is not None} - count: {len(content_blocks) if content_blocks else 0}"
- )
- if content_blocks:
- total_blocks = len(content_blocks)
- logger.info(f"Emitting events for {total_blocks} content blocks")
- for idx, block in enumerate(content_blocks):
+ # Emit content block events if present
+ content_blocks = getattr(response, "content_blocks", None)
+ logger.info(
+ f"Response has content_blocks: {content_blocks is not None} - count: {len(content_blocks) if content_blocks else 0}"
+ )
+ if content_blocks:
+ total_blocks = len(content_blocks)
logger.info(
- f"Emitting CONTENT_BLOCK_START for block {idx}, type: {block.type.value}"
+ f"Emitting events for {total_blocks} content blocks"
)
- # Emit block start (without non-serializable raw object)
- await hooks.emit(
- CONTENT_BLOCK_START,
- {
- "block_type": block.type.value,
+ for idx, block in enumerate(content_blocks):
+ logger.info(
+ f"Emitting CONTENT_BLOCK_START for block {idx}, type: {block.type.value}"
+ )
+ # Emit block start (without non-serializable raw object)
+ await hooks.emit(
+ CONTENT_BLOCK_START,
+ {
+ "block_type": block.type.value,
+ "block_index": idx,
+ "total_blocks": total_blocks,
+ },
+ )
+
+ # Emit block end with complete block, usage, and total count
+ event_data = {
"block_index": idx,
"total_blocks": total_blocks,
- },
- )
-
- # Emit block end with complete block, usage, and total count
- event_data = {
- "block_index": idx,
- "total_blocks": total_blocks,
- "block": block.to_dict(),
- }
- if usage:
- event_data["usage"] = (
- usage.model_dump()
- if hasattr(usage, "model_dump")
- else usage
- )
- await hooks.emit(CONTENT_BLOCK_END, event_data)
-
- # Handle tool calls (parallel execution)
- if tool_calls:
- # Add assistant message with tool calls BEFORE executing them
- if hasattr(context, "add_message"):
- # Store structured content from response.content (our Pydantic models)
- response_content = getattr(response, "content", None)
- if response_content and isinstance(response_content, list):
- assistant_msg = {
- "role": "assistant",
- "content": [
- block.model_dump()
- if hasattr(block, "model_dump")
- else block
- for block in response_content
- ],
- "tool_calls": [
- {
- "id": getattr(tc, "id", None) or tc.get("id"),
- "tool": getattr(tc, "name", None)
- or tc.get("tool"),
- "arguments": getattr(tc, "arguments", None)
- or tc.get("arguments")
- or {},
- }
- for tc in tool_calls
- ],
- }
- else:
- assistant_msg = {
- "role": "assistant",
- "content": content if content else "",
- "tool_calls": [
- {
- "id": getattr(tc, "id", None) or tc.get("id"),
- "tool": getattr(tc, "name", None)
- or tc.get("tool"),
- "arguments": getattr(tc, "arguments", None)
- or tc.get("arguments")
- or {},
- }
- for tc in tool_calls
- ],
+ "block": block.to_dict(),
}
+ if usage:
+ event_data["usage"] = (
+ usage.model_dump()
+ if hasattr(usage, "model_dump")
+ else usage
+ )
+ await hooks.emit(CONTENT_BLOCK_END, event_data)
- # Preserve provider metadata (provider-agnostic passthrough)
- # This enables providers to maintain state across steps (e.g., OpenAI reasoning items)
- if hasattr(response, "metadata") and response.metadata:
- assistant_msg["metadata"] = response.metadata
-
- await context.add_message(assistant_msg)
-
- # Execute tools in parallel (user guidance: assume parallel intent when multiple tool calls)
- import asyncio
- import uuid
-
- # Generate parallel group ID for event correlation
- parallel_group_id = str(uuid.uuid4())
-
- # Create tasks for parallel execution
- async def execute_single_tool(
- tc: Any, group_id: str
- ) -> tuple[str, str]:
- """Execute one tool, handling all errors gracefully.
-
- Always returns (tool_call_id, result_or_error) tuple.
- Never raises - errors become error results.
- """
- tool_name = getattr(tc, "name", None) or tc.get("tool")
- tool_call_id = getattr(tc, "id", None) or tc.get("id")
- args = (
- getattr(tc, "arguments", None) or tc.get("arguments") or {}
- )
- tool = tools.get(tool_name)
+ # Handle tool calls (parallel execution)
+ if tool_calls:
+ # Add assistant message with tool calls BEFORE executing them
+ if hasattr(context, "add_message"):
+ # Store structured content from response.content (our Pydantic models)
+ response_content = getattr(response, "content", None)
+ if response_content and isinstance(response_content, list):
+ assistant_msg = {
+ "role": "assistant",
+ "content": [
+ block.model_dump()
+ if hasattr(block, "model_dump")
+ else block
+ for block in response_content
+ ],
+ "tool_calls": [
+ {
+ "id": tc.id if hasattr(tc, "id") else tc.get("id"),
+ "tool": tc.name if hasattr(tc, "name") else tc.get("tool"),
+ "arguments": tc.arguments if hasattr(tc, "arguments") else (tc.get("arguments") or {}),
+ }
+ for tc in tool_calls
+ ],
+ }
+ else:
+ assistant_msg = {
+ "role": "assistant",
+ "content": content if content else "",
+ "tool_calls": [
+ {
+ "id": tc.id if hasattr(tc, "id") else tc.get("id"),
+ "tool": tc.name if hasattr(tc, "name") else tc.get("tool"),
+ "arguments": tc.arguments if hasattr(tc, "arguments") else (tc.get("arguments") or {}),
+ }
+ for tc in tool_calls
+ ],
+ }
+
+ # Preserve provider metadata (provider-agnostic passthrough)
+ # This enables providers to maintain state across steps (e.g., OpenAI reasoning items)
+ if hasattr(response, "metadata") and response.metadata:
+ assistant_msg["metadata"] = response.metadata
+
+ await context.add_message(assistant_msg)
+
+ # Execute tools in parallel (user guidance: assume parallel intent when multiple tool calls)
+ import uuid
+
+ # Generate parallel group ID for event correlation
+ parallel_group_id = str(uuid.uuid4())
+
+ # Create tasks for parallel execution
+ async def execute_single_tool(
+ tc: Any, group_id: str
+ ) -> tuple[str, str]:
+ """Execute one tool, handling all errors gracefully.
- # Register tool with cancellation token for visibility
- if coordinator:
- coordinator.cancellation.register_tool_start(
- tool_call_id, tool_name
+ Always returns (tool_call_id, result_or_error) tuple.
+ Never raises - errors become error results.
+ """
+ tool_name = tc.name if hasattr(tc, "name") else tc.get("tool")
+ tool_call_id = tc.id if hasattr(tc, "id") else tc.get("id")
+ args = (
+ tc.arguments if hasattr(tc, "arguments") else (tc.get("arguments") or {})
)
+ tool = tools.get(tool_name)
- try:
- try:
- # Emit and process tool pre (allows hooks to block or request approval)
- pre_result = await hooks.emit(
- TOOL_PRE,
+ # Register tool with cancellation token for visibility
+ if coordinator:
+ coordinator.cancellation.register_tool_start(
+ tool_call_id, tool_name
+ )
+ # Set dispatch context so tools (e.g. delegate) can
+ # read the framework-assigned tool_call_id and
+ # parallel_group_id. Cleared in the finally block.
+ setattr(
+ coordinator,
+ "_tool_dispatch_context",
{
- "tool_name": tool_name,
- "tool_input": args,
+ "tool_call_id": tool_call_id,
"parallel_group_id": group_id,
},
)
- if coordinator:
- pre_result = await coordinator.process_hook_result(
- pre_result, "tool:pre", tool_name
+
+ try:
+ try:
+ # Emit and process tool pre (allows hooks to block or request approval)
+ pre_result = await hooks.emit(
+ TOOL_PRE,
+ {
+ "tool_name": tool_name,
+ "tool_call_id": tool_call_id,
+ "tool_input": args,
+ "parallel_group_id": group_id,
+ },
)
- if pre_result.action == "deny":
- return (
- tool_call_id,
- f"Denied by hook: {pre_result.reason}",
+ if coordinator:
+ pre_result = (
+ await coordinator.process_hook_result(
+ pre_result, "tool:pre", tool_name
+ )
)
+ if pre_result.action == "deny":
+ return (
+ tool_call_id,
+ f"Denied by hook: {pre_result.reason}",
+ )
+
+ if not tool:
+ error_msg = (
+ f"Error: Tool '{tool_name}' not found"
+ )
+ await hooks.emit(
+ TOOL_ERROR,
+ {
+ "tool_name": tool_name,
+ "tool_call_id": tool_call_id,
+ "error": {
+ "type": "RuntimeError",
+ "msg": error_msg,
+ },
+ "parallel_group_id": group_id,
+ },
+ )
+ return (tool_call_id, error_msg)
+
+ result = await tool.execute(args)
- if not tool:
- error_msg = f"Error: Tool '{tool_name}' not found"
+ # Serialize result for logging
+ result_data = result
+ if hasattr(result, "to_dict"):
+ result_data = result.to_dict()
+
+ # Emit and process tool post (allows hooks to inject feedback)
+ post_result = await hooks.emit(
+ TOOL_POST,
+ {
+ "tool_name": tool_name,
+ "tool_call_id": tool_call_id,
+ "tool_input": args,
+ "result": result_data,
+ "parallel_group_id": group_id,
+ },
+ )
+ if coordinator:
+ await coordinator.process_hook_result(
+ post_result, "tool:post", tool_name
+ )
+
+ # Store ephemeral injection from tool:post for next iteration
+ if (
+ post_result.action == "inject_context"
+ and post_result.ephemeral
+ and post_result.context_injection
+ ):
+ self._pending_ephemeral_injections.append(
+ {
+ "role": post_result.context_injection_role,
+ "content": post_result.context_injection,
+ "append_to_last_tool_result": post_result.append_to_last_tool_result,
+ }
+ )
+ logger.debug(
+ f"Stored ephemeral injection from tool:post ({tool_name}) for next iteration"
+ )
+
+ # Check if a hook modified the tool result.
+ # hooks.emit() chains modify actions: when a hook
+ # returns action="modify", the data dict is replaced.
+ # We detect this by checking if the returned "result"
+ # is a different object than what we originally sent.
+ modified_result = None
+ if post_result and post_result.data is not None:
+ returned_result = post_result.data.get("result")
+ if (
+ returned_result is not None
+ and returned_result is not result_data
+ ):
+ modified_result = returned_result
+
+ if modified_result is not None:
+ if isinstance(modified_result, (dict, list)):
+ result_content = json.dumps(modified_result)
+ else:
+ result_content = str(modified_result)
+ else:
+ result_content = result.get_serialized_output()
+ return (tool_call_id, result_content)
+
+ except (Exception, asyncio.CancelledError) as te:
+ # Emit error event
await hooks.emit(
TOOL_ERROR,
{
"tool_name": tool_name,
+ "tool_call_id": tool_call_id,
"error": {
- "type": "RuntimeError",
- "msg": error_msg,
+ "type": type(te).__name__,
+ "msg": str(te),
},
"parallel_group_id": group_id,
},
)
- return (tool_call_id, error_msg)
-
- result = await tool.execute(args)
- # Serialize result for logging
- result_data = result
- if hasattr(result, "to_dict"):
- result_data = result.to_dict()
-
- # Emit and process tool post (allows hooks to inject feedback)
- post_result = await hooks.emit(
- TOOL_POST,
- {
- "tool_name": tool_name,
- "tool_input": args,
- "result": result_data,
- "parallel_group_id": group_id,
- },
- )
+ # Return failure with error message (don't raise!)
+ error_msg = f"Error executing tool: {str(te)}"
+ logger.error(f"Tool {tool_name} failed: {te}")
+ return (tool_call_id, error_msg)
+ finally:
+ # Clear dispatch context and unregister tool
if coordinator:
- await coordinator.process_hook_result(
- post_result, "tool:post", tool_name
+ setattr(coordinator, "_tool_dispatch_context", {})
+ coordinator.cancellation.register_tool_complete(
+ tool_call_id
)
- # Store ephemeral injection from tool:post for next iteration
- if (
- post_result.action == "inject_context"
- and post_result.ephemeral
- and post_result.context_injection
- ):
- self._pending_ephemeral_injections.append(
+ # Execute all tools in parallel with asyncio.gather
+ # return_exceptions=False because we handle exceptions inside execute_single_tool
+ # Wrap in try/except for CancelledError to handle immediate cancellation
+ try:
+ tool_results = await asyncio.gather(
+ *[
+ execute_single_tool(tc, parallel_group_id)
+ for tc in tool_calls
+ ]
+ )
+ except asyncio.CancelledError:
+ # Immediate cancellation (second Ctrl+C) - synthesize cancelled results
+ # for ALL tool_calls to maintain tool_use/tool_result pairing.
+ # Protect from further CancelledError using kernel
+ # catch-continue-reraise pattern so all results are written.
+ logger.info(
+ "Tool execution cancelled - synthesizing cancelled results"
+ )
+ for tc in tool_calls:
+ try:
+ if hasattr(context, "add_message"):
+ await context.add_message(
+ {
+ "role": "tool",
+ "tool_call_id": tc.id if hasattr(tc, "id") else tc.get("id"),
+ "content": f'{{"error": "Tool execution was cancelled by user", "cancelled": true, "tool": "{tc.name if hasattr(tc, "name") else tc.get("tool")}"}}',
+ }
+ )
+ except asyncio.CancelledError:
+ logger.info(
+ "CancelledError during synthetic result write - "
+ "completing remaining writes to prevent "
+ "orphaned tool_calls"
+ )
+ # Re-raise to let the cancellation propagate
+ raise
+
+ # Check for immediate cancellation (graceful path - tools completed)
+ if coordinator and coordinator.cancellation.is_immediate:
+ # MUST add tool results to context before returning
+ # Otherwise we leave orphaned tool_calls without matching tool_results
+ # which violates provider API contracts (Anthropic, OpenAI)
+ # Protect from CancelledError using kernel catch-continue-reraise
+ # pattern (coordinator.cleanup, hooks.emit) so all results are
+ # written even if force-cancel arrives mid-loop.
+ _cancel_error = None
+ for tool_call_id, content in tool_results:
+ try:
+ if hasattr(context, "add_message"):
+ await context.add_message(
+ {
+ "role": "tool",
+ "tool_call_id": tool_call_id,
+ "content": content,
+ }
+ )
+ except asyncio.CancelledError:
+ if _cancel_error is None:
+ _cancel_error = asyncio.CancelledError()
+ logger.info(
+ "CancelledError during tool result write - "
+ "completing remaining writes to prevent "
+ "orphaned tool_calls"
+ )
+ if _cancel_error is not None:
+ raise _cancel_error
+ await hooks.emit(
+ ORCHESTRATOR_COMPLETE,
+ {
+ "orchestrator": "loop-basic",
+ "turn_count": iteration,
+ "status": "cancelled",
+ },
+ )
+ execution_status = "cancelled"
+ return final_content
+
+ # Add all tool results to context in original order (deterministic)
+ # Protect from CancelledError using kernel catch-continue-reraise
+ # pattern so all results are written even if force-cancel arrives
+ # mid-loop.
+ _cancel_error = None
+ for tool_call_id, content in tool_results:
+ try:
+ if hasattr(context, "add_message"):
+ await context.add_message(
{
- "role": post_result.context_injection_role,
- "content": post_result.context_injection,
- "append_to_last_tool_result": post_result.append_to_last_tool_result,
+ "role": "tool",
+ "tool_call_id": tool_call_id,
+ "content": content,
}
)
- logger.debug(
- f"Stored ephemeral injection from tool:post ({tool_name}) for next iteration"
+ except asyncio.CancelledError:
+ if _cancel_error is None:
+ _cancel_error = asyncio.CancelledError()
+ logger.info(
+ "CancelledError during tool result write - "
+ "completing remaining writes to prevent "
+ "orphaned tool_calls"
)
-
- # Return success with result content (JSON-serialized for dict/list)
- result_content = result.get_serialized_output()
- return (tool_call_id, result_content)
-
- except Exception as te:
- # Emit error event
- await hooks.emit(
- TOOL_ERROR,
- {
- "tool_name": tool_name,
- "error": {
- "type": type(te).__name__,
- "msg": str(te),
- },
- "parallel_group_id": group_id,
- },
- )
-
- # Return failure with error message (don't raise!)
- error_msg = f"Error executing tool: {str(te)}"
- logger.error(f"Tool {tool_name} failed: {te}")
- return (tool_call_id, error_msg)
- finally:
- # Unregister tool from cancellation token
- if coordinator:
- coordinator.cancellation.register_tool_complete(
- tool_call_id
- )
-
- # Execute all tools in parallel with asyncio.gather
- # return_exceptions=False because we handle exceptions inside execute_single_tool
- # Wrap in try/except for CancelledError to handle immediate cancellation
- try:
- tool_results = await asyncio.gather(
- *[
- execute_single_tool(tc, parallel_group_id)
- for tc in tool_calls
- ]
- )
- except asyncio.CancelledError:
- # Immediate cancellation (second Ctrl+C) - synthesize cancelled results
- # for ALL tool_calls to maintain tool_use/tool_result pairing
- logger.info(
- "Tool execution cancelled - synthesizing cancelled results"
- )
- for tc in tool_calls:
- if hasattr(context, "add_message"):
- await context.add_message(
- {
- "role": "tool",
- "tool_call_id": tc.id,
- "content": f'{{"error": "Tool execution was cancelled by user", "cancelled": true, "tool": "{tc.name}"}}',
- }
- )
- # Re-raise to let the cancellation propagate
- raise
-
- # Check for immediate cancellation (graceful path - tools completed)
- if coordinator and coordinator.cancellation.is_immediate:
- # MUST add tool results to context before returning
- # Otherwise we leave orphaned tool_calls without matching tool_results
- # which violates provider API contracts (Anthropic, OpenAI)
- for tool_call_id, content in tool_results:
- if hasattr(context, "add_message"):
- await context.add_message(
- {
- "role": "tool",
- "tool_call_id": tool_call_id,
- "content": content,
- }
- )
- await hooks.emit(
- ORCHESTRATOR_COMPLETE,
- {
- "orchestrator": "loop-basic",
- "turn_count": iteration,
- "status": "cancelled",
- },
- )
- return final_content
-
- # Add all tool results to context in original order (deterministic)
- for tool_call_id, content in tool_results:
+ if _cancel_error is not None:
+ raise _cancel_error
+
+ # After executing tools, continue loop to get final response
+ iteration += 1
+ continue
+
+ # If we have content (no tool calls), we're done
+ if content:
+ # Extract text from content blocks
+ if isinstance(content, list):
+ text_parts = []
+ for block in content:
+ if hasattr(block, "text"):
+ text_parts.append(block.text)
+ elif isinstance(block, dict) and "text" in block:
+ text_parts.append(block["text"])
+ final_content = (
+ "\n\n".join(text_parts) if text_parts else ""
+ )
+ else:
+ final_content = content
if hasattr(context, "add_message"):
- await context.add_message(
- {
- "role": "tool",
- "tool_call_id": tool_call_id,
+ # Store structured content from response.content (our Pydantic models)
+ response_content = getattr(response, "content", None)
+ if response_content and isinstance(response_content, list):
+ assistant_msg = {
+ "role": "assistant",
+ "content": [
+ block.model_dump()
+ if hasattr(block, "model_dump")
+ else block
+ for block in response_content
+ ],
+ }
+ else:
+ assistant_msg = {
+ "role": "assistant",
"content": content,
}
- )
-
- # After executing tools, continue loop to get final response
+ # Preserve provider metadata (provider-agnostic passthrough)
+ if hasattr(response, "metadata") and response.metadata:
+ assistant_msg["metadata"] = response.metadata
+ await context.add_message(assistant_msg)
+ break
+
+ # No content and no tool calls - this shouldn't happen but handle it
+ logger.warning("Provider returned neither content nor tool calls")
iteration += 1
- continue
-
- # If we have content (no tool calls), we're done
- if content:
- # Extract text from content blocks
- if isinstance(content, list):
- text_parts = []
- for block in content:
- if hasattr(block, "text"):
- text_parts.append(block.text)
- elif isinstance(block, dict) and "text" in block:
- text_parts.append(block["text"])
- final_content = "\n\n".join(text_parts) if text_parts else ""
- else:
- final_content = content
- if hasattr(context, "add_message"):
- # Store structured content from response.content (our Pydantic models)
- response_content = getattr(response, "content", None)
- if response_content and isinstance(response_content, list):
- assistant_msg = {
- "role": "assistant",
- "content": [
- block.model_dump()
- if hasattr(block, "model_dump")
- else block
- for block in response_content
- ],
- }
- else:
- assistant_msg = {"role": "assistant", "content": content}
- # Preserve provider metadata (provider-agnostic passthrough)
- if hasattr(response, "metadata") and response.metadata:
- assistant_msg["metadata"] = response.metadata
- await context.add_message(assistant_msg)
- break
-
- # No content and no tool calls - this shouldn't happen but handle it
- logger.warning("Provider returned neither content nor tool calls")
- iteration += 1
-
- except LLMError as e:
+
+ except LLMError as e:
+ await hooks.emit(
+ PROVIDER_ERROR,
+ {
+ "provider": provider_name,
+ "error": {"type": type(e).__name__, "msg": str(e)},
+ "retryable": e.retryable,
+ "status_code": e.status_code,
+ },
+ )
+ raise
+ except Exception as e:
+ await hooks.emit(
+ PROVIDER_ERROR,
+ {
+ "provider": provider_name,
+ "error": {"type": type(e).__name__, "msg": str(e)},
+ },
+ )
+ raise
+
+ # Check if we exceeded max iterations (only if not unlimited)
+ if (
+ self.max_iterations != -1
+ and iteration >= self.max_iterations
+ and not final_content
+ ):
+ logger.warning(
+ f"Max iterations ({self.max_iterations}) reached without final response"
+ )
+
+ # Inject system reminder to agent before final response
await hooks.emit(
- PROVIDER_ERROR,
+ PROVIDER_REQUEST,
{
"provider": provider_name,
- "error": {"type": type(e).__name__, "msg": str(e)},
- "retryable": e.retryable,
- "status_code": e.status_code,
+ "iteration": iteration,
+ "max_reached": True,
},
)
- raise
- except Exception as e:
- await hooks.emit(
- PROVIDER_ERROR,
+ if coordinator:
+ # Inject ephemeral reminder (not stored in context)
+ await coordinator.process_hook_result(
+ HookResult(
+ action="inject_context",
+ context_injection="""
+ You have reached the maximum number of iterations for this turn. Please provide a response to the user now, summarizing your progress and noting what remains to be done. You can continue in the next turn if needed.
+ """,
+ context_injection_role="system",
+ ephemeral=True,
+ suppress_output=True,
+ ),
+ "provider:request",
+ "orchestrator",
+ )
+
+ # Get one final response with the reminder (context handles compaction internally)
+ if hasattr(context, "get_messages_for_request"):
+ message_dicts = await context.get_messages_for_request(
+ provider=provider
+ )
+ else:
+ message_dicts = getattr(
+ context, "messages", [{"role": "user", "content": prompt}]
+ )
+ message_dicts = list(message_dicts)
+ message_dicts.append(
{
- "provider": provider_name,
- "error": {"type": type(e).__name__, "msg": str(e)},
- },
+ "role": "user",
+ "content": """
+ You have reached the maximum number of iterations for this turn. Please provide a response to the user now, summarizing your progress and noting what remains to be done. You can continue in the next turn if needed.
+
+ DO NOT mention this iteration limit or reminder to the user explicitly. Simply wrap up naturally.
+ """,
+ }
)
- raise
-
- # Check if we exceeded max iterations (only if not unlimited)
- if (
- self.max_iterations != -1
- and iteration >= self.max_iterations
- and not final_content
- ):
- logger.warning(
- f"Max iterations ({self.max_iterations}) reached without final response"
- )
- # Inject system reminder to agent before final response
+ try:
+ messages_objects = [Message(**msg) for msg in message_dicts]
+
+ # Convert tools to ToolSpec format for ChatRequest
+ tools_list = None
+ if tools:
+ tools_list = [
+ ToolSpec(
+ name=t.name,
+ description=t.description,
+ parameters=t.input_schema,
+ )
+ for t in tools.values()
+ ]
+
+ chat_request = ChatRequest(
+ messages=messages_objects,
+ tools=tools_list,
+ reasoning_effort=self.config.get("reasoning_effort"),
+ )
+
+ kwargs = {}
+ if self.extended_thinking:
+ kwargs["extended_thinking"] = True
+
+ response = await provider.complete(chat_request, **kwargs)
+ content = getattr(response, "content", None)
+ content_blocks = getattr(response, "content_blocks", None)
+
+ if content:
+ final_content = content
+ if hasattr(context, "add_message"):
+ # Store structured content from response.content (our Pydantic models)
+ response_content = getattr(response, "content", None)
+ if response_content and isinstance(response_content, list):
+ assistant_msg = {
+ "role": "assistant",
+ "content": [
+ block.model_dump()
+ if hasattr(block, "model_dump")
+ else block
+ for block in response_content
+ ],
+ }
+ else:
+ assistant_msg = {
+ "role": "assistant",
+ "content": content,
+ }
+ # Preserve provider metadata (provider-agnostic passthrough)
+ if hasattr(response, "metadata") and response.metadata:
+ assistant_msg["metadata"] = response.metadata
+ await context.add_message(assistant_msg)
+
+ except LLMError as e:
+ await hooks.emit(
+ PROVIDER_ERROR,
+ {
+ "provider": provider_name,
+ "error": {"type": type(e).__name__, "msg": str(e)},
+ "retryable": e.retryable,
+ "status_code": e.status_code,
+ },
+ )
+ logger.error(
+ f"Error getting final response after max iterations: {e}"
+ )
+ except Exception as e:
+ await hooks.emit(
+ PROVIDER_ERROR,
+ {
+ "provider": provider_name,
+ "error": {"type": type(e).__name__, "msg": str(e)},
+ },
+ )
+ logger.error(
+ f"Error getting final response after max iterations: {e}"
+ )
+
await hooks.emit(
- PROVIDER_REQUEST,
+ PROMPT_COMPLETE,
{
- "provider": provider_name,
- "iteration": iteration,
- "max_reached": True,
+ "response_preview": (final_content or "")[:200],
+ "length": len(final_content or ""),
},
)
- if coordinator:
- # Inject ephemeral reminder (not stored in context)
- await coordinator.process_hook_result(
- HookResult(
- action="inject_context",
- context_injection="""
-You have reached the maximum number of iterations for this turn. Please provide a response to the user now, summarizing your progress and noting what remains to be done. You can continue in the next turn if needed.
-""",
- context_injection_role="system",
- ephemeral=True,
- suppress_output=True,
- ),
- "provider:request",
- "orchestrator",
- )
- # Get one final response with the reminder (context handles compaction internally)
- if hasattr(context, "get_messages_for_request"):
- message_dicts = await context.get_messages_for_request(
- provider=provider
- )
- else:
- message_dicts = getattr(
- context, "messages", [{"role": "user", "content": prompt}]
- )
- message_dicts = list(message_dicts)
- message_dicts.append(
+ # Emit orchestrator complete event
+ await hooks.emit(
+ ORCHESTRATOR_COMPLETE,
{
- "role": "user",
- "content": """
-You have reached the maximum number of iterations for this turn. Please provide a response to the user now, summarizing your progress and noting what remains to be done. You can continue in the next turn if needed.
-
-DO NOT mention this iteration limit or reminder to the user explicitly. Simply wrap up naturally.
-""",
- }
+ "orchestrator": "loop-basic",
+ "turn_count": iteration,
+ "status": "success" if final_content else "incomplete",
+ },
)
- try:
- messages_objects = [Message(**msg) for msg in message_dicts]
-
- # Convert tools to ToolSpec format for ChatRequest
- tools_list = None
- if tools:
- tools_list = [
- ToolSpec(
- name=t.name,
- description=t.description,
- parameters=t.input_schema,
- )
- for t in tools.values()
- ]
-
- chat_request = ChatRequest(
- messages=messages_objects,
- tools=tools_list,
- reasoning_effort=self.config.get("reasoning_effort"),
- )
-
- kwargs = {}
- if self.extended_thinking:
- kwargs["extended_thinking"] = True
-
- response = await provider.complete(chat_request, **kwargs)
- content = getattr(response, "content", None)
- content_blocks = getattr(response, "content_blocks", None)
-
- if content:
- final_content = content
- if hasattr(context, "add_message"):
- # Store structured content from response.content (our Pydantic models)
- response_content = getattr(response, "content", None)
- if response_content and isinstance(response_content, list):
- assistant_msg = {
- "role": "assistant",
- "content": [
- block.model_dump()
- if hasattr(block, "model_dump")
- else block
- for block in response_content
- ],
- }
- else:
- assistant_msg = {"role": "assistant", "content": content}
- # Preserve provider metadata (provider-agnostic passthrough)
- if hasattr(response, "metadata") and response.metadata:
- assistant_msg["metadata"] = response.metadata
- await context.add_message(assistant_msg)
-
- except LLMError as e:
- await hooks.emit(
- PROVIDER_ERROR,
- {
- "provider": provider_name,
- "error": {"type": type(e).__name__, "msg": str(e)},
- "retryable": e.retryable,
- "status_code": e.status_code,
- },
- )
- logger.error(f"Error getting final response after max iterations: {e}")
- except Exception as e:
- await hooks.emit(
- PROVIDER_ERROR,
- {
- "provider": provider_name,
- "error": {"type": type(e).__name__, "msg": str(e)},
- },
- )
- logger.error(f"Error getting final response after max iterations: {e}")
-
- await hooks.emit(
- PROMPT_COMPLETE,
- {
- "response_preview": (final_content or "")[:200],
- "length": len(final_content or ""),
- },
- )
-
- # Emit orchestrator complete event
- await hooks.emit(
- ORCHESTRATOR_COMPLETE,
- {
- "orchestrator": "loop-basic",
- "turn_count": iteration,
- "status": "success" if final_content else "incomplete",
- },
- )
-
- return final_content
+ return final_content
+ except asyncio.CancelledError:
+ execution_status = "cancelled"
+ raise
+ except Exception:
+ execution_status = "error"
+ raise
+ finally:
+ await hooks.emit(
+ EXECUTION_END, {"response": final_content, "status": execution_status}
+ )
def _select_provider(self, providers: dict[str, Any]) -> Any:
"""Select a provider based on priority."""
diff --git a/pyproject.toml b/pyproject.toml
index 98b7ccb..bc0a1d6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -33,13 +33,11 @@ allow-direct-references = true
[dependency-groups]
dev = [
- "amplifier-core",
+ "amplifier-core>=1.0.10",
"pytest>=8.0.0",
"pytest-asyncio>=1.0.0",
]
-[tool.uv.sources]
-amplifier-core = { git = "https://github.com/microsoft/amplifier-core", branch = "main" }
[tool.pytest.ini_options]
testpaths = ["tests"]
diff --git a/tests/test_cancelled_error_dict_tool_calls.py b/tests/test_cancelled_error_dict_tool_calls.py
new file mode 100644
index 0000000..8145e76
--- /dev/null
+++ b/tests/test_cancelled_error_dict_tool_calls.py
@@ -0,0 +1,105 @@
+"""Tests for CancelledError handler with dict-based tool_calls.
+
+Regression test for unsafe tc.id / tc.name access at lines 536-537.
+The CancelledError handler used bare attribute access on tool_call objects
+that may be plain dicts. Every other access site (9 of them) uses the safe
+dual-access pattern: getattr(tc, "id", None) or tc.get("id").
+"""
+
+import asyncio
+
+import pytest
+
+from amplifier_core.testing import EventRecorder, MockContextManager
+
+from amplifier_module_loop_basic import BasicOrchestrator
+
+
+class DictToolCallProvider:
+ """Provider that returns tool_calls as plain dicts (not ToolCall objects).
+
+ Some providers return tool_calls as dicts rather than objects.
+ The orchestrator explicitly accommodates this with a dual-access pattern.
+ """
+
+ name = "dict-provider"
+
+ async def complete(self, request, **kwargs):
+ return type(
+ "Response",
+ (),
+ {
+ "content": "Calling tool",
+ "tool_calls": [
+ {"id": "tc1", "tool": "cancel_tool", "arguments": {}}
+ ],
+ "usage": None,
+ "content_blocks": None,
+ "metadata": None,
+ },
+ )()
+
+
+class CancellingTool:
+ """Tool that raises CancelledError to simulate immediate cancellation."""
+
+ name = "cancel_tool"
+ description = "tool that simulates cancellation"
+ input_schema = {"type": "object", "properties": {}}
+
+ async def execute(self, args):
+ raise asyncio.CancelledError()
+
+
+@pytest.mark.asyncio
+async def test_cancelled_error_handler_with_dict_tool_calls():
+ """CancelledError handler must not crash when tool_calls are plain dicts.
+
+ Without the fix, line 536 (tc.id) raises:
+ AttributeError: 'dict' object has no attribute 'id'
+
+ With the fix, CancelledError propagates cleanly after synthesizing
+ cancelled tool results into the context.
+ """
+ orchestrator = BasicOrchestrator({})
+ context = MockContextManager()
+ hooks = EventRecorder()
+
+ with pytest.raises(asyncio.CancelledError):
+ await orchestrator.execute(
+ prompt="Test",
+ context=context,
+ providers={"default": DictToolCallProvider()},
+ tools={"cancel_tool": CancellingTool()},
+ hooks=hooks,
+ )
+
+
+@pytest.mark.asyncio
+async def test_cancelled_error_synthesizes_messages_for_dict_tool_calls():
+ """After fix, cancelled tool results are properly added to context.
+
+ Verifies the synthesized cancellation messages contain the correct
+ tool_call_id and tool name extracted via the safe dual-access pattern.
+ """
+ orchestrator = BasicOrchestrator({})
+ context = MockContextManager()
+ hooks = EventRecorder()
+
+ with pytest.raises(asyncio.CancelledError):
+ await orchestrator.execute(
+ prompt="Test",
+ context=context,
+ providers={"default": DictToolCallProvider()},
+ tools={"cancel_tool": CancellingTool()},
+ hooks=hooks,
+ )
+
+ # Find the synthesized cancellation message in context
+ tool_messages = [m for m in context.messages if m.get("role") == "tool"]
+ assert len(tool_messages) >= 1, "Expected at least one synthesized tool message"
+
+ cancel_msg = tool_messages[-1]
+ assert cancel_msg["tool_call_id"] == "tc1"
+ assert "cancelled" in cancel_msg["content"]
+ assert "cancel_tool" in cancel_msg["content"]
diff --git a/tests/test_hook_modify.py b/tests/test_hook_modify.py
new file mode 100644
index 0000000..491ea1c
--- /dev/null
+++ b/tests/test_hook_modify.py
@@ -0,0 +1,227 @@
+"""Tests for hook modify action on tool:post events.
+
+Verifies that when a hook returns HookResult(action="modify", data={"result": ...})
+on a tool:post event, the orchestrator uses the modified data instead of the
+original result.get_serialized_output().
+"""
+
+import json
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from amplifier_core.hooks import HookRegistry
+from amplifier_core.models import HookResult
+
+
+def _make_tool_result(output, success=True):
+ """Create a mock tool result with get_serialized_output() method."""
+ result = MagicMock()
+ result.success = success
+ result.output = output
+ result.error = None
+
+ def get_serialized_output():
+ if isinstance(output, (dict, list)):
+ return json.dumps(output)
+ return str(output)
+
+ result.get_serialized_output = get_serialized_output
+
+ def to_dict():
+ return {"success": success, "output": output, "error": None}
+
+ result.to_dict = to_dict
+ return result
+
+
+def _make_provider_responses(tool_calls_response, text_response_str="Done"):
+ """Create mock provider responses: one with tool calls, one with text."""
+ tool_call = MagicMock()
+ tool_call.id = "tc_1"
+ tool_call.name = "test_tool"
+ tool_call.arguments = {"key": "value"}
+
+ tool_response = MagicMock()
+ tool_response.content = [MagicMock(type="text", text="Using tool")]
+ tool_response.tool_calls = [tool_call]
+ tool_response.usage = None
+ tool_response.content_blocks = None
+ tool_response.metadata = None
+
+ text_block = MagicMock()
+ text_block.text = text_response_str
+ text_block.type = "text"
+ text_response = MagicMock()
+ text_response.content = [text_block]
+ text_response.tool_calls = None
+ text_response.usage = None
+ text_response.content_blocks = None
+ text_response.metadata = None
+
+ mock_provider = AsyncMock()
+ mock_provider.complete = AsyncMock(side_effect=[tool_response, text_response])
+ mock_provider.priority = 1
+
+ return mock_provider
+
+
+def _make_context():
+ """Create a mock context that captures add_message calls."""
+ context = AsyncMock()
+ messages_added = []
+
+ async def capture_add_message(msg):
+ messages_added.append(msg)
+
+ context.add_message = AsyncMock(side_effect=capture_add_message)
+ context.get_messages_for_request = AsyncMock(
+ return_value=[{"role": "user", "content": "test"}]
+ )
+ return context, messages_added
+
+
+def _get_tool_result_messages(messages_added):
+ """Extract tool result messages from captured messages."""
+ return [msg for msg in messages_added if msg.get("role") == "tool"]
+
+
+@pytest.mark.asyncio
+async def test_tool_post_modify_replaces_result():
+ """When a hook returns action='modify' on tool:post, the modified data
+ should be used instead of the original get_serialized_output()."""
+ with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}):
+ from amplifier_module_loop_basic import BasicOrchestrator
+
+ orchestrator = BasicOrchestrator({"max_iterations": 5})
+
+ # Tool with original output
+ original_output = {"original": True, "big_data": "x" * 1000}
+ mock_tool = MagicMock()
+ mock_tool.name = "test_tool"
+ mock_tool.description = "A test tool"
+ mock_tool.input_schema = {"type": "object", "properties": {}}
+ mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output))
+ tools = {"test_tool": mock_tool}
+
+ # Hook that returns modify with new data
+ modified_content = {"modified": True, "truncated": True}
+ hooks = HookRegistry()
+
+ async def modify_hook(event: str, data: dict) -> HookResult:
+ if event == "tool:post":
+ return HookResult(action="modify", data={"result": modified_content})
+ return HookResult()
+
+ hooks.register("tool:post", modify_hook, priority=50, name="test_modify")
+
+ mock_provider = _make_provider_responses(tool_calls_response=True)
+ providers = {"test_provider": mock_provider}
+ context, messages_added = _make_context()
+
+ await orchestrator.execute(
+ prompt="test prompt",
+ context=context,
+ providers=providers,
+ tools=tools,
+ hooks=hooks,
+ )
+
+ tool_msgs = _get_tool_result_messages(messages_added)
+ assert len(tool_msgs) == 1, f"Expected 1 tool result, got {len(tool_msgs)}"
+
+ tool_result_content = tool_msgs[0]["content"]
+
+ # The content should be the MODIFIED data, not the original
+ assert tool_result_content == json.dumps(modified_content), (
+ f"Expected modified content {json.dumps(modified_content)}, "
+ f"got {tool_result_content}"
+ )
+
+ # Verify the original data is NOT used
+ assert "big_data" not in tool_result_content
+
+
+@pytest.mark.asyncio
+async def test_tool_post_no_modify_uses_original():
+ """When no hook returns modify, the original get_serialized_output() is used."""
+ with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}):
+ from amplifier_module_loop_basic import BasicOrchestrator
+
+ orchestrator = BasicOrchestrator({"max_iterations": 5})
+
+ original_output = {"original": True}
+ mock_tool = MagicMock()
+ mock_tool.name = "test_tool"
+ mock_tool.description = "A test tool"
+ mock_tool.input_schema = {"type": "object", "properties": {}}
+ mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output))
+ tools = {"test_tool": mock_tool}
+
+ # No hooks registered — default continue action
+ hooks = HookRegistry()
+
+ mock_provider = _make_provider_responses(tool_calls_response=True)
+ providers = {"test_provider": mock_provider}
+ context, messages_added = _make_context()
+
+ await orchestrator.execute(
+ prompt="test prompt",
+ context=context,
+ providers=providers,
+ tools=tools,
+ hooks=hooks,
+ )
+
+ tool_msgs = _get_tool_result_messages(messages_added)
+ assert len(tool_msgs) == 1
+ tool_result_content = tool_msgs[0]["content"]
+
+ # Should use original serialized output
+ assert tool_result_content == json.dumps(original_output), (
+ f"Expected original {json.dumps(original_output)}, got {tool_result_content}"
+ )
+
+
+@pytest.mark.asyncio
+async def test_tool_post_modify_with_string_result():
+ """When a hook returns modify with a string result, it should be used as-is."""
+ with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}):
+ from amplifier_module_loop_basic import BasicOrchestrator
+
+ orchestrator = BasicOrchestrator({"max_iterations": 5})
+
+ mock_tool = MagicMock()
+ mock_tool.name = "test_tool"
+ mock_tool.description = "A test tool"
+ mock_tool.input_schema = {"type": "object", "properties": {}}
+ mock_tool.execute = AsyncMock(return_value=_make_tool_result("original text"))
+ tools = {"test_tool": mock_tool}
+
+ hooks = HookRegistry()
+
+ async def modify_hook(event: str, data: dict) -> HookResult:
+ if event == "tool:post":
+ return HookResult(
+ action="modify",
+ data={"result": "truncated string result"},
+ )
+ return HookResult()
+
+ hooks.register("tool:post", modify_hook, priority=50, name="test_modify")
+
+ mock_provider = _make_provider_responses(tool_calls_response=True)
+ providers = {"test_provider": mock_provider}
+ context, messages_added = _make_context()
+
+ await orchestrator.execute(
+ prompt="test prompt",
+ context=context,
+ providers=providers,
+ tools=tools,
+ hooks=hooks,
+ )
+
+ tool_msgs = _get_tool_result_messages(messages_added)
+ assert len(tool_msgs) == 1
+ assert tool_msgs[0]["content"] == "truncated string result"
diff --git a/tests/test_lifecycle_events.py b/tests/test_lifecycle_events.py
new file mode 100644
index 0000000..47460a8
--- /dev/null
+++ b/tests/test_lifecycle_events.py
@@ -0,0 +1,475 @@
+"""Tests for CP-6: execution lifecycle events, tool_call_id in tool events,
+and observability.events registration.
+
+Verifies that:
+- execution:start fires with {prompt} after prompt:submit succeeds
+- execution:end fires with {response, status="completed"} on normal exit
+- execution:end fires with {response, status="cancelled"} on cancellation paths
+- execution:end fires with {response, status="error"} on exception paths
+- execution:start is NOT fired when prompt:submit returns "deny"
+- tool_call_id appears in TOOL_PRE, TOOL_POST, and TOOL_ERROR event payloads
+- mount() registers observability.events contributions
+"""
+
+import pytest
+
+from amplifier_core import events as amp_events
+from amplifier_core.message_models import ChatResponse, TextBlock, ToolCall
+from amplifier_core.models import HookResult
+from amplifier_core.testing import EventRecorder, MockContextManager
+
+from amplifier_module_loop_basic import BasicOrchestrator
+
+# Import event constants via the module to avoid pyright false-positives on PyO3 re-exports
+EXECUTION_START = amp_events.EXECUTION_START # type: ignore[attr-defined]
+EXECUTION_END = amp_events.EXECUTION_END # type: ignore[attr-defined]
+TOOL_PRE = amp_events.TOOL_PRE # type: ignore[attr-defined]
+TOOL_POST = amp_events.TOOL_POST # type: ignore[attr-defined]
+TOOL_ERROR = amp_events.TOOL_ERROR # type: ignore[attr-defined]
+
+
+# ---------------------------------------------------------------------------
+# Helpers / shared fixtures
+# ---------------------------------------------------------------------------
+
+
+class SimpleTextProvider:
+ """Provider that always returns a plain text response."""
+
+ async def complete(self, request, **kwargs):
+ return ChatResponse(
+ content=[TextBlock(text="Hello, world!")],
+ tool_calls=None,
+ )
+
+
+class ToolThenTextProvider:
+ """Provider returns one tool call, then a plain text response."""
+
+ def __init__(self, tool_name="test_tool", tool_id="tc-abc123"):
+ self._tool_name = tool_name
+ self._tool_id = tool_id
+ self._call_count = 0
+
+ async def complete(self, request, **kwargs):
+ self._call_count += 1
+ if self._call_count == 1:
+ return ChatResponse(
+ content=[TextBlock(text="Using tool")],
+ tool_calls=[
+ ToolCall(id=self._tool_id, name=self._tool_name, arguments={"x": 1})
+ ],
+ )
+ return ChatResponse(
+ content=[TextBlock(text="Done!")],
+ tool_calls=None,
+ )
+
+
+class ErrorProvider:
+ """Provider that always raises an exception."""
+
+ def __init__(self, error=None):
+ self._error = error or RuntimeError("provider failed")
+
+ async def complete(self, request, **kwargs):
+ raise self._error
+
+
+class SimpleTool:
+ """Tool that returns a successful result."""
+
+ def __init__(self, name="test_tool"):
+ self.name = name
+ self.description = "A test tool"
+ self.input_schema = {"type": "object", "properties": {}}
+
+ async def execute(self, args):
+ from amplifier_core import ToolResult
+
+ return ToolResult(success=True, output="tool result")
+
+
+class MockCancellation:
+ """Minimal cancellation token stub."""
+
+ def __init__(self, is_cancelled=False, is_immediate=False):
+ self.is_cancelled = is_cancelled
+ self.is_immediate = is_immediate
+
+ def register_tool_start(self, tool_call_id, tool_name):
+ pass
+
+ def register_tool_complete(self, tool_call_id):
+ pass
+
+
+class MockCoordinator:
+ """Minimal coordinator stub for cancellation tests."""
+
+ def __init__(self, is_cancelled=False, is_immediate=False):
+ self.cancellation = MockCancellation(
+ is_cancelled=is_cancelled, is_immediate=is_immediate
+ )
+ self._contributions: dict = {}
+
+ async def process_hook_result(self, result, event_name, source):
+ return result
+
+ def register_contributor(self, channel, name, callback):
+ self._contributions[(channel, name)] = callback
+
+
+# ---------------------------------------------------------------------------
+# execution:start and execution:end on normal completion
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_execution_start_fires_with_prompt():
+ """execution:start event is emitted with the prompt payload."""
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+
+ await orchestrator.execute(
+ prompt="Say hello",
+ context=context,
+ providers={"default": SimpleTextProvider()},
+ tools={},
+ hooks=hooks,
+ )
+
+ events = hooks.get_events(EXECUTION_START)
+ assert len(events) == 1, f"Expected 1 execution:start, got {len(events)}"
+ _, data = events[0]
+ assert data["prompt"] == "Say hello"
+
+
+@pytest.mark.asyncio
+async def test_execution_end_fires_on_normal_completion():
+ """execution:end fires with status='completed' after a normal response."""
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+
+ await orchestrator.execute(
+ prompt="Say hello",
+ context=context,
+ providers={"default": SimpleTextProvider()},
+ tools={},
+ hooks=hooks,
+ )
+
+ events = hooks.get_events(EXECUTION_END)
+ assert len(events) == 1, f"Expected 1 execution:end, got {len(events)}"
+ _, data = events[0]
+ assert data["status"] == "completed"
+ assert "response" in data
+
+
+@pytest.mark.asyncio
+async def test_execution_end_response_matches_return_value():
+ """execution:end payload 'response' matches the value returned by execute()."""
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+
+ return_val = await orchestrator.execute(
+ prompt="Say hello",
+ context=context,
+ providers={"default": SimpleTextProvider()},
+ tools={},
+ hooks=hooks,
+ )
+
+ _, data = hooks.get_events(EXECUTION_END)[0]
+ # The response in the event should equal the returned string
+ assert data["response"] == return_val
+
+
+# ---------------------------------------------------------------------------
+# execution:start NOT fired when prompt:submit is denied
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_execution_start_not_fired_when_prompt_submit_denied():
+ """execution:start must NOT fire when the coordinator denies prompt:submit."""
+
+ class DenyingCoordinator(MockCoordinator):
+ async def process_hook_result(self, result, event_name, source):
+ if event_name == "prompt:submit":
+ return HookResult(action="deny", reason="blocked by policy")
+ return result
+
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+ coordinator = DenyingCoordinator()
+
+ result = await orchestrator.execute(
+ prompt="blocked",
+ context=context,
+ providers={"default": SimpleTextProvider()},
+ tools={},
+ hooks=hooks,
+ coordinator=coordinator,
+ )
+
+ assert "denied" in result.lower()
+ assert len(hooks.get_events(EXECUTION_START)) == 0, (
+ "execution:start should not fire when prompt:submit is denied"
+ )
+ assert len(hooks.get_events(EXECUTION_END)) == 0, (
+ "execution:end should not fire when prompt:submit is denied (no execution began)"
+ )
+
+
+# ---------------------------------------------------------------------------
+# execution:end on cancellation paths
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_execution_end_fires_cancelled_at_loop_start():
+ """execution:end fires with status='cancelled' when cancelled at loop start."""
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+ coordinator = MockCoordinator(is_cancelled=True)
+
+ await orchestrator.execute(
+ prompt="test",
+ context=context,
+ providers={"default": SimpleTextProvider()},
+ tools={},
+ hooks=hooks,
+ coordinator=coordinator,
+ )
+
+ events = hooks.get_events(EXECUTION_END)
+ assert len(events) == 1
+ _, data = events[0]
+ assert data["status"] == "cancelled"
+
+
+@pytest.mark.asyncio
+async def test_execution_end_fires_cancelled_after_provider():
+ """execution:end fires with status='cancelled' on immediate cancellation
+ after the provider returns (is_immediate path)."""
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+
+ # Cancellation that takes effect after first provider call
+ class ImmediateCancellationCoordinator(MockCoordinator):
+ def __init__(self):
+ super().__init__()
+ self.cancellation = MockCancellation(is_cancelled=False, is_immediate=True)
+
+ coordinator = ImmediateCancellationCoordinator()
+
+ await orchestrator.execute(
+ prompt="test",
+ context=context,
+ providers={"default": SimpleTextProvider()},
+ tools={},
+ hooks=hooks,
+ coordinator=coordinator,
+ )
+
+ events = hooks.get_events(EXECUTION_END)
+ assert len(events) == 1
+ _, data = events[0]
+ assert data["status"] == "cancelled"
+
+
+# ---------------------------------------------------------------------------
+# execution:end on error path
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_execution_end_fires_on_provider_exception():
+ """execution:end fires with status='error' when provider raises."""
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+
+ with pytest.raises(RuntimeError, match="provider failed"):
+ await orchestrator.execute(
+ prompt="test",
+ context=context,
+ providers={"default": ErrorProvider()},
+ tools={},
+ hooks=hooks,
+ )
+
+ events = hooks.get_events(EXECUTION_END)
+ assert len(events) == 1
+ _, data = events[0]
+ assert data["status"] == "error"
+
+
+@pytest.mark.asyncio
+async def test_execution_end_fires_exactly_once():
+ """execution:end is emitted exactly once per execute() call on the happy path."""
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+
+ await orchestrator.execute(
+ prompt="test",
+ context=context,
+ providers={"default": SimpleTextProvider()},
+ tools={},
+ hooks=hooks,
+ )
+
+ assert len(hooks.get_events(EXECUTION_START)) == 1
+ assert len(hooks.get_events(EXECUTION_END)) == 1
+
+
+# ---------------------------------------------------------------------------
+# tool_call_id in TOOL_PRE and TOOL_POST payloads
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_tool_pre_includes_tool_call_id():
+ """TOOL_PRE event payload must include tool_call_id."""
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+ provider = ToolThenTextProvider(tool_id="my-call-id-123")
+ tool = SimpleTool("test_tool")
+
+ await orchestrator.execute(
+ prompt="test",
+ context=context,
+ providers={"default": provider},
+ tools={"test_tool": tool},
+ hooks=hooks,
+ )
+
+ pre_events = hooks.get_events(TOOL_PRE)
+ assert len(pre_events) == 1
+ _, data = pre_events[0]
+ assert "tool_call_id" in data, "TOOL_PRE must include tool_call_id"
+ assert data["tool_call_id"] == "my-call-id-123"
+
+
+@pytest.mark.asyncio
+async def test_tool_post_includes_tool_call_id():
+ """TOOL_POST event payload must include tool_call_id."""
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+ provider = ToolThenTextProvider(tool_id="post-call-id-456")
+ tool = SimpleTool("test_tool")
+
+ await orchestrator.execute(
+ prompt="test",
+ context=context,
+ providers={"default": provider},
+ tools={"test_tool": tool},
+ hooks=hooks,
+ )
+
+ post_events = hooks.get_events(TOOL_POST)
+ assert len(post_events) == 1
+ _, data = post_events[0]
+ assert "tool_call_id" in data, "TOOL_POST must include tool_call_id"
+ assert data["tool_call_id"] == "post-call-id-456"
+
+
+@pytest.mark.asyncio
+async def test_tool_error_not_found_includes_tool_call_id():
+ """TOOL_ERROR (tool not found) event payload must include tool_call_id."""
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+ # Provider requests a tool that doesn't exist in tools dict
+ provider = ToolThenTextProvider(
+ tool_name="missing_tool", tool_id="error-call-id-789"
+ )
+
+ await orchestrator.execute(
+ prompt="test",
+ context=context,
+ providers={"default": provider},
+ tools={}, # empty — 'missing_tool' won't be found
+ hooks=hooks,
+ )
+
+ error_events = hooks.get_events(TOOL_ERROR)
+ assert len(error_events) == 1
+ _, data = error_events[0]
+ assert "tool_call_id" in data, "TOOL_ERROR (not found) must include tool_call_id"
+ assert data["tool_call_id"] == "error-call-id-789"
+
+
+@pytest.mark.asyncio
+async def test_tool_error_exception_includes_tool_call_id():
+ """TOOL_ERROR (exception in tool.execute) event payload must include tool_call_id."""
+
+ class FailingTool:
+ name = "failing_tool"
+ description = "fails"
+ input_schema = {"type": "object", "properties": {}}
+
+ async def execute(self, args):
+ raise ValueError("tool exploded")
+
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+ provider = ToolThenTextProvider(tool_name="failing_tool", tool_id="exc-call-id-000")
+
+ await orchestrator.execute(
+ prompt="test",
+ context=context,
+ providers={"default": provider},
+ tools={"failing_tool": FailingTool()},
+ hooks=hooks,
+ )
+
+ error_events = hooks.get_events(TOOL_ERROR)
+ assert len(error_events) == 1
+ _, data = error_events[0]
+ assert "tool_call_id" in data, "TOOL_ERROR (exception) must include tool_call_id"
+ assert data["tool_call_id"] == "exc-call-id-000"
+
+
+# ---------------------------------------------------------------------------
+# observability.events registration in mount()
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_mount_registers_observability_events():
+ """mount() must call coordinator.register_contributor('observability.events', ...)."""
+ from amplifier_module_loop_basic import mount
+
+ contributions = {}
+
+ class CapturingCoordinator:
+ async def mount(self, role, instance):
+ pass
+
+ def register_contributor(self, channel, name, callback):
+ contributions[(channel, name)] = callback
+
+ coordinator = CapturingCoordinator()
+ await mount(coordinator)
+
+ assert ("observability.events", "loop-basic") in contributions, (
+ "mount() must register 'loop-basic' as a contributor to 'observability.events'"
+ )
+
+ # Verify the callback returns the expected events
+ callback = contributions[("observability.events", "loop-basic")]
+ events_list = callback()
+ assert "execution:start" in events_list
+ assert "execution:end" in events_list
diff --git a/tests/test_tool_dispatch_context.py b/tests/test_tool_dispatch_context.py
new file mode 100644
index 0000000..817a3bf
--- /dev/null
+++ b/tests/test_tool_dispatch_context.py
@@ -0,0 +1,300 @@
+"""Tests for _tool_dispatch_context set on coordinator during tool.execute().
+
+Verifies that BasicOrchestrator sets coordinator._tool_dispatch_context
+with the correct tool_call_id and parallel_group_id immediately before
+calling tool.execute(), and clears it in a finally block afterward.
+
+Covers:
+- execute_single_tool (inner path): context set with tool_call_id and parallel_group_id
+- context cleared after tool completes normally
+- context cleared even when tool raises an exception
+- Integration: full execute() path sets dispatch context during tool call
+"""
+
+import pytest
+from amplifier_core import ToolResult
+from amplifier_core.message_models import ChatResponse, TextBlock, ToolCall
+from amplifier_core.testing import EventRecorder, MockContextManager
+
+from amplifier_module_loop_basic import BasicOrchestrator
+
+
+# ---------------------------------------------------------------------------
+# Helpers (reuse pattern from test_lifecycle_events.py)
+# ---------------------------------------------------------------------------
+
+
+class MockCancellation:
+ """Minimal cancellation token stub."""
+
+ is_cancelled: bool = False
+ is_immediate: bool = False
+
+ def register_tool_start(self, tool_call_id: str, tool_name: str) -> None:
+ pass
+
+ def register_tool_complete(self, tool_call_id: str) -> None:
+ pass
+
+
+class MockCoordinator:
+ """Minimal coordinator stub that supports _tool_dispatch_context."""
+
+ def __init__(self, is_cancelled: bool = False, is_immediate: bool = False) -> None:
+ self.cancellation = MockCancellation()
+ self.cancellation.is_cancelled = is_cancelled
+ self.cancellation.is_immediate = is_immediate
+ self._contributions: dict = {}
+
+ async def process_hook_result(self, result: object, event_name: str, source: str) -> object:
+ return result
+
+ def register_contributor(self, channel: str, name: str, callback: object) -> None:
+ self._contributions[(channel, name)] = callback
+
+
+# ---------------------------------------------------------------------------
+# Integration tests via full execute() path
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_execute_sets_tool_call_id_in_dispatch_context_during_tool_execution() -> None:
+ """BasicOrchestrator sets tool_call_id in coordinator._tool_dispatch_context
+ before calling tool.execute().
+ """
+ captured: dict = {}
+ coordinator = MockCoordinator()
+
+ class CapturingTool:
+ name = "capture_tool"
+ description = "Captures dispatch context during execution"
+ input_schema: dict = {"type": "object", "properties": {}}
+
+ async def execute(self, args: dict) -> ToolResult:
+ captured.update(getattr(coordinator, "_tool_dispatch_context", {}))
+ return ToolResult(success=True, output="done")
+
+ class ToolThenTextProvider:
+ def __init__(self) -> None:
+ self._call_count = 0
+
+ async def complete(self, request: object, **kwargs: object) -> ChatResponse:
+ self._call_count += 1
+ if self._call_count == 1:
+ return ChatResponse(
+ content=[TextBlock(text="Using tool")],
+ tool_calls=[
+ ToolCall(
+ id="dispatch-call-id-001",
+ name="capture_tool",
+ arguments={"_": 1},
+ )
+ ],
+ )
+ return ChatResponse(content=[TextBlock(text="Done!")])
+
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+
+ await orchestrator.execute(
+ prompt="test",
+ context=context,
+ providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item]
+ tools={"capture_tool": CapturingTool()}, # type: ignore[dict-item]
+ hooks=hooks, # type: ignore[arg-type]
+ coordinator=coordinator, # type: ignore[arg-type]
+ )
+
+ assert captured.get("tool_call_id") == "dispatch-call-id-001", (
+ "coordinator._tool_dispatch_context must have tool_call_id set during tool.execute()"
+ )
+
+
+@pytest.mark.asyncio
+async def test_execute_sets_parallel_group_id_in_dispatch_context() -> None:
+ """BasicOrchestrator sets parallel_group_id in coordinator._tool_dispatch_context."""
+ captured: dict = {}
+ coordinator = MockCoordinator()
+
+ class CapturingTool:
+ name = "capture_tool"
+ description = "Captures dispatch context"
+ input_schema: dict = {"type": "object", "properties": {}}
+
+ async def execute(self, args: dict) -> ToolResult:
+ captured.update(getattr(coordinator, "_tool_dispatch_context", {}))
+ return ToolResult(success=True, output="done")
+
+ class ToolThenTextProvider:
+ def __init__(self) -> None:
+ self._call_count = 0
+
+ async def complete(self, request: object, **kwargs: object) -> ChatResponse:
+ self._call_count += 1
+ if self._call_count == 1:
+ return ChatResponse(
+ content=[TextBlock(text="Using tool")],
+ tool_calls=[
+ ToolCall(id="group-test-call-id", name="capture_tool", arguments={"_": 1})
+ ],
+ )
+ return ChatResponse(content=[TextBlock(text="Done!")])
+
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+
+ await orchestrator.execute(
+ prompt="test",
+ context=context,
+ providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item]
+ tools={"capture_tool": CapturingTool()}, # type: ignore[dict-item]
+ hooks=hooks, # type: ignore[arg-type]
+ coordinator=coordinator, # type: ignore[arg-type]
+ )
+
+ # parallel_group_id is a UUID generated per-batch — verify it's a non-empty string
+ pgid = captured.get("parallel_group_id")
+ assert isinstance(pgid, str) and pgid, (
+ "_tool_dispatch_context must have a non-empty parallel_group_id string"
+ )
+
+
+@pytest.mark.asyncio
+async def test_execute_clears_dispatch_context_after_tool_completes() -> None:
+ """BasicOrchestrator clears coordinator._tool_dispatch_context after tool.execute()."""
+ coordinator = MockCoordinator()
+
+ class SimpleTool:
+ name = "simple_tool"
+ description = "Returns a successful result"
+ input_schema: dict = {"type": "object", "properties": {}}
+
+ async def execute(self, args: dict) -> ToolResult:
+ return ToolResult(success=True, output="done")
+
+ class ToolThenTextProvider:
+ def __init__(self) -> None:
+ self._call_count = 0
+
+ async def complete(self, request: object, **kwargs: object) -> ChatResponse:
+ self._call_count += 1
+ if self._call_count == 1:
+ return ChatResponse(
+ content=[TextBlock(text="Using tool")],
+ tool_calls=[
+ ToolCall(id="clear-test-id", name="simple_tool", arguments={"_": 1})
+ ],
+ )
+ return ChatResponse(content=[TextBlock(text="Done!")])
+
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+
+ await orchestrator.execute(
+ prompt="test",
+ context=context,
+ providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item]
+ tools={"simple_tool": SimpleTool()}, # type: ignore[dict-item]
+ hooks=hooks, # type: ignore[arg-type]
+ coordinator=coordinator, # type: ignore[arg-type]
+ )
+
+ ctx_after = getattr(coordinator, "_tool_dispatch_context", None)
+ assert ctx_after == {}, (
+ "_tool_dispatch_context must be cleared to {} after tool execution completes"
+ )
+
+
+@pytest.mark.asyncio
+async def test_execute_clears_dispatch_context_after_tool_raises() -> None:
+ """BasicOrchestrator clears coordinator._tool_dispatch_context even when tool raises."""
+ coordinator = MockCoordinator()
+
+ class RaisingTool:
+ name = "raising_tool"
+ description = "Always raises an exception"
+ input_schema: dict = {"type": "object", "properties": {}}
+
+ async def execute(self, args: dict) -> ToolResult:
+ raise ValueError("tool exploded")
+
+ class ToolThenTextProvider:
+ def __init__(self) -> None:
+ self._call_count = 0
+
+ async def complete(self, request: object, **kwargs: object) -> ChatResponse:
+ self._call_count += 1
+ if self._call_count == 1:
+ return ChatResponse(
+ content=[TextBlock(text="Using tool")],
+ tool_calls=[
+ ToolCall(id="raise-test-id", name="raising_tool", arguments={"_": 1})
+ ],
+ )
+ return ChatResponse(content=[TextBlock(text="Done despite error!")])
+
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+
+ # BasicOrchestrator handles tool exceptions gracefully (no raise propagation)
+ await orchestrator.execute(
+ prompt="test",
+ context=context,
+ providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item]
+ tools={"raising_tool": RaisingTool()}, # type: ignore[dict-item]
+ hooks=hooks, # type: ignore[arg-type]
+ coordinator=coordinator, # type: ignore[arg-type]
+ )
+
+ ctx_after = getattr(coordinator, "_tool_dispatch_context", None)
+ assert ctx_after == {}, (
+ "_tool_dispatch_context must be cleared even when tool.execute() raises"
+ )
+
+
+@pytest.mark.asyncio
+async def test_execute_no_coordinator_does_not_set_dispatch_context() -> None:
+ """Without a coordinator, execute() still runs tools normally (no dispatch context set)."""
+
+ class SimpleTool:
+ name = "simple_tool"
+ description = "Returns a successful result"
+ input_schema: dict = {"type": "object", "properties": {}}
+
+ async def execute(self, args: dict) -> ToolResult:
+ return ToolResult(success=True, output="done")
+
+ class ToolThenTextProvider:
+ def __init__(self) -> None:
+ self._call_count = 0
+
+ async def complete(self, request: object, **kwargs: object) -> ChatResponse:
+ self._call_count += 1
+ if self._call_count == 1:
+ return ChatResponse(
+ content=[TextBlock(text="Using tool")],
+ tool_calls=[
+ ToolCall(id="no-coord-id", name="simple_tool", arguments={"_": 1})
+ ],
+ )
+ return ChatResponse(content=[TextBlock(text="Done!")])
+
+ orchestrator = BasicOrchestrator({})
+ hooks = EventRecorder()
+ context = MockContextManager()
+
+ # Should complete without error even when coordinator=None
+ result = await orchestrator.execute(
+ prompt="test",
+ context=context,
+ providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item]
+ tools={"simple_tool": SimpleTool()}, # type: ignore[dict-item]
+ hooks=hooks, # type: ignore[arg-type]
+ )
+
+ assert result is not None
diff --git a/uv.lock b/uv.lock
index b3e7958..dbc7687 100644
--- a/uv.lock
+++ b/uv.lock
@@ -4,8 +4,8 @@ requires-python = ">=3.11"
[[package]]
name = "amplifier-core"
-version = "1.0.0"
-source = { git = "https://github.com/microsoft/amplifier-core?branch=main#f246c6f36589d45cb92a69f4e354fe9edc8249c2" }
+version = "1.3.0"
+source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
{ name = "pydantic" },
@@ -13,6 +13,16 @@ dependencies = [
{ name = "tomli" },
{ name = "typing-extensions" },
]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/64/0a/c9f979aa34ff43d86323fd02c6e0b2049cf583f48a75eefe4e2d4ea39a5b/amplifier_core-1.3.0-cp311-abi3-macosx_10_12_x86_64.whl", hash = "sha256:e94a6adfb538c2f621ddd4eb44f2f2dc645934fef804bc5318d4595520f4f4c9", size = 8096649, upload-time = "2026-03-19T14:01:24.218Z" },
+ { url = "https://files.pythonhosted.org/packages/01/58/619daa63943870340673c625695038b86567adc9896b9d81ff8a2a707b3b/amplifier_core-1.3.0-cp311-abi3-macosx_11_0_arm64.whl", hash = "sha256:d051f91a2c3c61c240f828daf892ac2cfa8d34e9850b245e7ebc96e87f9ac606", size = 7216307, upload-time = "2026-03-19T14:01:26.032Z" },
+ { url = "https://files.pythonhosted.org/packages/6b/35/6aeb099f012c8d97ccd4aad878a1422fe3148840922e4818a7c1279fdc57/amplifier_core-1.3.0-cp311-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af25c5e691638496226aa88e80056cb1870b715ddebd311bccb855c653ace05f", size = 7593745, upload-time = "2026-03-19T14:01:28.027Z" },
+ { url = "https://files.pythonhosted.org/packages/c8/8e/d2a092e31f1924c6c977a3a091b026560a31133fb1f94bf4389f71a8b249/amplifier_core-1.3.0-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18f1d9e2ffa6a9aac507699c039d3dc8dc23fe696a5af515e2fce1d90b34e11b", size = 8624234, upload-time = "2026-03-19T14:01:30.068Z" },
+ { url = "https://files.pythonhosted.org/packages/b8/a3/d6e83e879ecdc300fb4516f1f68f6bab4ab5e2dcc4037a90b9df74861911/amplifier_core-1.3.0-cp311-abi3-win_amd64.whl", hash = "sha256:1bf419d8d659821589b6af68d7a9d6c5e4495095bd7b240ff67527e0ab137985", size = 8887729, upload-time = "2026-03-19T14:01:32.392Z" },
+ { url = "https://files.pythonhosted.org/packages/12/00/36e1f6456a7a6782986918f4ec6890fe6526d3fda478bd46c57fd7cfa9b3/amplifier_core-1.3.0-cp311-abi3-win_arm64.whl", hash = "sha256:eebac607c14c5fac1e12eabb97c09ade64bd2d003f2e9fa1566ac4898bfe848d", size = 7658166, upload-time = "2026-03-19T14:01:34.443Z" },
+ { url = "https://files.pythonhosted.org/packages/02/f0/3beca3cc30323e60f88c8126612a05e3d97f0c951c2d7920458e7ae8e480/amplifier_core-1.3.0-cp313-cp313t-win_arm64.whl", hash = "sha256:a2f8c128f78c8c615f6e411f378b6ad5dc70db9d342a347435df759e4ea4cdbf", size = 7648908, upload-time = "2026-03-19T14:01:36.763Z" },
+ { url = "https://files.pythonhosted.org/packages/d8/38/84b012e1b50226ab97cf8ad9f688708bcb6a34a2c6dd139a28d60d40cd71/amplifier_core-1.3.0-cp314-cp314t-win_arm64.whl", hash = "sha256:5bbfac975bf6757c479b31bd07b1d465fde49d95dded497c4d161509919dd93b", size = 7647745, upload-time = "2026-03-19T14:01:38.955Z" },
+]
[[package]]
name = "amplifier-module-loop-basic"
@@ -30,7 +40,7 @@ dev = [
[package.metadata.requires-dev]
dev = [
- { name = "amplifier-core", git = "https://github.com/microsoft/amplifier-core?branch=main" },
+ { name = "amplifier-core", specifier = ">=1.0.10" },
{ name = "pytest", specifier = ">=8.0.0" },
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
]