Skip to content

Commit cd89cfa

Browse files
authored
Merge pull request #584 from rumpl/move-things
Move things around
2 parents 6445526 + 26819c4 commit cd89cfa

23 files changed

Lines changed: 529 additions & 442 deletions

cmd/root/run.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ func doRunCommand(ctx context.Context, args []string, exec bool) error {
192192
runConfig.RedirectURI = "http://localhost:8083/oauth-callback"
193193
}
194194

195-
agents, err = teamloader.LoadWithOverrides(ctx, agentFilename, runConfig, modelOverrides)
195+
agents, err = teamloader.Load(ctx, agentFilename, runConfig, teamloader.WithModelOverrides(modelOverrides))
196196
if err != nil {
197197
return err
198198
}

pkg/config/config.go

Lines changed: 20 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package config
22

33
import (
4+
"context"
45
"fmt"
56
"os"
67
"path/filepath"
@@ -11,6 +12,7 @@ import (
1112
v0 "github.com/docker/cagent/pkg/config/v0"
1213
v1 "github.com/docker/cagent/pkg/config/v1"
1314
latest "github.com/docker/cagent/pkg/config/v2"
15+
"github.com/docker/cagent/pkg/environment"
1416
"github.com/docker/cagent/pkg/filesystem"
1517
)
1618

@@ -55,6 +57,24 @@ func LoadConfig(path string, fs filesystem.FS) (*latest.Config, error) {
5557
return &config, nil
5658
}
5759

60+
// CheckRequiredEnvVars checks which environment variables are required by the models and tools.
61+
//
62+
// This allows exiting early with a proper error message instead of failing later when trying to use a model or tool.
63+
func CheckRequiredEnvVars(ctx context.Context, cfg *latest.Config, env environment.Provider, runtimeConfig RuntimeConfig) error {
64+
requiredEnv, err := GatherMissingEnvVars(ctx, cfg, env, runtimeConfig)
65+
if err != nil {
66+
return fmt.Errorf("gathering required environment variables: %w", err)
67+
}
68+
69+
if len(requiredEnv) == 0 {
70+
return nil
71+
}
72+
73+
return &environment.RequiredEnvError{
74+
Missing: requiredEnv,
75+
}
76+
}
77+
5878
func parseCurrentVersion(dir string, data []byte, version any) (any, error) {
5979
options := []yaml.DecodeOption{yaml.Strict(), yaml.ReferenceDirs(dir)}
6080

@@ -142,58 +162,3 @@ func validateConfig(cfg *latest.Config) error {
142162
func boolPtr(b bool) *bool {
143163
return &b
144164
}
145-
146-
func ValidatePathInDirectory(path, allowedDir string) (string, error) {
147-
if path == "" {
148-
return "", fmt.Errorf("empty path")
149-
}
150-
151-
cleanPath := filepath.Clean(path)
152-
153-
if cleanPath == "" || cleanPath == "." {
154-
return "", fmt.Errorf("empty or invalid path")
155-
}
156-
157-
if filepath.IsAbs(cleanPath) && allowedDir == "" {
158-
if strings.Contains(path, "..") {
159-
return "", fmt.Errorf("path contains directory traversal sequences")
160-
}
161-
return cleanPath, nil
162-
}
163-
164-
if allowedDir == "" {
165-
if strings.HasPrefix(cleanPath, "..") {
166-
return "", fmt.Errorf("path contains directory traversal sequences")
167-
}
168-
return cleanPath, nil
169-
}
170-
171-
cleanAllowedDir := filepath.Clean(allowedDir)
172-
absAllowedDir, err := filepath.Abs(cleanAllowedDir)
173-
if err != nil {
174-
return "", fmt.Errorf("invalid allowed directory: %w", err)
175-
}
176-
177-
var targetPath string
178-
if filepath.IsAbs(cleanPath) {
179-
targetPath = cleanPath
180-
} else {
181-
targetPath = filepath.Join(absAllowedDir, cleanPath)
182-
}
183-
184-
absTargetPath, err := filepath.Abs(targetPath)
185-
if err != nil {
186-
return "", fmt.Errorf("invalid path: %w", err)
187-
}
188-
189-
relPath, err := filepath.Rel(absAllowedDir, absTargetPath)
190-
if err != nil {
191-
return "", fmt.Errorf("cannot determine relative path: %w", err)
192-
}
193-
194-
if strings.HasPrefix(relPath, "..") {
195-
return "", fmt.Errorf("path outside allowed directory: %s", path)
196-
}
197-
198-
return absTargetPath, nil
199-
}

pkg/config/config_test.go

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package config
22

33
import (
4+
"context"
45
"os"
56
"testing"
67

78
"github.com/stretchr/testify/assert"
89
"github.com/stretchr/testify/require"
10+
11+
latest "github.com/docker/cagent/pkg/config/v2"
12+
"github.com/docker/cagent/pkg/environment"
913
)
1014

1115
func TestAutoRegisterModels(t *testing.T) {
@@ -188,3 +192,235 @@ func openRoot(t *testing.T, dir string) *os.Root {
188192

189193
return root
190194
}
195+
196+
type noEnvProvider struct{}
197+
198+
func (p *noEnvProvider) Get(context.Context, string) string { return "" }
199+
200+
func TestCheckRequiredEnvVars(t *testing.T) {
201+
tests := []struct {
202+
yaml string
203+
expectedMissing []string
204+
}{
205+
{
206+
yaml: "openai_inline.yaml",
207+
expectedMissing: []string{"OPENAI_API_KEY"},
208+
},
209+
{
210+
yaml: "anthropic_inline.yaml",
211+
expectedMissing: []string{"ANTHROPIC_API_KEY"},
212+
},
213+
{
214+
yaml: "google_inline.yaml",
215+
expectedMissing: []string{"GOOGLE_API_KEY"},
216+
},
217+
{
218+
yaml: "dmr_inline.yaml",
219+
expectedMissing: []string{},
220+
},
221+
{
222+
yaml: "openai_model.yaml",
223+
expectedMissing: []string{"OPENAI_API_KEY"},
224+
},
225+
{
226+
yaml: "anthropic_model.yaml",
227+
expectedMissing: []string{"ANTHROPIC_API_KEY"},
228+
},
229+
{
230+
yaml: "google_model.yaml",
231+
expectedMissing: []string{"GOOGLE_API_KEY"},
232+
},
233+
{
234+
yaml: "dmr_model.yaml",
235+
expectedMissing: []string{},
236+
},
237+
{
238+
yaml: "all.yaml",
239+
expectedMissing: []string{"ANTHROPIC_API_KEY", "GOOGLE_API_KEY", "OPENAI_API_KEY"},
240+
},
241+
}
242+
for _, test := range tests {
243+
t.Run(test.yaml, func(t *testing.T) {
244+
t.Parallel()
245+
246+
root := openRoot(t, "testdata/env")
247+
248+
cfg, err := LoadConfig(test.yaml, root)
249+
require.NoError(t, err)
250+
251+
err = CheckRequiredEnvVars(t.Context(), cfg, &noEnvProvider{}, RuntimeConfig{})
252+
253+
if len(test.expectedMissing) == 0 {
254+
require.NoError(t, err)
255+
} else {
256+
require.Error(t, err)
257+
assert.Equal(t, test.expectedMissing, err.(*environment.RequiredEnvError).Missing)
258+
}
259+
})
260+
}
261+
}
262+
263+
func TestCheckRequiredEnvVarsWithModelGateway(t *testing.T) {
264+
t.Parallel()
265+
266+
root := openRoot(t, "testdata/env")
267+
268+
cfg, err := LoadConfig("all.yaml", root)
269+
require.NoError(t, err)
270+
271+
err = CheckRequiredEnvVars(t.Context(), cfg, &noEnvProvider{}, RuntimeConfig{
272+
ModelsGateway: "gateway:8080",
273+
})
274+
275+
require.NoError(t, err)
276+
}
277+
278+
func TestApplyModelOverrides(t *testing.T) {
279+
tests := []struct {
280+
name string
281+
agents map[string]latest.AgentConfig
282+
overrides []string
283+
expected map[string]string // agent name -> expected model
284+
expectError bool
285+
errorMsg string
286+
}{
287+
{
288+
name: "global override",
289+
agents: map[string]latest.AgentConfig{
290+
"root": {Model: "openai/gpt-4"},
291+
"other": {Model: "anthropic/claude-3"},
292+
},
293+
overrides: []string{"google/gemini-pro"},
294+
expected: map[string]string{
295+
"root": "google/gemini-pro",
296+
"other": "google/gemini-pro",
297+
},
298+
},
299+
{
300+
name: "single per-agent override",
301+
agents: map[string]latest.AgentConfig{
302+
"root": {Model: "openai/gpt-4"},
303+
"other": {Model: "anthropic/claude-3"},
304+
},
305+
overrides: []string{"other=google/gemini-pro"},
306+
expected: map[string]string{
307+
"root": "openai/gpt-4",
308+
"other": "google/gemini-pro",
309+
},
310+
},
311+
{
312+
name: "multiple separate flags",
313+
agents: map[string]latest.AgentConfig{
314+
"root": {Model: "openai/gpt-4"},
315+
"other": {Model: "anthropic/claude-3"},
316+
},
317+
overrides: []string{"root=openai/gpt-5", "other=anthropic/claude-sonnet-4-0"},
318+
expected: map[string]string{
319+
"root": "openai/gpt-5",
320+
"other": "anthropic/claude-sonnet-4-0",
321+
},
322+
},
323+
{
324+
name: "comma-separated format",
325+
agents: map[string]latest.AgentConfig{
326+
"root": {Model: "openai/gpt-4"},
327+
"other": {Model: "anthropic/claude-3"},
328+
"third": {Model: "google/gemini-pro"},
329+
},
330+
overrides: []string{"root=openai/gpt-5,other=anthropic/claude-sonnet-4-0"},
331+
expected: map[string]string{
332+
"root": "openai/gpt-5",
333+
"other": "anthropic/claude-sonnet-4-0",
334+
"third": "google/gemini-pro",
335+
},
336+
},
337+
{
338+
name: "mixed formats",
339+
agents: map[string]latest.AgentConfig{
340+
"root": {Model: "openai/gpt-4"},
341+
"other": {Model: "anthropic/claude-3"},
342+
"third": {Model: "google/gemini-pro"},
343+
"reviewer": {Model: "openai/gpt-3.5-turbo"},
344+
},
345+
overrides: []string{"root=openai/gpt-5,other=anthropic/claude-4", "reviewer=google/gemini-1.5-pro"},
346+
expected: map[string]string{
347+
"root": "openai/gpt-5",
348+
"other": "anthropic/claude-4",
349+
"third": "google/gemini-pro",
350+
"reviewer": "google/gemini-1.5-pro",
351+
},
352+
},
353+
{
354+
name: "last override wins",
355+
agents: map[string]latest.AgentConfig{
356+
"root": {Model: "openai/gpt-4"},
357+
},
358+
overrides: []string{"root=openai/gpt-5", "root=anthropic/claude-4"},
359+
expected: map[string]string{
360+
"root": "anthropic/claude-4",
361+
},
362+
},
363+
{
364+
name: "unknown agent error",
365+
agents: map[string]latest.AgentConfig{
366+
"root": {Model: "openai/gpt-4"},
367+
},
368+
overrides: []string{"nonexistent=openai/gpt-5"},
369+
expectError: true,
370+
errorMsg: "unknown agent 'nonexistent'",
371+
},
372+
{
373+
name: "empty model spec error",
374+
agents: map[string]latest.AgentConfig{
375+
"root": {Model: "openai/gpt-4"},
376+
},
377+
overrides: []string{"root="},
378+
expectError: true,
379+
errorMsg: "empty model specification in override: root=",
380+
},
381+
{
382+
name: "empty global model spec is skipped",
383+
agents: map[string]latest.AgentConfig{
384+
"root": {Model: "openai/gpt-4"},
385+
},
386+
overrides: []string{""},
387+
expected: map[string]string{
388+
"root": "openai/gpt-4",
389+
},
390+
},
391+
{
392+
name: "whitespace handling",
393+
agents: map[string]latest.AgentConfig{
394+
"root": {Model: "openai/gpt-4"},
395+
"other": {Model: "anthropic/claude-3"},
396+
},
397+
overrides: []string{" root = openai/gpt-5 , other = anthropic/claude-4 "},
398+
expected: map[string]string{
399+
"root": "openai/gpt-5",
400+
"other": "anthropic/claude-4",
401+
},
402+
},
403+
}
404+
405+
for _, tt := range tests {
406+
t.Run(tt.name, func(t *testing.T) {
407+
t.Parallel()
408+
409+
cfg := &latest.Config{
410+
Agents: tt.agents,
411+
Models: make(map[string]latest.ModelConfig),
412+
}
413+
414+
err := ApplyModelOverrides(cfg, tt.overrides)
415+
416+
if tt.expectError {
417+
require.ErrorContains(t, err, tt.errorMsg)
418+
} else {
419+
require.NoError(t, err)
420+
for agentName, expectedModel := range tt.expected {
421+
assert.Equal(t, expectedModel, cfg.Agents[agentName].Model)
422+
}
423+
}
424+
})
425+
}
426+
}
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
package secrets
1+
package config
22

33
import (
44
"context"
55
"fmt"
66
"sort"
77

8-
"github.com/docker/cagent/pkg/config"
98
latest "github.com/docker/cagent/pkg/config/v2"
109
"github.com/docker/cagent/pkg/environment"
1110
"github.com/docker/cagent/pkg/gateway"
@@ -15,7 +14,7 @@ import (
1514
// GatherMissingEnvVars finds out which environment variables are required by the models and tools.
1615
// This allows exiting early with a proper error message instead of failing later when trying to use a model or tool.
1716
// TODO(dga): This code contains lots of duplication and ought to be refactored.
18-
func GatherMissingEnvVars(ctx context.Context, cfg *latest.Config, env environment.Provider, runtimeConfig config.RuntimeConfig) ([]string, error) {
17+
func GatherMissingEnvVars(ctx context.Context, cfg *latest.Config, env environment.Provider, runtimeConfig RuntimeConfig) ([]string, error) {
1918
requiredEnv := map[string]bool{}
2019

2120
// Models
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)