-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtest_hook_modify.py
More file actions
227 lines (173 loc) · 7.53 KB
/
test_hook_modify.py
File metadata and controls
227 lines (173 loc) · 7.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""Tests for hook modify action on tool:post events.
Verifies that when a hook returns HookResult(action="modify", data={"result": ...})
on a tool:post event, the orchestrator uses the modified data instead of the
original result.get_serialized_output().
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from amplifier_core.hooks import HookRegistry
from amplifier_core.models import HookResult
def _make_tool_result(output, success=True):
"""Create a mock tool result with get_serialized_output() method."""
result = MagicMock()
result.success = success
result.output = output
result.error = None
def get_serialized_output():
if isinstance(output, (dict, list)):
return json.dumps(output)
return str(output)
result.get_serialized_output = get_serialized_output
def to_dict():
return {"success": success, "output": output, "error": None}
result.to_dict = to_dict
return result
def _make_provider_responses(tool_calls_response, text_response_str="Done"):
"""Create mock provider responses: one with tool calls, one with text."""
tool_call = MagicMock()
tool_call.id = "tc_1"
tool_call.name = "test_tool"
tool_call.arguments = {"key": "value"}
tool_response = MagicMock()
tool_response.content = [MagicMock(type="text", text="Using tool")]
tool_response.tool_calls = [tool_call]
tool_response.usage = None
tool_response.content_blocks = None
tool_response.metadata = None
text_block = MagicMock()
text_block.text = text_response_str
text_block.type = "text"
text_response = MagicMock()
text_response.content = [text_block]
text_response.tool_calls = None
text_response.usage = None
text_response.content_blocks = None
text_response.metadata = None
mock_provider = AsyncMock()
mock_provider.complete = AsyncMock(side_effect=[tool_response, text_response])
mock_provider.priority = 1
return mock_provider
def _make_context():
"""Create a mock context that captures add_message calls."""
context = AsyncMock()
messages_added = []
async def capture_add_message(msg):
messages_added.append(msg)
context.add_message = AsyncMock(side_effect=capture_add_message)
context.get_messages_for_request = AsyncMock(
return_value=[{"role": "user", "content": "test"}]
)
return context, messages_added
def _get_tool_result_messages(messages_added):
"""Extract tool result messages from captured messages."""
return [msg for msg in messages_added if msg.get("role") == "tool"]
@pytest.mark.asyncio
async def test_tool_post_modify_replaces_result():
"""When a hook returns action='modify' on tool:post, the modified data
should be used instead of the original get_serialized_output()."""
with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}):
from amplifier_module_loop_basic import BasicOrchestrator
orchestrator = BasicOrchestrator({"max_iterations": 5})
# Tool with original output
original_output = {"original": True, "big_data": "x" * 1000}
mock_tool = MagicMock()
mock_tool.name = "test_tool"
mock_tool.description = "A test tool"
mock_tool.input_schema = {"type": "object", "properties": {}}
mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output))
tools = {"test_tool": mock_tool}
# Hook that returns modify with new data
modified_content = {"modified": True, "truncated": True}
hooks = HookRegistry()
async def modify_hook(event: str, data: dict) -> HookResult:
if event == "tool:post":
return HookResult(action="modify", data={"result": modified_content})
return HookResult()
hooks.register("tool:post", modify_hook, priority=50, name="test_modify")
mock_provider = _make_provider_responses(tool_calls_response=True)
providers = {"test_provider": mock_provider}
context, messages_added = _make_context()
await orchestrator.execute(
prompt="test prompt",
context=context,
providers=providers,
tools=tools,
hooks=hooks,
)
tool_msgs = _get_tool_result_messages(messages_added)
assert len(tool_msgs) == 1, f"Expected 1 tool result, got {len(tool_msgs)}"
tool_result_content = tool_msgs[0]["content"]
# The content should be the MODIFIED data, not the original
assert tool_result_content == json.dumps(modified_content), (
f"Expected modified content {json.dumps(modified_content)}, "
f"got {tool_result_content}"
)
# Verify the original data is NOT used
assert "big_data" not in tool_result_content
@pytest.mark.asyncio
async def test_tool_post_no_modify_uses_original():
"""When no hook returns modify, the original get_serialized_output() is used."""
with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}):
from amplifier_module_loop_basic import BasicOrchestrator
orchestrator = BasicOrchestrator({"max_iterations": 5})
original_output = {"original": True}
mock_tool = MagicMock()
mock_tool.name = "test_tool"
mock_tool.description = "A test tool"
mock_tool.input_schema = {"type": "object", "properties": {}}
mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output))
tools = {"test_tool": mock_tool}
# No hooks registered — default continue action
hooks = HookRegistry()
mock_provider = _make_provider_responses(tool_calls_response=True)
providers = {"test_provider": mock_provider}
context, messages_added = _make_context()
await orchestrator.execute(
prompt="test prompt",
context=context,
providers=providers,
tools=tools,
hooks=hooks,
)
tool_msgs = _get_tool_result_messages(messages_added)
assert len(tool_msgs) == 1
tool_result_content = tool_msgs[0]["content"]
# Should use original serialized output
assert tool_result_content == json.dumps(original_output), (
f"Expected original {json.dumps(original_output)}, got {tool_result_content}"
)
@pytest.mark.asyncio
async def test_tool_post_modify_with_string_result():
"""When a hook returns modify with a string result, it should be used as-is."""
with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}):
from amplifier_module_loop_basic import BasicOrchestrator
orchestrator = BasicOrchestrator({"max_iterations": 5})
mock_tool = MagicMock()
mock_tool.name = "test_tool"
mock_tool.description = "A test tool"
mock_tool.input_schema = {"type": "object", "properties": {}}
mock_tool.execute = AsyncMock(return_value=_make_tool_result("original text"))
tools = {"test_tool": mock_tool}
hooks = HookRegistry()
async def modify_hook(event: str, data: dict) -> HookResult:
if event == "tool:post":
return HookResult(
action="modify",
data={"result": "truncated string result"},
)
return HookResult()
hooks.register("tool:post", modify_hook, priority=50, name="test_modify")
mock_provider = _make_provider_responses(tool_calls_response=True)
providers = {"test_provider": mock_provider}
context, messages_added = _make_context()
await orchestrator.execute(
prompt="test prompt",
context=context,
providers=providers,
tools=tools,
hooks=hooks,
)
tool_msgs = _get_tool_result_messages(messages_added)
assert len(tool_msgs) == 1
assert tool_msgs[0]["content"] == "truncated string result"