Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pkg/api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,25 @@ type SessionResponse struct {
OutputTokens int64 `json:"output_tokens"`
WorkingDir string `json:"working_dir,omitempty"`
Permissions *session.PermissionsConfig `json:"permissions,omitempty"`
Mode session.Mode `json:"mode,omitempty"`
}

// UpdateSessionPermissionsRequest represents a request to update session permissions.
type UpdateSessionPermissionsRequest struct {
Permissions *session.PermissionsConfig `json:"permissions"`
}

// UpdateSessionModeRequest represents a request to update a session's mode.
type UpdateSessionModeRequest struct {
Mode session.Mode `json:"mode"`
}

// UpdateSessionModeResponse represents the response from updating a session's mode.
type UpdateSessionModeResponse struct {
ID string `json:"id"`
Mode session.Mode `json:"mode"`
}

// ResumeSessionRequest represents a request to resume a session
type ResumeSessionRequest struct {
Confirmation string `json:"confirmation"`
Expand Down Expand Up @@ -304,6 +316,7 @@ type SessionSnapshotResponse struct {
Messages []session.Message `json:"messages"`
ToolsApproved bool `json:"tools_approved"`
Permissions *session.PermissionsConfig `json:"permissions,omitempty"`
Mode session.Mode `json:"mode,omitempty"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`

Expand Down
3 changes: 2 additions & 1 deletion pkg/runtime/harness.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ func (r *LocalRuntime) runHarnessAgent(ctx context.Context, sess *session.Sessio
}()

turnStartMsgs := r.executeTurnStartHooks(ctx, sess, a, events)
messages := sess.GetMessages(a, append(baseExtra, turnStartMsgs...)...)
planReminder := planModeReminderMessages(sess)
Comment thread
trungutt marked this conversation as resolved.
messages := sess.GetMessages(a, append(append(baseExtra, turnStartMsgs...), planReminder...)...)
stop, msg, rewritten := r.executeBeforeLLMCallHooks(ctx, sess, a, modelID, 1, messages)
if stop {
slog.WarnContext(ctx, "before_llm_call hook signalled run termination",
Expand Down
41 changes: 37 additions & 4 deletions pkg/runtime/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session,
sink.Emit(ErrorWithCode(ErrorCodeToolFailed, fmt.Sprintf("failed to get tools: %v", err)))
return
}
agentTools = filterExcludedTools(agentTools, sess.ExcludedTools)
agentTools = filterToolsForSession(agentTools, sess)

sink.Emit(ToolsetInfo(len(agentTools), false, a.Name()))

Expand Down Expand Up @@ -348,7 +348,7 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session,
sink.Emit(ErrorWithCode(ErrorCodeToolFailed, fmt.Sprintf("failed to get tools: %v", err)))
return
}
agentTools = filterExcludedTools(agentTools, sess.ExcludedTools)
agentTools = filterToolsForSession(agentTools, sess)

// Emit updated tool count. After a ToolListChanged MCP notification
// the cache is invalidated, so getTools above re-fetches from the
Expand Down Expand Up @@ -554,7 +554,13 @@ func (r *LocalRuntime) runTurn(
// files) refresh every turn while session-level context (cwd, OS,
// arch) stays stable — all without bloating the stored history.
turnStartMsgs := r.executeTurnStartHooks(ctx, sess, a, events)
messages := sess.GetMessages(a, slices.Concat(ls.sessionStartMsgs, ls.userPromptMsgs, turnStartMsgs)...)
// Plan-mode reminder rides alongside the turn_start hook output so it
// participates in the same per-turn splice (and the cache_control marker
// that GetMessages applies to the last extra). It is appended last so its
// instruction is the most recent system context the model sees before the
// user prompt — minimising the chance the model ignores it.
planReminder := planModeReminderMessages(sess)
messages := sess.GetMessages(a, slices.Concat(ls.sessionStartMsgs, ls.userPromptMsgs, turnStartMsgs, planReminder)...)
slog.DebugContext(ctx, "Retrieved messages for processing", "agent", a.Name(), "message_count", len(messages))

// before_llm_call hooks fire just before the model is invoked.
Expand Down Expand Up @@ -990,6 +996,33 @@ func filterExcludedTools(agentTools []tools.Tool, excluded []string) []tools.Too
return filtered
}

// filterToolsForSession applies all session-level tool filters: the explicit
// ExcludedTools name list (used by skill sub-sessions) and, when the session
// is in plan mode, anything whose tool definition doesn't advertise
// ReadOnlyHint. The MCP spec's ReadOnlyHint is the canonical "this tool has
// no side effects" signal, so it's the right knob for plan mode and it
// extends naturally to user-added MCP tools without any per-tool config.
func filterToolsForSession(agentTools []tools.Tool, sess *session.Session) []tools.Tool {
out := filterExcludedTools(agentTools, sess.ExcludedTools)
if sess.Mode == session.ModePlan {
out = filterToReadOnlyTools(out)
}
return out
}

// filterToReadOnlyTools keeps only tools whose definition advertises
// ReadOnlyHint. Used by plan mode to hide every write/execute tool from the
// model so it can't reach for them even if the system reminder is ignored.
func filterToReadOnlyTools(agentTools []tools.Tool) []tools.Tool {
filtered := make([]tools.Tool, 0, len(agentTools))
for _, t := range agentTools {
if t.Annotations.ReadOnlyHint {
filtered = append(filtered, t)
}
}
return filtered
}

// reprobe re-runs ensureToolSetsAreStarted after a batch of tool calls.
// If new tools became available (by name-set diff), it emits a ToolsetInfo
// event to update the TUI immediately. The new tools will be picked up by
Expand All @@ -1010,7 +1043,7 @@ func (r *LocalRuntime) reprobe(
slog.WarnContext(ctx, "reprobe: getTools failed", "agent", a.Name(), "error", err)
return
}
updated = filterExcludedTools(updated, sess.ExcludedTools)
updated = filterToolsForSession(updated, sess)

// Emit any pending warnings that getTools just generated.
r.emitAgentWarnings(a, events)
Expand Down
45 changes: 45 additions & 0 deletions pkg/runtime/plan_mode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package runtime

import (
"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/session"
)

// planModeReminder is the per-turn system instruction injected when a session
// is in plan mode. Two layers enforce plan mode: the runtime hides every
// non-read-only tool from the model (see filterToolsForSession in loop.go),
// and this reminder tells the model how it should behave. Hiding the tools
// is the hard guarantee; the reminder is the explanation, so the model
// produces a useful plan instead of just bouncing off missing tools.
const planModeReminder = `<system-reminder>
You are currently in PLAN MODE.

In this mode you research the codebase, ask clarifying questions, and write a
clear, actionable plan for the user. You MUST NOT make any changes to the
system:

- No edits to files (no write, edit, create, or delete).
- No shell commands or background jobs.
- No state-changing tool calls of any kind.

Only read-only tools have been made available to you for this turn. If you try
to call a tool that isn't in your list, the user has explicitly disabled it
for planning.

End the turn by presenting the plan in your final message and asking the user
to review it. The user will switch you to BUILD MODE when they want execution
to begin.
</system-reminder>`

// planModeReminderMessages returns the system-reminder messages to splice
// before the conversation history when sess is in plan mode. Returns nil for
// other modes so callers can use it unconditionally.
func planModeReminderMessages(sess *session.Session) []chat.Message {
if sess == nil || sess.Mode != session.ModePlan {
return nil
}
return []chat.Message{{
Role: chat.MessageRoleSystem,
Content: planModeReminder,
}}
}
54 changes: 54 additions & 0 deletions pkg/runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2395,6 +2395,60 @@ func TestFilterExcludedTools(t *testing.T) {
})
}

func TestFilterToolsForSession_PlanMode(t *testing.T) {
readOnly := tools.Tool{Name: "read_file", Annotations: tools.ToolAnnotations{ReadOnlyHint: true}}
mutating := tools.Tool{Name: "write_file"}
all := []tools.Tool{readOnly, mutating, {Name: "shell"}}

t.Run("build mode keeps all tools", func(t *testing.T) {
sess := &session.Session{Mode: session.ModeBuild}
result := filterToolsForSession(all, sess)
assert.Len(t, result, 3)
})

t.Run("empty mode is treated as build", func(t *testing.T) {
// Sessions loaded before the mode column existed have Mode == "".
sess := &session.Session{}
result := filterToolsForSession(all, sess)
assert.Len(t, result, 3)
})

t.Run("plan mode keeps only read-only tools", func(t *testing.T) {
sess := &session.Session{Mode: session.ModePlan}
result := filterToolsForSession(all, sess)
assert.Len(t, result, 1)
assert.Equal(t, "read_file", result[0].Name)
})

t.Run("plan mode still respects ExcludedTools", func(t *testing.T) {
readOnly2 := tools.Tool{Name: "list_directory", Annotations: tools.ToolAnnotations{ReadOnlyHint: true}}
sess := &session.Session{
Mode: session.ModePlan,
ExcludedTools: []string{"read_file"},
}
result := filterToolsForSession([]tools.Tool{readOnly, readOnly2, mutating}, sess)
assert.Len(t, result, 1)
assert.Equal(t, "list_directory", result[0].Name)
})
}

func TestPlanModeReminderMessages(t *testing.T) {
t.Run("build mode returns nil", func(t *testing.T) {
assert.Nil(t, planModeReminderMessages(&session.Session{Mode: session.ModeBuild}))
})

t.Run("nil session returns nil", func(t *testing.T) {
assert.Nil(t, planModeReminderMessages(nil))
})

t.Run("plan mode returns a single system reminder", func(t *testing.T) {
msgs := planModeReminderMessages(&session.Session{Mode: session.ModePlan})
assert.Len(t, msgs, 1)
assert.Equal(t, chat.MessageRoleSystem, msgs[0].Role)
assert.Contains(t, msgs[0].Content, "PLAN MODE")
})
}

func TestMergeExcludedTools(t *testing.T) {
t.Run("both empty", func(t *testing.T) {
assert.Nil(t, mergeExcludedTools(nil, nil))
Expand Down
22 changes: 22 additions & 0 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func (s *Server) registerRoutes() {
group.POST("/sessions/:id/resume", s.resumeSession)
group.POST("/sessions/:id/tools/toggle", s.toggleSessionYolo)
group.PATCH("/sessions/:id/permissions", s.updateSessionPermissions)
group.PATCH("/sessions/:id/mode", s.updateSessionMode)
group.PATCH("/sessions/:id/title", s.updateSessionTitle)
group.PATCH("/sessions/:id/tokens", s.updateSessionTokens)
group.PATCH("/sessions/:id/starred", s.setSessionStarred)
Expand Down Expand Up @@ -249,6 +250,7 @@ func (s *Server) getSession(c echo.Context) error {
OutputTokens: sess.OutputTokens,
WorkingDir: sess.WorkingDir,
Permissions: sess.Permissions,
Mode: sess.Mode,
})
}

Expand Down Expand Up @@ -329,6 +331,26 @@ func (s *Server) updateSessionPermissions(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]string{"message": "session permissions updated"})
}

func (s *Server) updateSessionMode(c echo.Context) error {
sessionID := c.Param("id")
var req api.UpdateSessionModeRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid request body: %v", err))
}
if !req.Mode.IsValid() {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid mode %q; must be one of: %s, %s", req.Mode, session.ModeBuild, session.ModePlan))
}

if err := s.sm.UpdateSessionMode(c.Request().Context(), sessionID, req.Mode); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to update session mode: %v", err))
}

return c.JSON(http.StatusOK, api.UpdateSessionModeResponse{
ID: sessionID,
Mode: req.Mode,
})
}

func (s *Server) updateSessionTitle(c echo.Context) error {
sessionID := c.Param("id")
var req api.UpdateSessionTitleRequest
Expand Down
58 changes: 58 additions & 0 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,64 @@ func TestServer_UpdateSessionTitle(t *testing.T) {
assert.Equal(t, newTitle, sessionResp.Title)
}

func TestServer_UpdateSessionMode(t *testing.T) {
t.Parallel()

ctx := t.Context()
store := session.NewInMemorySessionStore()
lnPath := startServerWithStore(t, ctx, prepareAgentsDir(t), store)

// Create a session in default (build) mode.
createResp := httpDo(t, ctx, http.MethodPost, lnPath, "/api/sessions", map[string]any{})
var createdSession session.Session
unmarshal(t, createResp, &createdSession)
require.NotEmpty(t, createdSession.ID)

// Switch the session into plan mode.
patchResp := httpDo(t, ctx, http.MethodPatch, lnPath,
"/api/sessions/"+createdSession.ID+"/mode",
api.UpdateSessionModeRequest{Mode: session.ModePlan})
var modeResp api.UpdateSessionModeResponse
unmarshal(t, patchResp, &modeResp)

assert.Equal(t, createdSession.ID, modeResp.ID)
assert.Equal(t, session.ModePlan, modeResp.Mode)

// GET should reflect the new mode.
getResp := httpGET(t, ctx, lnPath, "/api/sessions/"+createdSession.ID)
var sessionResp api.SessionResponse
unmarshal(t, getResp, &sessionResp)
assert.Equal(t, session.ModePlan, sessionResp.Mode)

// Switch back to build mode.
patchResp = httpDo(t, ctx, http.MethodPatch, lnPath,
"/api/sessions/"+createdSession.ID+"/mode",
api.UpdateSessionModeRequest{Mode: session.ModeBuild})
unmarshal(t, patchResp, &modeResp)
assert.Equal(t, session.ModeBuild, modeResp.Mode)
}

func TestServer_CreateSession_AcceptsMode(t *testing.T) {
t.Parallel()

ctx := t.Context()
store := session.NewInMemorySessionStore()
lnPath := startServerWithStore(t, ctx, prepareAgentsDir(t), store)

// Creating a session with mode=plan should persist that mode.
createResp := httpDo(t, ctx, http.MethodPost, lnPath, "/api/sessions",
map[string]any{"mode": string(session.ModePlan)})
var createdSession session.Session
unmarshal(t, createResp, &createdSession)
require.NotEmpty(t, createdSession.ID)
assert.Equal(t, session.ModePlan, createdSession.Mode)

getResp := httpGET(t, ctx, lnPath, "/api/sessions/"+createdSession.ID)
var sessionResp api.SessionResponse
unmarshal(t, getResp, &sessionResp)
assert.Equal(t, session.ModePlan, sessionResp.Mode)
}

func startServerWithStore(t *testing.T, ctx context.Context, agentsDir string, store session.Store) string {
t.Helper()

Expand Down
28 changes: 28 additions & 0 deletions pkg/server/session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ func (sm *SessionManager) GetSessionSnapshot(ctx context.Context, id string) (*a
Messages: sess.GetAllMessages(),
ToolsApproved: sess.ToolsApproved,
Permissions: sess.Permissions,
Mode: sess.Mode,
InputTokens: sess.InputTokens,
OutputTokens: sess.OutputTokens,
Streaming: streaming,
Expand Down Expand Up @@ -353,6 +354,10 @@ func (sm *SessionManager) CreateSession(ctx context.Context, sessionTemplate *se
opts = append(opts, session.WithPermissions(sessionTemplate.Permissions))
}

if sessionTemplate.Mode != "" {
opts = append(opts, session.WithMode(sessionTemplate.Mode))
}

sess := session.New(opts...)

// Copy model-related fields from the template so callers can pin a
Expand Down Expand Up @@ -741,6 +746,29 @@ func (sm *SessionManager) UpdateSessionPermissions(ctx context.Context, sessionI
return sm.sessionStore.UpdateSession(ctx, sess)
}

// UpdateSessionMode updates the interaction mode (build/plan) for a session.
// If the session is actively running, it also updates the in-memory session
// object so the next turn's tool filter and plan-mode reminder see the new
// mode without having to round-trip through the store.
func (sm *SessionManager) UpdateSessionMode(ctx context.Context, sessionID string, mode session.Mode) error {
mode = session.NormalizeMode(mode)
sm.mux.Lock()
defer sm.mux.Unlock()

if rt, ok := sm.runtimeSessions.Load(sessionID); ok && rt.session != nil {
rt.session.Mode = mode
Comment thread
trungutt marked this conversation as resolved.
slog.DebugContext(ctx, "Updated mode for active session", "session_id", sessionID, "mode", mode)
return sm.sessionStore.UpdateSession(ctx, rt.session)
Comment thread
trungutt marked this conversation as resolved.
}

sess, err := sm.sessionStore.GetSession(ctx, sessionID)
if err != nil {
return err
}
sess.Mode = mode
return sm.sessionStore.UpdateSession(ctx, sess)
}

// UpdateSessionTitle updates the title for a session.
// If the session is actively running, it also updates the in-memory session
// object to prevent subsequent runtime saves from overwriting the title.
Expand Down
1 change: 1 addition & 0 deletions pkg/session/branch.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ func (s *Session) Clone() *Session {
CustomModelsUsed: cloneStringSlice(s.CustomModelsUsed),
AttachedFiles: cloneStringSlice(s.AttachedFiles),
ExcludedTools: cloneStringSlice(s.ExcludedTools),
Mode: s.Mode,
AgentName: s.AgentName,
ParentID: s.ParentID,
MessageUsageHistory: slices.Clone(s.MessageUsageHistory),
Expand Down
Loading
Loading