Skip to content

Commit aa37f4e

Browse files
fix: read HookResult modify action from tool:post events (#4)
When a hook returns action='modify' on tool:post, the orchestrator now detects the modified data by comparing the returned result with the original, and uses it instead of the original get_serialized_output(). This enables tool output truncation, sanitization, and transformation hooks to work correctly. 🤖 Generated with [Amplifier](https://github.com/microsoft/amplifier) Co-authored-by: Amplifier <240397093+microsoft-amplifier@users.noreply.github.com>
1 parent e344b9f commit aa37f4e

File tree

2 files changed

+249
-2
lines changed

2 files changed

+249
-2
lines changed

amplifier_module_loop_basic/__init__.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# Amplifier module metadata
66
__amplifier_module_type__ = "orchestrator"
77

8+
import json
89
import logging
910
from typing import Any
1011

@@ -463,8 +464,27 @@ async def execute_single_tool(
463464
f"Stored ephemeral injection from tool:post ({tool_name}) for next iteration"
464465
)
465466

466-
# Return success with result content (JSON-serialized for dict/list)
467-
result_content = result.get_serialized_output()
467+
# Check if a hook modified the tool result.
468+
# hooks.emit() chains modify actions: when a hook
469+
# returns action="modify", the data dict is replaced.
470+
# We detect this by checking if the returned "result"
471+
# is a different object than what we originally sent.
472+
modified_result = None
473+
if post_result and post_result.data is not None:
474+
returned_result = post_result.data.get("result")
475+
if (
476+
returned_result is not None
477+
and returned_result is not result_data
478+
):
479+
modified_result = returned_result
480+
481+
if modified_result is not None:
482+
if isinstance(modified_result, (dict, list)):
483+
result_content = json.dumps(modified_result)
484+
else:
485+
result_content = str(modified_result)
486+
else:
487+
result_content = result.get_serialized_output()
468488
return (tool_call_id, result_content)
469489

470490
except Exception as te:

tests/test_hook_modify.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""Tests for hook modify action on tool:post events.
2+
3+
Verifies that when a hook returns HookResult(action="modify", data={"result": ...})
4+
on a tool:post event, the orchestrator uses the modified data instead of the
5+
original result.get_serialized_output().
6+
"""
7+
8+
import json
9+
10+
import pytest
11+
from unittest.mock import AsyncMock, MagicMock, patch
12+
13+
from amplifier_core.hooks import HookRegistry
14+
from amplifier_core.models import HookResult
15+
16+
17+
def _make_tool_result(output, success=True):
18+
"""Create a mock tool result with get_serialized_output() method."""
19+
result = MagicMock()
20+
result.success = success
21+
result.output = output
22+
result.error = None
23+
24+
def get_serialized_output():
25+
if isinstance(output, (dict, list)):
26+
return json.dumps(output)
27+
return str(output)
28+
29+
result.get_serialized_output = get_serialized_output
30+
31+
def to_dict():
32+
return {"success": success, "output": output, "error": None}
33+
34+
result.to_dict = to_dict
35+
return result
36+
37+
38+
def _make_provider_responses(tool_calls_response, text_response_str="Done"):
39+
"""Create mock provider responses: one with tool calls, one with text."""
40+
tool_call = MagicMock()
41+
tool_call.id = "tc_1"
42+
tool_call.name = "test_tool"
43+
tool_call.arguments = {"key": "value"}
44+
45+
tool_response = MagicMock()
46+
tool_response.content = [MagicMock(type="text", text="Using tool")]
47+
tool_response.tool_calls = [tool_call]
48+
tool_response.usage = None
49+
tool_response.content_blocks = None
50+
tool_response.metadata = None
51+
52+
text_block = MagicMock()
53+
text_block.text = text_response_str
54+
text_block.type = "text"
55+
text_response = MagicMock()
56+
text_response.content = [text_block]
57+
text_response.tool_calls = None
58+
text_response.usage = None
59+
text_response.content_blocks = None
60+
text_response.metadata = None
61+
62+
mock_provider = AsyncMock()
63+
mock_provider.complete = AsyncMock(side_effect=[tool_response, text_response])
64+
mock_provider.priority = 1
65+
66+
return mock_provider
67+
68+
69+
def _make_context():
70+
"""Create a mock context that captures add_message calls."""
71+
context = AsyncMock()
72+
messages_added = []
73+
74+
async def capture_add_message(msg):
75+
messages_added.append(msg)
76+
77+
context.add_message = AsyncMock(side_effect=capture_add_message)
78+
context.get_messages_for_request = AsyncMock(
79+
return_value=[{"role": "user", "content": "test"}]
80+
)
81+
return context, messages_added
82+
83+
84+
def _get_tool_result_messages(messages_added):
85+
"""Extract tool result messages from captured messages."""
86+
return [msg for msg in messages_added if msg.get("role") == "tool"]
87+
88+
89+
@pytest.mark.asyncio
90+
async def test_tool_post_modify_replaces_result():
91+
"""When a hook returns action='modify' on tool:post, the modified data
92+
should be used instead of the original get_serialized_output()."""
93+
with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}):
94+
from amplifier_module_loop_basic import BasicOrchestrator
95+
96+
orchestrator = BasicOrchestrator({"max_iterations": 5})
97+
98+
# Tool with original output
99+
original_output = {"original": True, "big_data": "x" * 1000}
100+
mock_tool = MagicMock()
101+
mock_tool.name = "test_tool"
102+
mock_tool.description = "A test tool"
103+
mock_tool.input_schema = {"type": "object", "properties": {}}
104+
mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output))
105+
tools = {"test_tool": mock_tool}
106+
107+
# Hook that returns modify with new data
108+
modified_content = {"modified": True, "truncated": True}
109+
hooks = HookRegistry()
110+
111+
async def modify_hook(event: str, data: dict) -> HookResult:
112+
if event == "tool:post":
113+
return HookResult(action="modify", data={"result": modified_content})
114+
return HookResult()
115+
116+
hooks.register("tool:post", modify_hook, priority=50, name="test_modify")
117+
118+
mock_provider = _make_provider_responses(tool_calls_response=True)
119+
providers = {"test_provider": mock_provider}
120+
context, messages_added = _make_context()
121+
122+
await orchestrator.execute(
123+
prompt="test prompt",
124+
context=context,
125+
providers=providers,
126+
tools=tools,
127+
hooks=hooks,
128+
)
129+
130+
tool_msgs = _get_tool_result_messages(messages_added)
131+
assert len(tool_msgs) == 1, f"Expected 1 tool result, got {len(tool_msgs)}"
132+
133+
tool_result_content = tool_msgs[0]["content"]
134+
135+
# The content should be the MODIFIED data, not the original
136+
assert tool_result_content == json.dumps(modified_content), (
137+
f"Expected modified content {json.dumps(modified_content)}, "
138+
f"got {tool_result_content}"
139+
)
140+
141+
# Verify the original data is NOT used
142+
assert "big_data" not in tool_result_content
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_tool_post_no_modify_uses_original():
147+
"""When no hook returns modify, the original get_serialized_output() is used."""
148+
with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}):
149+
from amplifier_module_loop_basic import BasicOrchestrator
150+
151+
orchestrator = BasicOrchestrator({"max_iterations": 5})
152+
153+
original_output = {"original": True}
154+
mock_tool = MagicMock()
155+
mock_tool.name = "test_tool"
156+
mock_tool.description = "A test tool"
157+
mock_tool.input_schema = {"type": "object", "properties": {}}
158+
mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output))
159+
tools = {"test_tool": mock_tool}
160+
161+
# No hooks registered — default continue action
162+
hooks = HookRegistry()
163+
164+
mock_provider = _make_provider_responses(tool_calls_response=True)
165+
providers = {"test_provider": mock_provider}
166+
context, messages_added = _make_context()
167+
168+
await orchestrator.execute(
169+
prompt="test prompt",
170+
context=context,
171+
providers=providers,
172+
tools=tools,
173+
hooks=hooks,
174+
)
175+
176+
tool_msgs = _get_tool_result_messages(messages_added)
177+
assert len(tool_msgs) == 1
178+
tool_result_content = tool_msgs[0]["content"]
179+
180+
# Should use original serialized output
181+
assert tool_result_content == json.dumps(original_output), (
182+
f"Expected original {json.dumps(original_output)}, got {tool_result_content}"
183+
)
184+
185+
186+
@pytest.mark.asyncio
187+
async def test_tool_post_modify_with_string_result():
188+
"""When a hook returns modify with a string result, it should be used as-is."""
189+
with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}):
190+
from amplifier_module_loop_basic import BasicOrchestrator
191+
192+
orchestrator = BasicOrchestrator({"max_iterations": 5})
193+
194+
mock_tool = MagicMock()
195+
mock_tool.name = "test_tool"
196+
mock_tool.description = "A test tool"
197+
mock_tool.input_schema = {"type": "object", "properties": {}}
198+
mock_tool.execute = AsyncMock(return_value=_make_tool_result("original text"))
199+
tools = {"test_tool": mock_tool}
200+
201+
hooks = HookRegistry()
202+
203+
async def modify_hook(event: str, data: dict) -> HookResult:
204+
if event == "tool:post":
205+
return HookResult(
206+
action="modify",
207+
data={"result": "truncated string result"},
208+
)
209+
return HookResult()
210+
211+
hooks.register("tool:post", modify_hook, priority=50, name="test_modify")
212+
213+
mock_provider = _make_provider_responses(tool_calls_response=True)
214+
providers = {"test_provider": mock_provider}
215+
context, messages_added = _make_context()
216+
217+
await orchestrator.execute(
218+
prompt="test prompt",
219+
context=context,
220+
providers=providers,
221+
tools=tools,
222+
hooks=hooks,
223+
)
224+
225+
tool_msgs = _get_tool_result_messages(messages_added)
226+
assert len(tool_msgs) == 1
227+
assert tool_msgs[0]["content"] == "truncated string result"

0 commit comments

Comments
 (0)