Skip to content

Commit ed8da2d

Browse files
committed
feat: set _tool_dispatch_context on coordinator before tool.execute()
The orchestrator now sets coordinator._tool_dispatch_context = { "tool_call_id": tool_call_id, "parallel_group_id": group_id, } immediately before calling tool.execute() in execute_single_tool, and clears it in the outer finally block (alongside register_tool_complete). This lets tools that need framework correlation IDs (e.g. the delegate tool) read them via: dispatch_ctx = getattr(self.coordinator, '_tool_dispatch_context', {}) tool_call_id = dispatch_ctx.get('tool_call_id', '') Uses setattr() to avoid type errors on the dynamic private attribute. The finally block always clears the context to prevent leaking between parallel tool executions. Fixes Check 4: delegate:agent_spawned tool_call_id was always empty because the delegate tool was reading from tool input (which the LLM never populates with framework IDs).
1 parent 3b9ff5e commit ed8da2d

File tree

2 files changed

+313
-1
lines changed

2 files changed

+313
-1
lines changed

amplifier_module_loop_basic/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,17 @@ async def execute_single_tool(
419419
coordinator.cancellation.register_tool_start(
420420
tool_call_id, tool_name
421421
)
422+
# Set dispatch context so tools (e.g. delegate) can
423+
# read the framework-assigned tool_call_id and
424+
# parallel_group_id. Cleared in the finally block.
425+
setattr(
426+
coordinator,
427+
"_tool_dispatch_context",
428+
{
429+
"tool_call_id": tool_call_id,
430+
"parallel_group_id": group_id,
431+
},
432+
)
422433

423434
try:
424435
try:
@@ -545,8 +556,9 @@ async def execute_single_tool(
545556
logger.error(f"Tool {tool_name} failed: {te}")
546557
return (tool_call_id, error_msg)
547558
finally:
548-
# Unregister tool from cancellation token
559+
# Clear dispatch context and unregister tool
549560
if coordinator:
561+
setattr(coordinator, "_tool_dispatch_context", {})
550562
coordinator.cancellation.register_tool_complete(
551563
tool_call_id
552564
)
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
"""Tests for _tool_dispatch_context set on coordinator during tool.execute().
2+
3+
Verifies that BasicOrchestrator sets coordinator._tool_dispatch_context
4+
with the correct tool_call_id and parallel_group_id immediately before
5+
calling tool.execute(), and clears it in a finally block afterward.
6+
7+
Covers:
8+
- execute_single_tool (inner path): context set with tool_call_id and parallel_group_id
9+
- context cleared after tool completes normally
10+
- context cleared even when tool raises an exception
11+
- Integration: full execute() path sets dispatch context during tool call
12+
"""
13+
14+
import pytest
15+
from amplifier_core import ToolResult
16+
from amplifier_core.message_models import ChatResponse, TextBlock, ToolCall
17+
from amplifier_core.testing import EventRecorder, MockContextManager
18+
19+
from amplifier_module_loop_basic import BasicOrchestrator
20+
21+
22+
# ---------------------------------------------------------------------------
23+
# Helpers (reuse pattern from test_lifecycle_events.py)
24+
# ---------------------------------------------------------------------------
25+
26+
27+
class MockCancellation:
28+
"""Minimal cancellation token stub."""
29+
30+
is_cancelled: bool = False
31+
is_immediate: bool = False
32+
33+
def register_tool_start(self, tool_call_id: str, tool_name: str) -> None:
34+
pass
35+
36+
def register_tool_complete(self, tool_call_id: str) -> None:
37+
pass
38+
39+
40+
class MockCoordinator:
41+
"""Minimal coordinator stub that supports _tool_dispatch_context."""
42+
43+
def __init__(self, is_cancelled: bool = False, is_immediate: bool = False) -> None:
44+
self.cancellation = MockCancellation()
45+
self.cancellation.is_cancelled = is_cancelled
46+
self.cancellation.is_immediate = is_immediate
47+
self._contributions: dict = {}
48+
49+
async def process_hook_result(self, result: object, event_name: str, source: str) -> object:
50+
return result
51+
52+
def register_contributor(self, channel: str, name: str, callback: object) -> None:
53+
self._contributions[(channel, name)] = callback
54+
55+
56+
# ---------------------------------------------------------------------------
57+
# Integration tests via full execute() path
58+
# ---------------------------------------------------------------------------
59+
60+
61+
@pytest.mark.asyncio
62+
async def test_execute_sets_tool_call_id_in_dispatch_context_during_tool_execution() -> None:
63+
"""BasicOrchestrator sets tool_call_id in coordinator._tool_dispatch_context
64+
before calling tool.execute().
65+
"""
66+
captured: dict = {}
67+
coordinator = MockCoordinator()
68+
69+
class CapturingTool:
70+
name = "capture_tool"
71+
description = "Captures dispatch context during execution"
72+
input_schema: dict = {"type": "object", "properties": {}}
73+
74+
async def execute(self, args: dict) -> ToolResult:
75+
captured.update(getattr(coordinator, "_tool_dispatch_context", {}))
76+
return ToolResult(success=True, output="done")
77+
78+
class ToolThenTextProvider:
79+
def __init__(self) -> None:
80+
self._call_count = 0
81+
82+
async def complete(self, request: object, **kwargs: object) -> ChatResponse:
83+
self._call_count += 1
84+
if self._call_count == 1:
85+
return ChatResponse(
86+
content=[TextBlock(text="Using tool")],
87+
tool_calls=[
88+
ToolCall(
89+
id="dispatch-call-id-001",
90+
name="capture_tool",
91+
arguments={"_": 1},
92+
)
93+
],
94+
)
95+
return ChatResponse(content=[TextBlock(text="Done!")])
96+
97+
orchestrator = BasicOrchestrator({})
98+
hooks = EventRecorder()
99+
context = MockContextManager()
100+
101+
await orchestrator.execute(
102+
prompt="test",
103+
context=context,
104+
providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item]
105+
tools={"capture_tool": CapturingTool()}, # type: ignore[dict-item]
106+
hooks=hooks, # type: ignore[arg-type]
107+
coordinator=coordinator, # type: ignore[arg-type]
108+
)
109+
110+
assert captured.get("tool_call_id") == "dispatch-call-id-001", (
111+
"coordinator._tool_dispatch_context must have tool_call_id set during tool.execute()"
112+
)
113+
114+
115+
@pytest.mark.asyncio
116+
async def test_execute_sets_parallel_group_id_in_dispatch_context() -> None:
117+
"""BasicOrchestrator sets parallel_group_id in coordinator._tool_dispatch_context."""
118+
captured: dict = {}
119+
coordinator = MockCoordinator()
120+
121+
class CapturingTool:
122+
name = "capture_tool"
123+
description = "Captures dispatch context"
124+
input_schema: dict = {"type": "object", "properties": {}}
125+
126+
async def execute(self, args: dict) -> ToolResult:
127+
captured.update(getattr(coordinator, "_tool_dispatch_context", {}))
128+
return ToolResult(success=True, output="done")
129+
130+
class ToolThenTextProvider:
131+
def __init__(self) -> None:
132+
self._call_count = 0
133+
134+
async def complete(self, request: object, **kwargs: object) -> ChatResponse:
135+
self._call_count += 1
136+
if self._call_count == 1:
137+
return ChatResponse(
138+
content=[TextBlock(text="Using tool")],
139+
tool_calls=[
140+
ToolCall(id="group-test-call-id", name="capture_tool", arguments={"_": 1})
141+
],
142+
)
143+
return ChatResponse(content=[TextBlock(text="Done!")])
144+
145+
orchestrator = BasicOrchestrator({})
146+
hooks = EventRecorder()
147+
context = MockContextManager()
148+
149+
await orchestrator.execute(
150+
prompt="test",
151+
context=context,
152+
providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item]
153+
tools={"capture_tool": CapturingTool()}, # type: ignore[dict-item]
154+
hooks=hooks, # type: ignore[arg-type]
155+
coordinator=coordinator, # type: ignore[arg-type]
156+
)
157+
158+
# parallel_group_id is a UUID generated per-batch — verify it's a non-empty string
159+
pgid = captured.get("parallel_group_id")
160+
assert isinstance(pgid, str) and pgid, (
161+
"_tool_dispatch_context must have a non-empty parallel_group_id string"
162+
)
163+
164+
165+
@pytest.mark.asyncio
166+
async def test_execute_clears_dispatch_context_after_tool_completes() -> None:
167+
"""BasicOrchestrator clears coordinator._tool_dispatch_context after tool.execute()."""
168+
coordinator = MockCoordinator()
169+
170+
class SimpleTool:
171+
name = "simple_tool"
172+
description = "Returns a successful result"
173+
input_schema: dict = {"type": "object", "properties": {}}
174+
175+
async def execute(self, args: dict) -> ToolResult:
176+
return ToolResult(success=True, output="done")
177+
178+
class ToolThenTextProvider:
179+
def __init__(self) -> None:
180+
self._call_count = 0
181+
182+
async def complete(self, request: object, **kwargs: object) -> ChatResponse:
183+
self._call_count += 1
184+
if self._call_count == 1:
185+
return ChatResponse(
186+
content=[TextBlock(text="Using tool")],
187+
tool_calls=[
188+
ToolCall(id="clear-test-id", name="simple_tool", arguments={"_": 1})
189+
],
190+
)
191+
return ChatResponse(content=[TextBlock(text="Done!")])
192+
193+
orchestrator = BasicOrchestrator({})
194+
hooks = EventRecorder()
195+
context = MockContextManager()
196+
197+
await orchestrator.execute(
198+
prompt="test",
199+
context=context,
200+
providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item]
201+
tools={"simple_tool": SimpleTool()}, # type: ignore[dict-item]
202+
hooks=hooks, # type: ignore[arg-type]
203+
coordinator=coordinator, # type: ignore[arg-type]
204+
)
205+
206+
ctx_after = getattr(coordinator, "_tool_dispatch_context", None)
207+
assert ctx_after == {}, (
208+
"_tool_dispatch_context must be cleared to {} after tool execution completes"
209+
)
210+
211+
212+
@pytest.mark.asyncio
213+
async def test_execute_clears_dispatch_context_after_tool_raises() -> None:
214+
"""BasicOrchestrator clears coordinator._tool_dispatch_context even when tool raises."""
215+
coordinator = MockCoordinator()
216+
217+
class RaisingTool:
218+
name = "raising_tool"
219+
description = "Always raises an exception"
220+
input_schema: dict = {"type": "object", "properties": {}}
221+
222+
async def execute(self, args: dict) -> ToolResult:
223+
raise ValueError("tool exploded")
224+
225+
class ToolThenTextProvider:
226+
def __init__(self) -> None:
227+
self._call_count = 0
228+
229+
async def complete(self, request: object, **kwargs: object) -> ChatResponse:
230+
self._call_count += 1
231+
if self._call_count == 1:
232+
return ChatResponse(
233+
content=[TextBlock(text="Using tool")],
234+
tool_calls=[
235+
ToolCall(id="raise-test-id", name="raising_tool", arguments={"_": 1})
236+
],
237+
)
238+
return ChatResponse(content=[TextBlock(text="Done despite error!")])
239+
240+
orchestrator = BasicOrchestrator({})
241+
hooks = EventRecorder()
242+
context = MockContextManager()
243+
244+
# BasicOrchestrator handles tool exceptions gracefully (no raise propagation)
245+
await orchestrator.execute(
246+
prompt="test",
247+
context=context,
248+
providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item]
249+
tools={"raising_tool": RaisingTool()}, # type: ignore[dict-item]
250+
hooks=hooks, # type: ignore[arg-type]
251+
coordinator=coordinator, # type: ignore[arg-type]
252+
)
253+
254+
ctx_after = getattr(coordinator, "_tool_dispatch_context", None)
255+
assert ctx_after == {}, (
256+
"_tool_dispatch_context must be cleared even when tool.execute() raises"
257+
)
258+
259+
260+
@pytest.mark.asyncio
261+
async def test_execute_no_coordinator_does_not_set_dispatch_context() -> None:
262+
"""Without a coordinator, execute() still runs tools normally (no dispatch context set)."""
263+
264+
class SimpleTool:
265+
name = "simple_tool"
266+
description = "Returns a successful result"
267+
input_schema: dict = {"type": "object", "properties": {}}
268+
269+
async def execute(self, args: dict) -> ToolResult:
270+
return ToolResult(success=True, output="done")
271+
272+
class ToolThenTextProvider:
273+
def __init__(self) -> None:
274+
self._call_count = 0
275+
276+
async def complete(self, request: object, **kwargs: object) -> ChatResponse:
277+
self._call_count += 1
278+
if self._call_count == 1:
279+
return ChatResponse(
280+
content=[TextBlock(text="Using tool")],
281+
tool_calls=[
282+
ToolCall(id="no-coord-id", name="simple_tool", arguments={"_": 1})
283+
],
284+
)
285+
return ChatResponse(content=[TextBlock(text="Done!")])
286+
287+
orchestrator = BasicOrchestrator({})
288+
hooks = EventRecorder()
289+
context = MockContextManager()
290+
291+
# Should complete without error even when coordinator=None
292+
result = await orchestrator.execute(
293+
prompt="test",
294+
context=context,
295+
providers={"default": ToolThenTextProvider()}, # type: ignore[dict-item]
296+
tools={"simple_tool": SimpleTool()}, # type: ignore[dict-item]
297+
hooks=hooks, # type: ignore[arg-type]
298+
)
299+
300+
assert result is not None

0 commit comments

Comments
 (0)