Skip to content

Commit 92d8cd6

Browse files
authored
Merge pull request #2547 from dgageot/board/improving-provider-package-test-coverage-afd6319e
refactor(provider): improve testability and split provider.go
2 parents 672101e + d17a2c8 commit 92d8cd6

15 files changed

Lines changed: 1726 additions & 505 deletions

pkg/config/gather.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func addEnvVarsForModelConfig(model *latest.ModelConfig, customProviders map[str
106106
addEnvVarsForCoreProvider(effective, model, requiredEnv)
107107
}
108108
}
109-
} else if alias, exists := provider.Aliases[model.Provider]; exists {
109+
} else if alias, exists := provider.LookupAlias(model.Provider); exists {
110110
// Check built-in aliases
111111
if alias.TokenEnvVar != "" {
112112
requiredEnv[alias.TokenEnvVar] = true

pkg/model/provider/aliases.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
package provider
2+
3+
import (
4+
"iter"
5+
"maps"
6+
"slices"
7+
"strings"
8+
)
9+
10+
// Alias defines the configuration for a provider alias.
11+
type Alias struct {
12+
APIType string // The actual API type to use (openai, anthropic, etc.)
13+
BaseURL string // Default base URL for the provider
14+
TokenEnvVar string // Environment variable name for the API token
15+
}
16+
17+
// CoreProviders lists all natively implemented provider types.
18+
// These are the provider types that have direct implementations (not aliases).
19+
var CoreProviders = []string{
20+
"openai",
21+
"anthropic",
22+
"google",
23+
"dmr",
24+
"amazon-bedrock",
25+
}
26+
27+
// Aliases maps provider names to their corresponding configurations.
28+
//
29+
// Most consumers should call [LookupAlias] for a single lookup or [EachAlias]
30+
// to iterate, both of which keep the rest of the codebase decoupled from this
31+
// concrete map. Direct mutation of Aliases is not supported.
32+
var Aliases = map[string]Alias{
33+
"requesty": {
34+
APIType: "openai",
35+
BaseURL: "https://router.requesty.ai/v1",
36+
TokenEnvVar: "REQUESTY_API_KEY",
37+
},
38+
"azure": {
39+
APIType: "openai",
40+
TokenEnvVar: "AZURE_API_KEY",
41+
},
42+
"xai": {
43+
APIType: "openai",
44+
BaseURL: "https://api.x.ai/v1",
45+
TokenEnvVar: "XAI_API_KEY",
46+
},
47+
"nebius": {
48+
APIType: "openai",
49+
BaseURL: "https://api.studio.nebius.com/v1",
50+
TokenEnvVar: "NEBIUS_API_KEY",
51+
},
52+
"mistral": {
53+
APIType: "openai",
54+
BaseURL: "https://api.mistral.ai/v1",
55+
TokenEnvVar: "MISTRAL_API_KEY",
56+
},
57+
"ollama": {
58+
APIType: "openai",
59+
BaseURL: "http://localhost:11434/v1",
60+
},
61+
"minimax": {
62+
APIType: "openai",
63+
BaseURL: "https://api.minimax.io/v1",
64+
TokenEnvVar: "MINIMAX_API_KEY",
65+
},
66+
"github-copilot": {
67+
APIType: "openai",
68+
BaseURL: "https://api.githubcopilot.com",
69+
TokenEnvVar: "GITHUB_TOKEN",
70+
},
71+
}
72+
73+
// LookupAlias returns the Alias registered for the given name (if any).
74+
// Lookup is case-sensitive; callers that need case-insensitive matching
75+
// should normalise the name first (e.g. [strings.ToLower]).
76+
func LookupAlias(name string) (Alias, bool) {
77+
alias, ok := Aliases[name]
78+
return alias, ok
79+
}
80+
81+
// EachAlias returns an iterator over every registered (name, Alias) pair.
82+
// Iteration order is not guaranteed; callers that need a deterministic order
83+
// should sort by name.
84+
func EachAlias() iter.Seq2[string, Alias] {
85+
return func(yield func(string, Alias) bool) {
86+
for name, alias := range Aliases {
87+
if !yield(name, alias) {
88+
return
89+
}
90+
}
91+
}
92+
}
93+
94+
// AllProviders returns all known provider names (core providers + aliases),
95+
// sorted for deterministic output.
96+
func AllProviders() []string {
97+
providers := slices.Concat(CoreProviders, slices.Collect(maps.Keys(Aliases)))
98+
slices.Sort(providers)
99+
return providers
100+
}
101+
102+
// IsKnownProvider returns true if the provider name is a core provider or an alias.
103+
func IsKnownProvider(name string) bool {
104+
if slices.Contains(CoreProviders, strings.ToLower(name)) {
105+
return true
106+
}
107+
_, exists := LookupAlias(strings.ToLower(name))
108+
return exists
109+
}
110+
111+
// CatalogProviders returns the list of provider names that should be shown in the model catalog.
112+
// This includes core providers and aliases that have a defined BaseURL (self-contained endpoints).
113+
// Aliases without a BaseURL (like azure) require user configuration and are excluded.
114+
func CatalogProviders() []string {
115+
providers := make([]string, 0, len(CoreProviders)+len(Aliases))
116+
117+
// Add all core providers
118+
providers = append(providers, CoreProviders...)
119+
120+
// Add aliases that have a defined BaseURL (they work out of the box)
121+
for name, alias := range EachAlias() {
122+
if alias.BaseURL != "" {
123+
providers = append(providers, name)
124+
}
125+
}
126+
127+
return providers
128+
}
129+
130+
// IsCatalogProvider returns true if the provider name is valid for the model catalog.
131+
func IsCatalogProvider(name string) bool {
132+
// Check core providers
133+
if slices.Contains(CoreProviders, name) {
134+
return true
135+
}
136+
// Check aliases with BaseURL
137+
if alias, exists := LookupAlias(name); exists && alias.BaseURL != "" {
138+
return true
139+
}
140+
return false
141+
}

pkg/model/provider/aliases_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package provider
2+
3+
import (
4+
"maps"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestLookupAlias(t *testing.T) {
12+
t.Parallel()
13+
14+
// Every entry in the table is reachable.
15+
for name, expected := range Aliases {
16+
got, ok := LookupAlias(name)
17+
assert.True(t, ok, "alias %q should be found", name)
18+
assert.Equal(t, expected, got)
19+
}
20+
21+
// Unknown name yields the zero Alias and false.
22+
got, ok := LookupAlias("does-not-exist")
23+
assert.False(t, ok)
24+
assert.Equal(t, Alias{}, got)
25+
26+
// Lookup is case-sensitive (callers normalise themselves).
27+
if _, ok := LookupAlias("MISTRAL"); ok {
28+
t.Errorf("LookupAlias should be case-sensitive")
29+
}
30+
}
31+
32+
func TestEachAlias(t *testing.T) {
33+
t.Parallel()
34+
35+
// Iterator yields every entry exactly once.
36+
collected := maps.Collect(EachAlias())
37+
assert.Equal(t, Aliases, collected)
38+
}
39+
40+
func TestEachAlias_EarlyTermination(t *testing.T) {
41+
t.Parallel()
42+
43+
// Iterator must respect a false return from the yield function.
44+
require.NotEmpty(t, Aliases, "test requires the alias table to be non-empty")
45+
46+
count := 0
47+
for range EachAlias() {
48+
count++
49+
if count == 1 {
50+
break
51+
}
52+
}
53+
assert.Equal(t, 1, count, "iteration should stop when consumer breaks out")
54+
}

pkg/model/provider/clone.go

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,41 +4,57 @@ import (
44
"context"
55
"log/slog"
66

7+
"github.com/docker/docker-agent/pkg/config/latest"
8+
"github.com/docker/docker-agent/pkg/model/provider/base"
79
"github.com/docker/docker-agent/pkg/model/provider/options"
810
)
911

1012
// CloneWithOptions returns a new Provider instance using the same provider/model
1113
// as the base provider, applying the provided options. If cloning fails, the
1214
// original base provider is returned.
13-
func CloneWithOptions(ctx context.Context, base Provider, opts ...options.Opt) Provider {
14-
config := base.BaseConfig()
15+
func CloneWithOptions(ctx context.Context, baseProvider Provider, opts ...options.Opt) Provider {
16+
cfg := baseProvider.BaseConfig()
17+
modelConfig, mergedOpts := mergeCloneOptions(cfg, opts)
1518

19+
// Use NewWithModels to support cloning routers that reference other models.
20+
// cfg.Models is populated by routers; for other providers it's nil (which is fine).
21+
clone, err := NewWithModels(ctx, &modelConfig, cfg.Models, cfg.Env, mergedOpts...)
22+
if err != nil {
23+
slog.Debug("Failed to clone provider; using base provider", "error", err, "id", baseProvider.ID())
24+
return baseProvider
25+
}
26+
27+
return clone
28+
}
29+
30+
// mergeCloneOptions is the pure half of CloneWithOptions. Given the base
31+
// provider's configuration and the user-supplied overrides, it returns:
32+
//
33+
// - a copy of the base ModelConfig with explicit overrides applied (currently
34+
// MaxTokens and the no-thinking flag), and
35+
// - the full ordered slice of options that should be passed to NewWithModels
36+
// (existing options first, then user overrides; later opts win).
37+
//
38+
// Splitting this out from the impure NewWithModels call lets us table-test the
39+
// option-merging logic without spinning up an HTTP server.
40+
func mergeCloneOptions(cfg base.Config, opts []options.Opt) (latest.ModelConfig, []options.Opt) {
1641
// Preserve existing options, then apply overrides. Later opts take precedence.
17-
baseOpts := options.FromModelOptions(config.ModelOptions)
42+
baseOpts := options.FromModelOptions(cfg.ModelOptions)
1843
mergedOpts := append(baseOpts, opts...)
1944

20-
// Apply max_tokens override if present in options
21-
// We need to apply it to the ModelConfig itself since that's what providers use
22-
// Only update MaxTokens if an option explicitly sets it (non-zero value)
23-
modelConfig := config.ModelConfig
45+
// Apply every option to a single accumulator so we can read the final
46+
// effective values directly. "Later opt wins" semantics fall out naturally.
47+
var merged options.ModelOptions
2448
for _, opt := range mergedOpts {
25-
tempOpts := &options.ModelOptions{}
26-
opt(tempOpts)
27-
if mt := tempOpts.MaxTokens(); mt != 0 {
28-
modelConfig.MaxTokens = &mt
29-
}
30-
if tempOpts.NoThinking() {
31-
modelConfig.ThinkingBudget = nil
32-
}
49+
opt(&merged)
3350
}
3451

35-
// Use NewWithModels to support cloning routers that reference other models.
36-
// config.Models is populated by routers; for other providers it's nil (which is fine).
37-
clone, err := NewWithModels(ctx, &modelConfig, config.Models, config.Env, mergedOpts...)
38-
if err != nil {
39-
slog.Debug("Failed to clone provider; using base provider", "error", err, "id", base.ID())
40-
return base
52+
modelConfig := cfg.ModelConfig
53+
if mt := merged.MaxTokens(); mt != 0 {
54+
modelConfig.MaxTokens = &mt
4155
}
42-
43-
return clone
56+
if merged.NoThinking() {
57+
modelConfig.ThinkingBudget = nil
58+
}
59+
return modelConfig, mergedOpts
4460
}

0 commit comments

Comments
 (0)