Skip to content

Commit c54be36

Browse files
committed
fix(runtime,hooks): review-driven fixes (duplicate before_llm_call dispatch + 4 quality)
Review of the three feature commits surfaced one critical bug introduced during the rebase plus a handful of low-impact issues. Fixes them, with a regression test for the critical one. ## Critical bug **before_llm_call hook fired twice per loop iteration.** The rebase left two identical dispatch blocks in [LocalRuntime.RunStream], so any stateful before_llm_call hook \u2014 prominently the new max_iterations builtin \u2014 advanced its counter twice per LLM call and tripped at half the configured limit. Other affected handlers: any user-authored before_llm_call hook with side effects (audit logging, cost meters). Adds pkg/runtime/before_llm_call_test.go to pin "fires exactly once per iteration". I verified the test fails on the buggy code and passes on the fix before checking either in. ## Quality fixes * pkg/hooks/builtins/json.go: sortKeys was mutating []any slices in place. Currently safe because each builtin gets a freshly-decoded Input, but it's a foot-gun for any future caller that re-uses inputs. sortKeys now returns a deep copy. * pkg/runtime/tool_dispatch.go: processToolCalls switch had three arms each ending in nearly-identical span End() / SetStatus (codes.Ok, "...") boilerplate. Pulled span finalisation up before the switch so each arm only carries the logic that's actually distinct. * pkg/hooks/builtins/git.go: gitOutput's empty-dir error included the function name ("gitOutput: empty working directory") against Go style. Trimmed. * pkg/runtime/hooks_wiring_test.go: stale comment referred to "caching by agent name"; the post-rebase eager-build doesn't cache, it just looks up. Updated. ## Validation go test ./... -> all packages pass go test -race ./pkg/hooks/... ./pkg/runtime/... -> clean golangci-lint run ./... -> 0 issues Assisted-By: docker-agent
1 parent c34390e commit c54be36

6 files changed

Lines changed: 89 additions & 31 deletions

File tree

pkg/hooks/builtins/git.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func gitOutput(ctx context.Context, dir string, args ...string) (string, error)
5757
// Defensive: every caller guards on Cwd, but bailing out
5858
// here keeps a future caller from accidentally running git
5959
// in the process's working directory.
60-
return "", errors.New("gitOutput: empty working directory")
60+
return "", errors.New("empty working directory")
6161
}
6262
full := append([]string{"-C", dir}, args...)
6363
out, err := exec.CommandContext(ctx, "git", full...).Output()

pkg/hooks/builtins/json.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ import (
66
"slices"
77
)
88

9-
// sortKeys recursively sorts map keys so [json.Marshal] produces
10-
// deterministic output regardless of how the input was constructed.
11-
// Slices are walked in place; non-collection values are returned
12-
// unchanged.
9+
// sortKeys returns a deep, deterministic copy of v with every nested
10+
// map's keys ordered. Slices and maps are copied rather than mutated
11+
// in place so the caller's input is never modified — important when
12+
// the same Input is reachable from a future hook handler.
1313
func sortKeys(v any) any {
1414
switch val := v.(type) {
1515
case map[string]any:
@@ -19,10 +19,11 @@ func sortKeys(v any) any {
1919
}
2020
return sorted
2121
case []any:
22+
copied := make([]any, len(val))
2223
for i, item := range val {
23-
val[i] = sortKeys(item)
24+
copied[i] = sortKeys(item)
2425
}
25-
return val
26+
return copied
2627
default:
2728
return v
2829
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package runtime
2+
3+
import (
4+
"context"
5+
"sync/atomic"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
11+
"github.com/docker/docker-agent/pkg/agent"
12+
"github.com/docker/docker-agent/pkg/config/latest"
13+
"github.com/docker/docker-agent/pkg/hooks"
14+
"github.com/docker/docker-agent/pkg/session"
15+
"github.com/docker/docker-agent/pkg/team"
16+
)
17+
18+
// TestBeforeLLMCallHookFiresOncePerLoopIteration is a regression test
19+
// for a duplicate dispatch in [LocalRuntime.RunStream] that fired
20+
// [LocalRuntime.executeBeforeLLMCallHooks] twice per iteration. The
21+
// bug would silently break stateful before_llm_call hooks (the
22+
// max_iterations builtin would have tripped at half its configured
23+
// limit). A single-turn session must observe exactly one fire.
24+
func TestBeforeLLMCallHookFiresOncePerLoopIteration(t *testing.T) {
25+
t.Parallel()
26+
27+
const counterName = "test-before-llm-counter"
28+
var calls atomic.Int32
29+
30+
stream := newStreamBuilder().
31+
AddContent("Hello").
32+
AddStopWithUsage(3, 2).
33+
Build()
34+
prov := &mockProvider{id: "test/mock-model", stream: stream}
35+
36+
root := agent.New("root", "test agent",
37+
agent.WithModel(prov),
38+
agent.WithHooks(&latest.HooksConfig{
39+
BeforeLLMCall: []latest.HookDefinition{
40+
{Type: "builtin", Command: counterName},
41+
},
42+
}),
43+
)
44+
tm := team.New(team.WithAgents(root))
45+
46+
rt, err := NewLocalRuntime(tm,
47+
WithSessionCompaction(false),
48+
WithModelStore(mockModelStore{}),
49+
)
50+
require.NoError(t, err)
51+
52+
// Builtin lookup happens at dispatch time, not at executor build,
53+
// so registering after NewLocalRuntime is sufficient.
54+
require.NoError(t, rt.hooksRegistry.RegisterBuiltin(
55+
counterName,
56+
func(_ context.Context, _ *hooks.Input, _ []string) (*hooks.Output, error) {
57+
calls.Add(1)
58+
return nil, nil
59+
},
60+
))
61+
62+
sess := session.New(session.WithUserMessage("hi"))
63+
sess.Title = "Unit Test"
64+
65+
for range rt.RunStream(t.Context(), sess) {
66+
}
67+
68+
assert.Equal(t, int32(1), calls.Load(),
69+
"before_llm_call must fire exactly once per loop iteration; "+
70+
"a duplicate dispatch would silently break stateful hooks like max_iterations")
71+
}

pkg/runtime/hooks_wiring_test.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,9 @@ func TestHooksExecWiresAgentFlagsToBuiltins(t *testing.T) {
8686
exec := r.hooksExec(a)
8787
require.NotNil(t, exec, "loop_detector is always-on, so an executor is always built")
8888

89-
// hooksExec caches the executor by agent name. Calling it twice
90-
// returns the same pointer, so per-turn dispatches don't pay
91-
// the matcher-compilation cost repeatedly.
92-
assert.Same(t, exec, r.hooksExec(a), "hooksExec must cache by agent name")
89+
// hooksExec is read-only after [LocalRuntime.buildHooksExecutors],
90+
// so calling it twice returns the same pointer.
91+
assert.Same(t, exec, r.hooksExec(a), "hooksExec must be stable across calls")
9392

9493
assert.Equal(t, tc.wantTurnStart, exec.Has(hooks.EventTurnStart),
9594
"turn_start activation must match flags")

pkg/runtime/loop.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -402,17 +402,6 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
402402
return
403403
}
404404

405-
// before_llm_call hooks fire just before the model is invoked.
406-
// A terminating verdict (e.g. from the max_iterations builtin)
407-
// stops the run loop here, before any tokens are spent.
408-
if stop, msg := r.executeBeforeLLMCallHooks(ctx, sess, a); stop {
409-
slog.Warn("before_llm_call hook signalled run termination",
410-
"agent", a.Name(), "session_id", sess.ID, "reason", msg)
411-
r.emitHookDrivenShutdown(ctx, a, sess, msg, events)
412-
streamSpan.End()
413-
return
414-
}
415-
416405
// Try primary model with fallback chain if configured
417406
res, usedModel, err := r.tryModelWithFallback(streamCtx, a, model, messages, agentTools, sess, m, events)
418407
if err != nil {

pkg/runtime/tool_dispatch.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,24 +84,22 @@ func (r *LocalRuntime) processToolCalls(ctx context.Context, sess *session.Sessi
8484

8585
outcome := r.executeWithApproval(callCtx, sess, toolCall, tool, events, a, invoke)
8686

87+
if outcome.canceled {
88+
callSpan.SetStatus(codes.Ok, "tool call canceled by user")
89+
} else {
90+
callSpan.SetStatus(codes.Ok, "tool call processed")
91+
}
92+
callSpan.End()
93+
8794
switch {
8895
case outcome.canceled:
89-
callSpan.SetStatus(codes.Ok, "tool call canceled by user")
90-
callSpan.End()
9196
synthesizeRemaining(calls[i+1:],
9297
"The tool call was canceled because a previous tool call in the same batch was canceled by the user.")
9398
return false, ""
94-
9599
case outcome.stopRun:
96-
callSpan.SetStatus(codes.Ok, "tool call processed")
97-
callSpan.End()
98100
synthesizeRemaining(calls[i+1:],
99101
"The tool call was skipped because a post_tool_use hook signalled run termination.")
100102
return true, outcome.stopMessage
101-
102-
default:
103-
callSpan.SetStatus(codes.Ok, "tool call processed")
104-
callSpan.End()
105103
}
106104
}
107105
return false, ""

0 commit comments

Comments
 (0)