Skip to content

Commit 0b9d040

Browse files
feat: add cooperative cancellation support
- Check cancellation token at iteration start to exit loop when cancelled - Register tool start/complete with cancellation module for visibility - Check immediate cancellation after parallel tool execution - Emit ORCHESTRATOR_COMPLETE with status 'cancelled' when cancelled Enables graceful Ctrl+C handling where current tools complete before stopping. 🤖 Generated with [Amplifier](https://github.com/microsoft/amplifier) Co-Authored-By: Amplifier <240397093+microsoft-amplifier@users.noreply.github.com>
1 parent ad757dc commit 0b9d040

File tree

1 file changed

+91
-55
lines changed

1 file changed

+91
-55
lines changed

amplifier_module_loop_basic/__init__.py

Lines changed: 91 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ async def execute(
8484
final_content = ""
8585

8686
while self.max_iterations == -1 or iteration < self.max_iterations:
87+
# Check for cancellation at iteration start
88+
if coordinator and coordinator.cancellation.is_cancelled:
89+
# Emit orchestrator complete with cancelled status
90+
await hooks.emit(
91+
ORCHESTRATOR_COMPLETE,
92+
{
93+
"orchestrator": "loop-basic",
94+
"turn_count": iteration,
95+
"status": "cancelled",
96+
},
97+
)
98+
return final_content
99+
87100
# Emit provider request BEFORE getting messages (allows hook injections)
88101
result = await hooks.emit(PROVIDER_REQUEST, {"provider": provider_name, "iteration": iteration})
89102
if coordinator:
@@ -258,79 +271,102 @@ async def execute_single_tool(tc: Any, group_id: str) -> tuple[str, str]:
258271
args = getattr(tc, "arguments", None) or tc.get("arguments") or {}
259272
tool = tools.get(tool_name)
260273

274+
# Register tool with cancellation token for visibility
275+
if coordinator:
276+
coordinator.cancellation.register_tool_start(tool_call_id, tool_name)
277+
261278
try:
262-
# Emit and process tool pre (allows hooks to block or request approval)
263-
pre_result = await hooks.emit(
264-
TOOL_PRE,
265-
{
266-
"tool_name": tool_name,
267-
"tool_input": args,
268-
"parallel_group_id": group_id,
269-
},
270-
)
271-
if coordinator:
272-
pre_result = await coordinator.process_hook_result(pre_result, "tool:pre", tool_name)
273-
if pre_result.action == "deny":
274-
return (tool_call_id, f"Denied by hook: {pre_result.reason}")
279+
try:
280+
# Emit and process tool pre (allows hooks to block or request approval)
281+
pre_result = await hooks.emit(
282+
TOOL_PRE,
283+
{
284+
"tool_name": tool_name,
285+
"tool_input": args,
286+
"parallel_group_id": group_id,
287+
},
288+
)
289+
if coordinator:
290+
pre_result = await coordinator.process_hook_result(pre_result, "tool:pre", tool_name)
291+
if pre_result.action == "deny":
292+
return (tool_call_id, f"Denied by hook: {pre_result.reason}")
293+
294+
if not tool:
295+
error_msg = f"Error: Tool '{tool_name}' not found"
296+
await hooks.emit(
297+
TOOL_ERROR,
298+
{
299+
"tool_name": tool_name,
300+
"error": {"type": "RuntimeError", "msg": error_msg},
301+
"parallel_group_id": group_id,
302+
},
303+
)
304+
return (tool_call_id, error_msg)
305+
306+
result = await tool.execute(args)
307+
308+
# Serialize result for logging
309+
result_data = result
310+
if hasattr(result, "to_dict"):
311+
result_data = result.to_dict()
312+
313+
# Emit and process tool post (allows hooks to inject feedback)
314+
post_result = await hooks.emit(
315+
TOOL_POST,
316+
{
317+
"tool_name": tool_name,
318+
"tool_input": args,
319+
"result": result_data,
320+
"parallel_group_id": group_id,
321+
},
322+
)
323+
if coordinator:
324+
await coordinator.process_hook_result(post_result, "tool:post", tool_name)
275325

276-
if not tool:
277-
error_msg = f"Error: Tool '{tool_name}' not found"
326+
# Return success with result content (JSON-serialized for dict/list)
327+
result_content = result.get_serialized_output()
328+
return (tool_call_id, result_content)
329+
330+
except Exception as te:
331+
# Emit error event
278332
await hooks.emit(
279333
TOOL_ERROR,
280334
{
281335
"tool_name": tool_name,
282-
"error": {"type": "RuntimeError", "msg": error_msg},
336+
"error": {"type": type(te).__name__, "msg": str(te)},
283337
"parallel_group_id": group_id,
284338
},
285339
)
286-
return (tool_call_id, error_msg)
287-
288-
result = await tool.execute(args)
289340

290-
# Serialize result for logging
291-
result_data = result
292-
if hasattr(result, "to_dict"):
293-
result_data = result.to_dict()
294-
295-
# Emit and process tool post (allows hooks to inject feedback)
296-
post_result = await hooks.emit(
297-
TOOL_POST,
298-
{
299-
"tool_name": tool_name,
300-
"tool_input": args,
301-
"result": result_data,
302-
"parallel_group_id": group_id,
303-
},
304-
)
341+
# Return failure with error message (don't raise!)
342+
error_msg = f"Error executing tool: {str(te)}"
343+
logger.error(f"Tool {tool_name} failed: {te}")
344+
return (tool_call_id, error_msg)
345+
finally:
346+
# Unregister tool from cancellation token
305347
if coordinator:
306-
await coordinator.process_hook_result(post_result, "tool:post", tool_name)
307-
308-
# Return success with result content (JSON-serialized for dict/list)
309-
result_content = result.get_serialized_output()
310-
return (tool_call_id, result_content)
311-
312-
except Exception as te:
313-
# Emit error event
314-
await hooks.emit(
315-
TOOL_ERROR,
316-
{
317-
"tool_name": tool_name,
318-
"error": {"type": type(te).__name__, "msg": str(te)},
319-
"parallel_group_id": group_id,
320-
},
321-
)
322-
323-
# Return failure with error message (don't raise!)
324-
error_msg = f"Error executing tool: {str(te)}"
325-
logger.error(f"Tool {tool_name} failed: {te}")
326-
return (tool_call_id, error_msg)
348+
coordinator.cancellation.register_tool_complete(tool_call_id)
327349

328350
# Execute all tools in parallel with asyncio.gather
329351
# return_exceptions=False because we handle exceptions inside execute_single_tool
330352
tool_results = await asyncio.gather(
331353
*[execute_single_tool(tc, parallel_group_id) for tc in tool_calls]
332354
)
333355

356+
# Check for immediate cancellation - synthesize results for any pending tools
357+
if coordinator and coordinator.cancellation.is_immediate:
358+
# Any tools that didn't complete will have been handled by gather
359+
# Just break out of the loop
360+
await hooks.emit(
361+
ORCHESTRATOR_COMPLETE,
362+
{
363+
"orchestrator": "loop-basic",
364+
"turn_count": iteration,
365+
"status": "cancelled",
366+
},
367+
)
368+
return final_content
369+
334370
# Add all tool results to context in original order (deterministic)
335371
for tool_call_id, content in tool_results:
336372
if hasattr(context, "add_message"):

0 commit comments

Comments
 (0)