Skip to content

Commit adcfafb

Browse files
authored
Merge pull request #2531 from dgageot/board/simplifying-hooks-handling-7ba687cc
refactor(hooks): simplify caching, dispatch flow, and notification helpers
2 parents bc4893c + d929de7 commit adcfafb

4 files changed

Lines changed: 101 additions & 124 deletions

File tree

pkg/hooks/executor.go

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,12 @@ type Executor struct {
2929
// matcher is the compiled form of a [MatcherConfig]: an optional regex
3030
// pattern (nil means "match all") and the hooks to fire when it matches.
3131
type matcher struct {
32-
raw string
3332
pattern *regexp.Regexp
3433
hooks []Hook
3534
}
3635

3736
func (m *matcher) matches(toolName string) bool {
38-
if m.raw == "" || m.raw == "*" {
39-
return true
40-
}
41-
return m.pattern != nil && m.pattern.MatchString(toolName)
37+
return m.pattern == nil || m.pattern.MatchString(toolName)
4238
}
4339

4440
// hookResult is the outcome of a single hook invocation.
@@ -103,7 +99,7 @@ func compileMatchers(configs []MatcherConfig) []matcher {
10399
}
104100
out := make([]matcher, 0, len(configs))
105101
for _, mc := range configs {
106-
m := matcher{raw: mc.Matcher, hooks: mc.Hooks}
102+
m := matcher{hooks: mc.Hooks}
107103
if mc.Matcher != "" && mc.Matcher != "*" {
108104
p, err := regexp.Compile("^(?:" + mc.Matcher + ")$")
109105
if err != nil {
@@ -127,36 +123,17 @@ func (e *Executor) Has(event EventType) bool {
127123
// don't have to remember. Defaults [Input.Cwd] to the executor's
128124
// working directory when the caller didn't supply one.
129125
func (e *Executor) Dispatch(ctx context.Context, event EventType, input *Input) (*Result, error) {
130-
matchers := e.events[event]
131-
if len(matchers) == 0 {
126+
hooks := e.hooksFor(event, input.ToolName)
127+
if len(hooks) == 0 {
132128
return &Result{Allowed: true}, nil
133129
}
130+
134131
input.HookEventName = event
135132
if input.Cwd == "" {
136133
input.Cwd = e.workingDir
137134
}
138135

139-
// Collect, filter by tool name, and dedup by (type, command, args).
140-
// Dedup catches the common case of an explicit YAML hook overlapping
141-
// a runtime auto-injected one (e.g. WithAddDate plus a user-authored
142-
// add_date entry).
143-
seen := make(map[string]bool)
144-
var hooks []Hook
145-
for _, m := range matchers {
146-
if !m.matches(input.ToolName) {
147-
continue
148-
}
149-
for _, h := range m.hooks {
150-
key := dedupKey(h)
151-
if !seen[key] {
152-
seen[key] = true
153-
hooks = append(hooks, h)
154-
}
155-
}
156-
}
157-
if len(hooks) == 0 {
158-
return &Result{Allowed: true}, nil
159-
}
136+
slog.Debug("Executing hooks", "event", event, "session_id", input.SessionID, "count", len(hooks))
160137

161138
inputJSON, err := input.ToJSON()
162139
if err != nil {
@@ -173,6 +150,29 @@ func (e *Executor) Dispatch(ctx context.Context, event EventType, input *Input)
173150
return aggregate(results, event), nil
174151
}
175152

153+
// hooksFor returns the deduplicated list of hooks that should run for
154+
// (event, toolName). Dedup by (type, command, args) catches the common
155+
// case of an explicit YAML hook overlapping a runtime auto-injected
156+
// one (e.g. WithAddDate plus a user-authored add_date entry).
157+
func (e *Executor) hooksFor(event EventType, toolName string) []Hook {
158+
seen := make(map[string]bool)
159+
var hooks []Hook
160+
for _, m := range e.events[event] {
161+
if !m.matches(toolName) {
162+
continue
163+
}
164+
for _, h := range m.hooks {
165+
key := dedupKey(h)
166+
if seen[key] {
167+
continue
168+
}
169+
seen[key] = true
170+
hooks = append(hooks, h)
171+
}
172+
}
173+
return hooks
174+
}
175+
176176
// dedupKey returns a deterministic key identifying a hook by (type, command, args).
177177
func dedupKey(h Hook) string {
178178
var b strings.Builder

pkg/runtime/hooks.go

Lines changed: 60 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -11,60 +11,50 @@ import (
1111
"github.com/docker/docker-agent/pkg/session"
1212
)
1313

14-
// hooksExec returns the cached [hooks.Executor] for a, building one on
15-
// first lookup. Returns nil when the agent has no user-configured hooks
16-
// and no agent-flag (AddDate / AddEnvironmentInfo / AddPromptFiles) maps
17-
// to a builtin. Callers can then short-circuit without paying for a
18-
// no-op dispatch.
14+
// buildHooksExecutors builds a [hooks.Executor] for every agent in the
15+
// team that has user-configured hooks or an agent-flag that maps to a
16+
// builtin (AddDate / AddEnvironmentInfo / AddPromptFiles). Agents with
17+
// no hooks have no entry; lookups fall through to nil so callers can
18+
// short-circuit cheaply.
1919
//
20-
// The cache is keyed by agent name. Entries (including the nil sentinel)
21-
// are stable for the lifetime of the runtime, so repeated dispatches
22-
// during a turn don't re-translate agent flags into builtin hook entries
23-
// or rebuild matcher tables.
20+
// Called once from [NewLocalRuntime] after r.workingDir, r.env and
21+
// r.hooksRegistry are finalized; the resulting map is read-only for
22+
// the lifetime of the runtime, so per-dispatch lookups don't need to
23+
// lock.
24+
func (r *LocalRuntime) buildHooksExecutors() {
25+
r.hooksExecByAgent = make(map[string]*hooks.Executor)
26+
for _, name := range r.team.AgentNames() {
27+
a, err := r.team.Agent(name)
28+
if err != nil {
29+
continue
30+
}
31+
cfg := builtins.ApplyAgentDefaults(hooks.FromConfig(a.Hooks()), builtins.AgentDefaults{
32+
AddDate: a.AddDate(),
33+
AddEnvironmentInfo: a.AddEnvironmentInfo(),
34+
AddPromptFiles: a.AddPromptFiles(),
35+
})
36+
if cfg == nil {
37+
continue
38+
}
39+
r.hooksExecByAgent[name] = hooks.NewExecutorWithRegistry(cfg, r.workingDir, r.env, r.hooksRegistry)
40+
}
41+
}
42+
43+
// hooksExec returns the pre-built [hooks.Executor] for a, or nil when
44+
// the agent has no hooks (see [buildHooksExecutors]).
2445
func (r *LocalRuntime) hooksExec(a *agent.Agent) *hooks.Executor {
2546
if a == nil {
2647
return nil
2748
}
28-
name := a.Name()
29-
30-
r.hooksExecMu.RLock()
31-
if exec, ok := r.hooksExecByAgent[name]; ok {
32-
r.hooksExecMu.RUnlock()
33-
return exec
34-
}
35-
r.hooksExecMu.RUnlock()
36-
37-
r.hooksExecMu.Lock()
38-
defer r.hooksExecMu.Unlock()
39-
// Re-check under the write lock to avoid double-build under contention.
40-
if exec, ok := r.hooksExecByAgent[name]; ok {
41-
return exec
42-
}
43-
44-
cfg := builtins.ApplyAgentDefaults(hooks.FromConfig(a.Hooks()), builtins.AgentDefaults{
45-
AddDate: a.AddDate(),
46-
AddEnvironmentInfo: a.AddEnvironmentInfo(),
47-
AddPromptFiles: a.AddPromptFiles(),
48-
})
49-
50-
var exec *hooks.Executor
51-
if cfg != nil {
52-
exec = hooks.NewExecutorWithRegistry(cfg, r.workingDir, r.env, r.hooksRegistry)
53-
}
54-
if r.hooksExecByAgent == nil {
55-
r.hooksExecByAgent = make(map[string]*hooks.Executor)
56-
}
57-
r.hooksExecByAgent[name] = exec
58-
return exec
49+
return r.hooksExecByAgent[a.Name()]
5950
}
6051

6152
// dispatchHook is the common dispatch path shared by every hook
62-
// callsite: resolve the cached executor, short-circuit if no hook is
63-
// configured for event, then dispatch and emit any [Result.SystemMessage]
64-
// as a Warning event. Errors are logged at warn level and surfaced as
65-
// nil results so callers can use a single nil check to mean "nothing
66-
// useful came back" — covering the not-configured, no-agent, and
67-
// dispatch-failed cases uniformly.
53+
// callsite: resolve the pre-built executor, dispatch, and emit any
54+
// [Result.SystemMessage] as a Warning event. Errors are logged at warn
55+
// level and surfaced as nil results so callers can use a single nil
56+
// check to mean "nothing useful came back" — covering the
57+
// not-configured, no-agent, and dispatch-failed cases uniformly.
6858
//
6959
// events may be nil for fire-and-forget callsites (notification,
7060
// on_error, on_max_iterations, ...) where there's no Warning channel
@@ -79,11 +69,10 @@ func (r *LocalRuntime) dispatchHook(
7969
events chan Event,
8070
) *hooks.Result {
8171
exec := r.hooksExec(a)
82-
if exec == nil || !exec.Has(event) {
72+
if exec == nil {
8373
return nil
8474
}
8575

86-
slog.Debug("Executing hooks", "event", event, "agent", a.Name(), "session_id", input.SessionID)
8776
result, err := exec.Dispatch(ctx, event, input)
8877
if err != nil {
8978
slog.Warn("Hook execution failed", "event", event, "agent", a.Name(), "error", err)
@@ -150,41 +139,31 @@ func (r *LocalRuntime) executeStopHooks(ctx context.Context, sess *session.Sessi
150139
}, events)
151140
}
152141

153-
// executeNotificationHooks runs notification hooks when the agent emits
154-
// a user-facing notification. Hook output is informational — it does
155-
// not suppress or rewrite the notification.
156-
func (r *LocalRuntime) executeNotificationHooks(ctx context.Context, a *agent.Agent, sessionID, level, message string) {
157-
if level != "error" && level != "warning" {
158-
slog.Error("Invalid notification level", "level", level, "expected", "error|warning")
159-
return
160-
}
161-
r.dispatchHook(ctx, a, hooks.EventNotification, &hooks.Input{
162-
SessionID: sessionID,
163-
NotificationLevel: level,
164-
NotificationMessage: message,
165-
}, nil)
142+
// notifyError fires both notification(level=error) and on_error in one
143+
// call. They're always emitted together (an error is always also a
144+
// user-facing notification), so collapsing them into one call expresses
145+
// intent more directly than firing two events at every callsite.
146+
func (r *LocalRuntime) notifyError(ctx context.Context, a *agent.Agent, sessionID, message string) {
147+
r.notify(ctx, a, hooks.EventNotification, sessionID, "error", message)
148+
r.notify(ctx, a, hooks.EventOnError, sessionID, "error", message)
166149
}
167150

168-
// executeOnErrorHooks fires on_error when the runtime hits an error
169-
// during a turn (model failures, tool-call loops). Fires alongside the
170-
// broader notification event; on_error is the structured entry point
171-
// for users who want to react only to errors.
172-
func (r *LocalRuntime) executeOnErrorHooks(ctx context.Context, a *agent.Agent, sessionID, message string) {
173-
r.dispatchHook(ctx, a, hooks.EventOnError, &hooks.Input{
174-
SessionID: sessionID,
175-
NotificationLevel: "error",
176-
NotificationMessage: message,
177-
}, nil)
151+
// notifyMaxIterations fires both notification(level=warning) and
152+
// on_max_iterations. Same rationale as [notifyError]: the two are
153+
// always emitted together when the iteration limit is reached.
154+
func (r *LocalRuntime) notifyMaxIterations(ctx context.Context, a *agent.Agent, sessionID, message string) {
155+
r.notify(ctx, a, hooks.EventNotification, sessionID, "warning", message)
156+
r.notify(ctx, a, hooks.EventOnMaxIterations, sessionID, "warning", message)
178157
}
179158

180-
// executeOnMaxIterationsHooks fires on_max_iterations when the runtime
181-
// reaches its configured max_iterations limit. Fires alongside the
182-
// broader notification event; on_max_iterations is the structured entry
183-
// point for users who want to react only to that condition.
184-
func (r *LocalRuntime) executeOnMaxIterationsHooks(ctx context.Context, a *agent.Agent, sessionID, message string) {
185-
r.dispatchHook(ctx, a, hooks.EventOnMaxIterations, &hooks.Input{
159+
// notify is the shared dispatch path for the (level, message)-shaped
160+
// hook events: notification, on_error, on_max_iterations. They all
161+
// take the same Input fields and are observational (no Result is
162+
// honored), so a single helper covers them all.
163+
func (r *LocalRuntime) notify(ctx context.Context, a *agent.Agent, event hooks.EventType, sessionID, level, message string) {
164+
r.dispatchHook(ctx, a, event, &hooks.Input{
186165
SessionID: sessionID,
187-
NotificationLevel: "warning",
166+
NotificationLevel: level,
188167
NotificationMessage: message,
189168
}, nil)
190169
}
@@ -215,11 +194,10 @@ func (r *LocalRuntime) executeAfterLLMCallHooks(ctx context.Context, sess *sessi
215194

216195
// executeOnUserInputHooks fires on_user_input when the runtime is about
217196
// to wait for the user (tool confirmation, elicitation, max iterations,
218-
// stream stopped). Resolves the agent from r.team itself so callsites
219-
// in code paths without an agent handle (like the elicitation handler)
220-
// stay short.
197+
// stream stopped). Resolves the agent itself so callsites in code paths
198+
// without an agent handle (like the elicitation handler) stay short.
221199
func (r *LocalRuntime) executeOnUserInputHooks(ctx context.Context, sessionID, logContext string) {
222-
a, _ := r.team.Agent(r.CurrentAgentName())
200+
a := r.CurrentAgent()
223201
if a == nil {
224202
return
225203
}

pkg/runtime/loop.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
252252
events <- MaxIterationsReached(runtimeMaxIterations)
253253

254254
maxIterMsg := fmt.Sprintf("Maximum iterations reached (%d)", runtimeMaxIterations)
255-
r.executeNotificationHooks(ctx, a, sess.ID, "warning", maxIterMsg)
256-
r.executeOnMaxIterationsHooks(ctx, a, sess.ID, maxIterMsg)
255+
r.notifyMaxIterations(ctx, a, sess.ID, maxIterMsg)
257256
r.executeOnUserInputHooks(ctx, sess.ID, "max iterations reached")
258257

259258
// In non-interactive mode (e.g. MCP server), auto-stop instead of
@@ -434,8 +433,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
434433
telemetry.RecordError(ctx, err.Error())
435434
errMsg := modelerrors.FormatError(err)
436435
events <- Error(errMsg)
437-
r.executeNotificationHooks(ctx, a, sess.ID, "error", errMsg)
438-
r.executeOnErrorHooks(ctx, a, sess.ID, errMsg)
436+
r.notifyError(ctx, a, sess.ID, errMsg)
439437
streamSpan.End()
440438
return
441439
}
@@ -501,8 +499,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
501499
"This indicates a degenerate loop where the model is not making progress.",
502500
loopDetector.consecutive, toolName)
503501
events <- Error(errMsg)
504-
r.executeNotificationHooks(ctx, a, sess.ID, "error", errMsg)
505-
r.executeOnErrorHooks(ctx, a, sess.ID, errMsg)
502+
r.notifyError(ctx, a, sess.ID, errMsg)
506503
loopDetector.reset()
507504
return
508505
}

pkg/runtime/runtime.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,12 @@ type LocalRuntime struct {
206206
// touching any process-wide state.
207207
hooksRegistry *hooks.Registry
208208

209-
// hooksExecByAgent caches the per-agent [hooks.Executor], keyed by
210-
// agent name. Building one requires translating agent flags into
211-
// implicit builtin hook entries and compiling matchers — work we
212-
// don't want to repeat on every dispatch within a turn. Entries are
213-
// stable for the runtime's lifetime; a nil value caches the
214-
// "no hooks configured" verdict so repeat lookups stay cheap.
209+
// hooksExecByAgent holds the per-agent [hooks.Executor], keyed by
210+
// agent name. Built once in [NewLocalRuntime.buildHooksExecutors]
211+
// after team and runtime config are finalized; agents with no hooks
212+
// have no entry, so [hooksExec] returns nil for them. Read-only after
213+
// construction, so no locking is needed.
215214
hooksExecByAgent map[string]*hooks.Executor
216-
hooksExecMu sync.RWMutex
217215

218216
// retryOnRateLimit enables retry-with-backoff for HTTP 429 (rate limit) errors
219217
// when no fallback models are configured. When false (default), 429 errors are
@@ -394,6 +392,10 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
394392
// RunStream on the same runtime (e.g. background agent sessions).
395393
r.registerDefaultTools()
396394

395+
// Pre-build per-agent hook executors now that workingDir, env and
396+
// the team are finalized. Read-only afterwards.
397+
r.buildHooksExecutors()
398+
397399
slog.Debug("Creating new runtime", "agent", r.currentAgent, "available_agents", agents.Size())
398400

399401
return r, nil

0 commit comments

Comments
 (0)