Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/agents/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def on_tool_end(
context: RunContextWrapper[TContext],
agent: TAgent,
tool: Tool,
result: str,
result: Any,
) -> None:
"""Called immediately after a local tool is invoked.

Expand Down Expand Up @@ -161,7 +161,7 @@ async def on_tool_end(
context: RunContextWrapper[TContext],
agent: TAgent,
tool: Tool,
result: str,
result: Any,
) -> None:
"""Called immediately after a local tool is invoked.

Expand Down
2 changes: 1 addition & 1 deletion tests/test_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def on_tool_end(
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
result: str,
result: Any,
) -> None:
self.events["on_tool_end"] += 1
if isinstance(context, ToolContext):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_agent_llm_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def on_tool_end(
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
result: str,
result: Any,
) -> None:
self.events["on_tool_end"] += 1

Expand Down
4 changes: 2 additions & 2 deletions tests/test_computer_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ async def on_tool_start(
self.started.append((agent, tool))

async def on_tool_end(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: Any
) -> None:
self.ended.append((agent, tool, result))

Expand All @@ -529,7 +529,7 @@ async def on_tool_start(
self.started.append((agent, tool))

async def on_tool_end(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: Any
) -> None:
self.ended.append((agent, tool, result))

Expand Down
2 changes: 1 addition & 1 deletion tests/test_global_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def on_tool_end(
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
result: str,
result: Any,
) -> None:
self.events["on_tool_end"] += 1
if isinstance(context, ToolContext):
Expand Down
54 changes: 52 additions & 2 deletions tests/test_run_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from agents.models.interface import Model
from agents.run import Runner
from agents.run_context import AgentHookContext, RunContextWrapper, TContext
from agents.tool import Tool
from agents.tool import Tool, function_tool
from agents.tool_context import ToolContext
from tests.test_agent_llm_hooks import AgentHooksForTests

Expand Down Expand Up @@ -60,7 +60,7 @@ async def on_tool_end(
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
result: str,
result: Any,
) -> None:
self.events["on_tool_end"] += 1
if isinstance(context, ToolContext):
Expand Down Expand Up @@ -386,3 +386,53 @@ async def test_streamed_run_hooks_count_tool_and_handoff_invocations():
assert hooks.events["on_agent_start"] == 2
assert hooks.events["on_agent_end"] == 1
assert len(hooks.tool_context_ids) == 2


@pytest.mark.asyncio
async def test_run_hooks_receive_structured_tool_result():
class RecordingRunHooks(RunHooks):
def __init__(self):
self.result: Any = None

async def on_tool_end(
self,
context: RunContextWrapper[Any],
agent: Agent[Any],
tool: Tool,
result: Any,
) -> None:
self.result = result

class RecordingAgentHooks(AgentHooks):
def __init__(self):
self.result: Any = None

async def on_tool_end(
self,
context: RunContextWrapper[Any],
agent: Agent[Any],
tool: Tool,
result: Any,
) -> None:
self.result = result

@function_tool
def get_metadata() -> dict[str, Any]:
return {"status": "ok", "count": 1}

run_hooks = RecordingRunHooks()
agent_hooks = RecordingAgentHooks()
model = FakeModel()
agent = Agent(name="test", model=model, tools=[get_metadata], hooks=agent_hooks)

model.add_multiple_turn_outputs(
[
[get_function_tool_call("get_metadata", "{}")],
[get_text_message("done")],
]
)

await Runner.run(agent, input="user_message", hooks=run_hooks)

assert run_hooks.result == {"status": "ok", "count": 1}
assert agent_hooks.result == {"status": "ok", "count": 1}
12 changes: 6 additions & 6 deletions tests/test_run_step_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ async def on_tool_end(
context: RunContextWrapper[Any],
agent: Agent[Any],
tool,
result: str,
result: Any,
) -> None:
if tool.name != "ok_tool":
return
Expand Down Expand Up @@ -1117,7 +1117,7 @@ async def on_tool_end(
context: RunContextWrapper[Any],
agent: Agent[Any],
tool,
result: str,
result: Any,
) -> None:
seen_values.append(("hook", tool_state.get()))

Expand Down Expand Up @@ -1249,7 +1249,7 @@ async def on_tool_end(
context: RunContextWrapper[Any],
agent: Agent[Any],
tool,
result: str,
result: Any,
) -> None:
self.results[tool.name] = result
if tool.name == "ok_tool":
Expand Down Expand Up @@ -1308,7 +1308,7 @@ async def on_tool_end(
context: RunContextWrapper[Any],
agent: Agent[Any],
tool,
result: str,
result: Any,
) -> None:
if tool.name == "waiting_tool":
on_tool_end_called.set()
Expand Down Expand Up @@ -1378,7 +1378,7 @@ async def on_tool_end(
context: RunContextWrapper[Any],
agent: Agent[Any],
tool,
result: str,
result: Any,
) -> None:
on_tool_end_called.set()

Expand Down Expand Up @@ -1922,7 +1922,7 @@ async def on_tool_end(
context: RunContextWrapper[Any],
agent: Agent[Any],
tool,
result: str,
result: Any,
) -> None:
if tool.name != "ok_tool":
return
Expand Down