|
| 1 | +"""Tests for _tool_dispatch_context set on coordinator during tool.execute(). |
| 2 | +
|
| 3 | +Verifies that BasicOrchestrator sets coordinator._tool_dispatch_context |
| 4 | +with the correct tool_call_id and parallel_group_id immediately before |
| 5 | +calling tool.execute(), and clears it in a finally block afterward. |
| 6 | +
|
| 7 | +Covers: |
| 8 | +- execute_single_tool (inner path): context set with tool_call_id and parallel_group_id |
| 9 | +- context cleared after tool completes normally |
| 10 | +- context cleared even when tool raises an exception |
| 11 | +- Integration: full execute() path sets dispatch context during tool call |
| 12 | +""" |
| 13 | + |
| 14 | +import pytest |
| 15 | +from amplifier_core import ToolResult |
| 16 | +from amplifier_core.message_models import ChatResponse, TextBlock, ToolCall |
| 17 | +from amplifier_core.testing import EventRecorder, MockContextManager |
| 18 | + |
| 19 | +from amplifier_module_loop_basic import BasicOrchestrator |
| 20 | + |
| 21 | + |
| 22 | +# --------------------------------------------------------------------------- |
| 23 | +# Helpers (reuse pattern from test_lifecycle_events.py) |
| 24 | +# --------------------------------------------------------------------------- |
| 25 | + |
| 26 | + |
| 27 | +class MockCancellation: |
| 28 | + """Minimal cancellation token stub.""" |
| 29 | + |
| 30 | + is_cancelled: bool = False |
| 31 | + is_immediate: bool = False |
| 32 | + |
| 33 | + def register_tool_start(self, tool_call_id: str, tool_name: str) -> None: |
| 34 | + pass |
| 35 | + |
| 36 | + def register_tool_complete(self, tool_call_id: str) -> None: |
| 37 | + pass |
| 38 | + |
| 39 | + |
| 40 | +class MockCoordinator: |
| 41 | + """Minimal coordinator stub that supports _tool_dispatch_context.""" |
| 42 | + |
| 43 | + def __init__(self, is_cancelled: bool = False, is_immediate: bool = False) -> None: |
| 44 | + self.cancellation = MockCancellation() |
| 45 | + self.cancellation.is_cancelled = is_cancelled |
| 46 | + self.cancellation.is_immediate = is_immediate |
| 47 | + self._contributions: dict = {} |
| 48 | + |
| 49 | + async def process_hook_result(self, result: object, event_name: str, source: str) -> object: |
| 50 | + return result |
| 51 | + |
| 52 | + def register_contributor(self, channel: str, name: str, callback: object) -> None: |
| 53 | + self._contributions[(channel, name)] = callback |
| 54 | + |
| 55 | + |
| 56 | +# --------------------------------------------------------------------------- |
| 57 | +# Integration tests via full execute() path |
| 58 | +# --------------------------------------------------------------------------- |
| 59 | + |
| 60 | + |
| 61 | +@pytest.mark.asyncio |
| 62 | +async def test_execute_sets_tool_call_id_in_dispatch_context_during_tool_execution() -> None: |
| 63 | + """BasicOrchestrator sets tool_call_id in coordinator._tool_dispatch_context |
| 64 | + before calling tool.execute(). |
| 65 | + """ |
| 66 | + captured: dict = {} |
| 67 | + coordinator = MockCoordinator() |
| 68 | + |
| 69 | + class CapturingTool: |
| 70 | + name = "capture_tool" |
| 71 | + description = "Captures dispatch context during execution" |
| 72 | + input_schema: dict = {"type": "object", "properties": {}} |
| 73 | + |
| 74 | + async def execute(self, args: dict) -> ToolResult: |
| 75 | + captured.update(getattr(coordinator, "_tool_dispatch_context", {})) |
| 76 | + return ToolResult(success=True, output="done") |
| 77 | + |
| 78 | + class ToolThenTextProvider: |
| 79 | + def __init__(self) -> None: |
| 80 | + self._call_count = 0 |
| 81 | + |
| 82 | + async def complete(self, request: object, **kwargs: object) -> ChatResponse: |
| 83 | + self._call_count += 1 |
| 84 | + if self._call_count == 1: |
| 85 | + return ChatResponse( |
| 86 | + content=[TextBlock(text="Using tool")], |
| 87 | + tool_calls=[ |
| 88 | + ToolCall( |
| 89 | + id="dispatch-call-id-001", |
| 90 | + name="capture_tool", |
| 91 | + arguments={"_": 1}, |
| 92 | + ) |
| 93 | + ], |
| 94 | + ) |
| 95 | + return ChatResponse(content=[TextBlock(text="Done!")]) |
| 96 | + |
| 97 | + orchestrator = BasicOrchestrator({}) |
| 98 | + hooks = EventRecorder() |
| 99 | + context = MockContextManager() |
| 100 | + |
| 101 | + await orchestrator.execute( |
| 102 | + prompt="test", |
| 103 | + context=context, |
| 104 | + providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item] |
| 105 | + tools={"capture_tool": CapturingTool()}, # type: ignore[dict-item] |
| 106 | + hooks=hooks, # type: ignore[arg-type] |
| 107 | + coordinator=coordinator, # type: ignore[arg-type] |
| 108 | + ) |
| 109 | + |
| 110 | + assert captured.get("tool_call_id") == "dispatch-call-id-001", ( |
| 111 | + "coordinator._tool_dispatch_context must have tool_call_id set during tool.execute()" |
| 112 | + ) |
| 113 | + |
| 114 | + |
| 115 | +@pytest.mark.asyncio |
| 116 | +async def test_execute_sets_parallel_group_id_in_dispatch_context() -> None: |
| 117 | + """BasicOrchestrator sets parallel_group_id in coordinator._tool_dispatch_context.""" |
| 118 | + captured: dict = {} |
| 119 | + coordinator = MockCoordinator() |
| 120 | + |
| 121 | + class CapturingTool: |
| 122 | + name = "capture_tool" |
| 123 | + description = "Captures dispatch context" |
| 124 | + input_schema: dict = {"type": "object", "properties": {}} |
| 125 | + |
| 126 | + async def execute(self, args: dict) -> ToolResult: |
| 127 | + captured.update(getattr(coordinator, "_tool_dispatch_context", {})) |
| 128 | + return ToolResult(success=True, output="done") |
| 129 | + |
| 130 | + class ToolThenTextProvider: |
| 131 | + def __init__(self) -> None: |
| 132 | + self._call_count = 0 |
| 133 | + |
| 134 | + async def complete(self, request: object, **kwargs: object) -> ChatResponse: |
| 135 | + self._call_count += 1 |
| 136 | + if self._call_count == 1: |
| 137 | + return ChatResponse( |
| 138 | + content=[TextBlock(text="Using tool")], |
| 139 | + tool_calls=[ |
| 140 | + ToolCall(id="group-test-call-id", name="capture_tool", arguments={"_": 1}) |
| 141 | + ], |
| 142 | + ) |
| 143 | + return ChatResponse(content=[TextBlock(text="Done!")]) |
| 144 | + |
| 145 | + orchestrator = BasicOrchestrator({}) |
| 146 | + hooks = EventRecorder() |
| 147 | + context = MockContextManager() |
| 148 | + |
| 149 | + await orchestrator.execute( |
| 150 | + prompt="test", |
| 151 | + context=context, |
| 152 | + providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item] |
| 153 | + tools={"capture_tool": CapturingTool()}, # type: ignore[dict-item] |
| 154 | + hooks=hooks, # type: ignore[arg-type] |
| 155 | + coordinator=coordinator, # type: ignore[arg-type] |
| 156 | + ) |
| 157 | + |
| 158 | + # parallel_group_id is a UUID generated per-batch — verify it's a non-empty string |
| 159 | + pgid = captured.get("parallel_group_id") |
| 160 | + assert isinstance(pgid, str) and pgid, ( |
| 161 | + "_tool_dispatch_context must have a non-empty parallel_group_id string" |
| 162 | + ) |
| 163 | + |
| 164 | + |
| 165 | +@pytest.mark.asyncio |
| 166 | +async def test_execute_clears_dispatch_context_after_tool_completes() -> None: |
| 167 | + """BasicOrchestrator clears coordinator._tool_dispatch_context after tool.execute().""" |
| 168 | + coordinator = MockCoordinator() |
| 169 | + |
| 170 | + class SimpleTool: |
| 171 | + name = "simple_tool" |
| 172 | + description = "Returns a successful result" |
| 173 | + input_schema: dict = {"type": "object", "properties": {}} |
| 174 | + |
| 175 | + async def execute(self, args: dict) -> ToolResult: |
| 176 | + return ToolResult(success=True, output="done") |
| 177 | + |
| 178 | + class ToolThenTextProvider: |
| 179 | + def __init__(self) -> None: |
| 180 | + self._call_count = 0 |
| 181 | + |
| 182 | + async def complete(self, request: object, **kwargs: object) -> ChatResponse: |
| 183 | + self._call_count += 1 |
| 184 | + if self._call_count == 1: |
| 185 | + return ChatResponse( |
| 186 | + content=[TextBlock(text="Using tool")], |
| 187 | + tool_calls=[ |
| 188 | + ToolCall(id="clear-test-id", name="simple_tool", arguments={"_": 1}) |
| 189 | + ], |
| 190 | + ) |
| 191 | + return ChatResponse(content=[TextBlock(text="Done!")]) |
| 192 | + |
| 193 | + orchestrator = BasicOrchestrator({}) |
| 194 | + hooks = EventRecorder() |
| 195 | + context = MockContextManager() |
| 196 | + |
| 197 | + await orchestrator.execute( |
| 198 | + prompt="test", |
| 199 | + context=context, |
| 200 | + providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item] |
| 201 | + tools={"simple_tool": SimpleTool()}, # type: ignore[dict-item] |
| 202 | + hooks=hooks, # type: ignore[arg-type] |
| 203 | + coordinator=coordinator, # type: ignore[arg-type] |
| 204 | + ) |
| 205 | + |
| 206 | + ctx_after = getattr(coordinator, "_tool_dispatch_context", None) |
| 207 | + assert ctx_after == {}, ( |
| 208 | + "_tool_dispatch_context must be cleared to {} after tool execution completes" |
| 209 | + ) |
| 210 | + |
| 211 | + |
| 212 | +@pytest.mark.asyncio |
| 213 | +async def test_execute_clears_dispatch_context_after_tool_raises() -> None: |
| 214 | + """BasicOrchestrator clears coordinator._tool_dispatch_context even when tool raises.""" |
| 215 | + coordinator = MockCoordinator() |
| 216 | + |
| 217 | + class RaisingTool: |
| 218 | + name = "raising_tool" |
| 219 | + description = "Always raises an exception" |
| 220 | + input_schema: dict = {"type": "object", "properties": {}} |
| 221 | + |
| 222 | + async def execute(self, args: dict) -> ToolResult: |
| 223 | + raise ValueError("tool exploded") |
| 224 | + |
| 225 | + class ToolThenTextProvider: |
| 226 | + def __init__(self) -> None: |
| 227 | + self._call_count = 0 |
| 228 | + |
| 229 | + async def complete(self, request: object, **kwargs: object) -> ChatResponse: |
| 230 | + self._call_count += 1 |
| 231 | + if self._call_count == 1: |
| 232 | + return ChatResponse( |
| 233 | + content=[TextBlock(text="Using tool")], |
| 234 | + tool_calls=[ |
| 235 | + ToolCall(id="raise-test-id", name="raising_tool", arguments={"_": 1}) |
| 236 | + ], |
| 237 | + ) |
| 238 | + return ChatResponse(content=[TextBlock(text="Done despite error!")]) |
| 239 | + |
| 240 | + orchestrator = BasicOrchestrator({}) |
| 241 | + hooks = EventRecorder() |
| 242 | + context = MockContextManager() |
| 243 | + |
| 244 | + # BasicOrchestrator handles tool exceptions gracefully (no raise propagation) |
| 245 | + await orchestrator.execute( |
| 246 | + prompt="test", |
| 247 | + context=context, |
| 248 | + providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item] |
| 249 | + tools={"raising_tool": RaisingTool()}, # type: ignore[dict-item] |
| 250 | + hooks=hooks, # type: ignore[arg-type] |
| 251 | + coordinator=coordinator, # type: ignore[arg-type] |
| 252 | + ) |
| 253 | + |
| 254 | + ctx_after = getattr(coordinator, "_tool_dispatch_context", None) |
| 255 | + assert ctx_after == {}, ( |
| 256 | + "_tool_dispatch_context must be cleared even when tool.execute() raises" |
| 257 | + ) |
| 258 | + |
| 259 | + |
| 260 | +@pytest.mark.asyncio |
| 261 | +async def test_execute_no_coordinator_does_not_set_dispatch_context() -> None: |
| 262 | + """Without a coordinator, execute() still runs tools normally (no dispatch context set).""" |
| 263 | + |
| 264 | + class SimpleTool: |
| 265 | + name = "simple_tool" |
| 266 | + description = "Returns a successful result" |
| 267 | + input_schema: dict = {"type": "object", "properties": {}} |
| 268 | + |
| 269 | + async def execute(self, args: dict) -> ToolResult: |
| 270 | + return ToolResult(success=True, output="done") |
| 271 | + |
| 272 | + class ToolThenTextProvider: |
| 273 | + def __init__(self) -> None: |
| 274 | + self._call_count = 0 |
| 275 | + |
| 276 | + async def complete(self, request: object, **kwargs: object) -> ChatResponse: |
| 277 | + self._call_count += 1 |
| 278 | + if self._call_count == 1: |
| 279 | + return ChatResponse( |
| 280 | + content=[TextBlock(text="Using tool")], |
| 281 | + tool_calls=[ |
| 282 | + ToolCall(id="no-coord-id", name="simple_tool", arguments={"_": 1}) |
| 283 | + ], |
| 284 | + ) |
| 285 | + return ChatResponse(content=[TextBlock(text="Done!")]) |
| 286 | + |
| 287 | + orchestrator = BasicOrchestrator({}) |
| 288 | + hooks = EventRecorder() |
| 289 | + context = MockContextManager() |
| 290 | + |
| 291 | + # Should complete without error even when coordinator=None |
| 292 | + result = await orchestrator.execute( |
| 293 | + prompt="test", |
| 294 | + context=context, |
| 295 | + providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item] |
| 296 | + tools={"simple_tool": SimpleTool()}, # type: ignore[dict-item] |
| 297 | + hooks=hooks, # type: ignore[arg-type] |
| 298 | + ) |
| 299 | + |
| 300 | + assert result is not None |
0 commit comments