Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,321 changes: 722 additions & 599 deletions amplifier_module_loop_basic/__init__.py

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,11 @@ allow-direct-references = true

[dependency-groups]
dev = [
"amplifier-core",
"amplifier-core>=1.0.10",
"pytest>=8.0.0",
"pytest-asyncio>=1.0.0",
]

[tool.uv.sources]
amplifier-core = { git = "https://github.com/microsoft/amplifier-core", branch = "main" }

[tool.pytest.ini_options]
testpaths = ["tests"]
Expand Down
105 changes: 105 additions & 0 deletions tests/test_cancelled_error_dict_tool_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Tests for CancelledError handler with dict-based tool_calls.

Regression test for unsafe tc.id / tc.name access at lines 536-537.
The CancelledError handler used bare attribute access on tool_call objects
that may be plain dicts. Every other access site (9 of them) uses the safe
dual-access pattern: getattr(tc, "id", None) or tc.get("id").
"""

import asyncio

import pytest

from amplifier_core.testing import EventRecorder, MockContextManager

from amplifier_module_loop_basic import BasicOrchestrator


class DictToolCallProvider:
"""Provider that returns tool_calls as plain dicts (not ToolCall objects).

Some providers return tool_calls as dicts rather than objects.
The orchestrator explicitly accommodates this with a dual-access pattern.
"""

name = "dict-provider"

async def complete(self, request, **kwargs):
return type(
"Response",
(),
{
"content": "Calling tool",
"tool_calls": [
{"id": "tc1", "tool": "cancel_tool", "arguments": {}}
],
"usage": None,
"content_blocks": None,
"metadata": None,
},
)()


class CancellingTool:
"""Tool that raises CancelledError to simulate immediate cancellation."""

name = "cancel_tool"
description = "tool that simulates cancellation"
input_schema = {"type": "object", "properties": {}}

async def execute(self, args):
raise asyncio.CancelledError()


@pytest.mark.asyncio
async def test_cancelled_error_handler_with_dict_tool_calls():
"""CancelledError handler must not crash when tool_calls are plain dicts.

Without the fix, line 536 (tc.id) raises:
AttributeError: 'dict' object has no attribute 'id'

With the fix, CancelledError propagates cleanly after synthesizing
cancelled tool results into the context.
"""
orchestrator = BasicOrchestrator({})
context = MockContextManager()
hooks = EventRecorder()

with pytest.raises(asyncio.CancelledError):
await orchestrator.execute(
prompt="Test",
context=context,
providers={"default": DictToolCallProvider()},
tools={"cancel_tool": CancellingTool()},
hooks=hooks,
)


@pytest.mark.asyncio
async def test_cancelled_error_synthesizes_messages_for_dict_tool_calls():
"""After fix, cancelled tool results are properly added to context.

Verifies the synthesized cancellation messages contain the correct
tool_call_id and tool name extracted via the safe dual-access pattern.
"""
orchestrator = BasicOrchestrator({})
context = MockContextManager()
hooks = EventRecorder()

with pytest.raises(asyncio.CancelledError):
await orchestrator.execute(
prompt="Test",
context=context,
providers={"default": DictToolCallProvider()},
tools={"cancel_tool": CancellingTool()},
hooks=hooks,
)

# Find the synthesized cancellation message in context
tool_messages = [m for m in context.messages if m.get("role") == "tool"]
assert len(tool_messages) >= 1, "Expected at least one synthesized tool message"

cancel_msg = tool_messages[-1]
assert cancel_msg["tool_call_id"] == "tc1"
assert "cancelled" in cancel_msg["content"]
assert "cancel_tool" in cancel_msg["content"]
227 changes: 227 additions & 0 deletions tests/test_hook_modify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""Tests for hook modify action on tool:post events.

Verifies that when a hook returns HookResult(action="modify", data={"result": ...})
on a tool:post event, the orchestrator uses the modified data instead of the
original result.get_serialized_output().
"""

import json

import pytest
from unittest.mock import AsyncMock, MagicMock, patch

from amplifier_core.hooks import HookRegistry
from amplifier_core.models import HookResult


def _make_tool_result(output, success=True):
"""Create a mock tool result with get_serialized_output() method."""
result = MagicMock()
result.success = success
result.output = output
result.error = None

def get_serialized_output():
if isinstance(output, (dict, list)):
return json.dumps(output)
return str(output)

result.get_serialized_output = get_serialized_output

def to_dict():
return {"success": success, "output": output, "error": None}

result.to_dict = to_dict
return result


def _make_provider_responses(tool_calls_response, text_response_str="Done"):
"""Create mock provider responses: one with tool calls, one with text."""
tool_call = MagicMock()
tool_call.id = "tc_1"
tool_call.name = "test_tool"
tool_call.arguments = {"key": "value"}

tool_response = MagicMock()
tool_response.content = [MagicMock(type="text", text="Using tool")]
tool_response.tool_calls = [tool_call]
tool_response.usage = None
tool_response.content_blocks = None
tool_response.metadata = None

text_block = MagicMock()
text_block.text = text_response_str
text_block.type = "text"
text_response = MagicMock()
text_response.content = [text_block]
text_response.tool_calls = None
text_response.usage = None
text_response.content_blocks = None
text_response.metadata = None

mock_provider = AsyncMock()
mock_provider.complete = AsyncMock(side_effect=[tool_response, text_response])
mock_provider.priority = 1

return mock_provider


def _make_context():
"""Create a mock context that captures add_message calls."""
context = AsyncMock()
messages_added = []

async def capture_add_message(msg):
messages_added.append(msg)

context.add_message = AsyncMock(side_effect=capture_add_message)
context.get_messages_for_request = AsyncMock(
return_value=[{"role": "user", "content": "test"}]
)
return context, messages_added


def _get_tool_result_messages(messages_added):
"""Extract tool result messages from captured messages."""
return [msg for msg in messages_added if msg.get("role") == "tool"]


@pytest.mark.asyncio
async def test_tool_post_modify_replaces_result():
"""When a hook returns action='modify' on tool:post, the modified data
should be used instead of the original get_serialized_output()."""
with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}):
from amplifier_module_loop_basic import BasicOrchestrator

orchestrator = BasicOrchestrator({"max_iterations": 5})

# Tool with original output
original_output = {"original": True, "big_data": "x" * 1000}
mock_tool = MagicMock()
mock_tool.name = "test_tool"
mock_tool.description = "A test tool"
mock_tool.input_schema = {"type": "object", "properties": {}}
mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output))
tools = {"test_tool": mock_tool}

# Hook that returns modify with new data
modified_content = {"modified": True, "truncated": True}
hooks = HookRegistry()

async def modify_hook(event: str, data: dict) -> HookResult:
if event == "tool:post":
return HookResult(action="modify", data={"result": modified_content})
return HookResult()

hooks.register("tool:post", modify_hook, priority=50, name="test_modify")

mock_provider = _make_provider_responses(tool_calls_response=True)
providers = {"test_provider": mock_provider}
context, messages_added = _make_context()

await orchestrator.execute(
prompt="test prompt",
context=context,
providers=providers,
tools=tools,
hooks=hooks,
)

tool_msgs = _get_tool_result_messages(messages_added)
assert len(tool_msgs) == 1, f"Expected 1 tool result, got {len(tool_msgs)}"

tool_result_content = tool_msgs[0]["content"]

# The content should be the MODIFIED data, not the original
assert tool_result_content == json.dumps(modified_content), (
f"Expected modified content {json.dumps(modified_content)}, "
f"got {tool_result_content}"
)

# Verify the original data is NOT used
assert "big_data" not in tool_result_content


@pytest.mark.asyncio
async def test_tool_post_no_modify_uses_original():
"""When no hook returns modify, the original get_serialized_output() is used."""
with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}):
from amplifier_module_loop_basic import BasicOrchestrator

orchestrator = BasicOrchestrator({"max_iterations": 5})

original_output = {"original": True}
mock_tool = MagicMock()
mock_tool.name = "test_tool"
mock_tool.description = "A test tool"
mock_tool.input_schema = {"type": "object", "properties": {}}
mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output))
tools = {"test_tool": mock_tool}

# No hooks registered — default continue action
hooks = HookRegistry()

mock_provider = _make_provider_responses(tool_calls_response=True)
providers = {"test_provider": mock_provider}
context, messages_added = _make_context()

await orchestrator.execute(
prompt="test prompt",
context=context,
providers=providers,
tools=tools,
hooks=hooks,
)

tool_msgs = _get_tool_result_messages(messages_added)
assert len(tool_msgs) == 1
tool_result_content = tool_msgs[0]["content"]

# Should use original serialized output
assert tool_result_content == json.dumps(original_output), (
f"Expected original {json.dumps(original_output)}, got {tool_result_content}"
)


@pytest.mark.asyncio
async def test_tool_post_modify_with_string_result():
"""When a hook returns modify with a string result, it should be used as-is."""
with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}):
from amplifier_module_loop_basic import BasicOrchestrator

orchestrator = BasicOrchestrator({"max_iterations": 5})

mock_tool = MagicMock()
mock_tool.name = "test_tool"
mock_tool.description = "A test tool"
mock_tool.input_schema = {"type": "object", "properties": {}}
mock_tool.execute = AsyncMock(return_value=_make_tool_result("original text"))
tools = {"test_tool": mock_tool}

hooks = HookRegistry()

async def modify_hook(event: str, data: dict) -> HookResult:
if event == "tool:post":
return HookResult(
action="modify",
data={"result": "truncated string result"},
)
return HookResult()

hooks.register("tool:post", modify_hook, priority=50, name="test_modify")

mock_provider = _make_provider_responses(tool_calls_response=True)
providers = {"test_provider": mock_provider}
context, messages_added = _make_context()

await orchestrator.execute(
prompt="test prompt",
context=context,
providers=providers,
tools=tools,
hooks=hooks,
)

tool_msgs = _get_tool_result_messages(messages_added)
assert len(tool_msgs) == 1
assert tool_msgs[0]["content"] == "truncated string result"
Loading