Skip to content

Commit 80865b1

Browse files
committed
refactor(provider): expose LookupAlias/EachAlias accessors for the alias table
The Aliases map is package-level mutable global state. External consumers (pkg/config/gather.go, pkg/runtime/model_switcher.go) reach into the map directly, and tests that need to assert what happens when an alias has only a BaseURL or only a TokenEnvVar have to either pick whatever the live table happens to contain or risk corrupting global state for parallel tests. Introduce two accessors that hide the storage detail: - LookupAlias(name) (Alias, bool) — single lookup, replaces all uses of Aliases[name]. - EachAlias() iter.Seq2[string, Alias] — Go 1.23 iterator for ranging. Migrate every external call site (gather.go, model_switcher.go) and all in-package usage outside provider.go to the accessors. The Aliases var stays exported for backwards compatibility but its godoc now points readers at the accessors and notes that direct mutation is unsupported. Add unit tests that lock in the contracts (LookupAlias case-sensitivity and zero-value miss semantics, EachAlias completeness and early-termination via break). Behaviour is unchanged. Both new accessors at 100% coverage; package coverage 96.4% -> 96.5%. Assisted-By: docker-agent
1 parent 1c9e856 commit 80865b1

6 files changed

Lines changed: 94 additions & 14 deletions

File tree

pkg/config/gather.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func addEnvVarsForModelConfig(model *latest.ModelConfig, customProviders map[str
106106
addEnvVarsForCoreProvider(effective, model, requiredEnv)
107107
}
108108
}
109-
} else if alias, exists := provider.Aliases[model.Provider]; exists {
109+
} else if alias, exists := provider.LookupAlias(model.Provider); exists {
110110
// Check built-in aliases
111111
if alias.TokenEnvVar != "" {
112112
requiredEnv[alias.TokenEnvVar] = true

pkg/model/provider/aliases_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package provider
2+
3+
import (
4+
"maps"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestLookupAlias(t *testing.T) {
12+
t.Parallel()
13+
14+
// Every entry in the table is reachable.
15+
for name, expected := range Aliases {
16+
got, ok := LookupAlias(name)
17+
assert.True(t, ok, "alias %q should be found", name)
18+
assert.Equal(t, expected, got)
19+
}
20+
21+
// Unknown name yields the zero Alias and false.
22+
got, ok := LookupAlias("does-not-exist")
23+
assert.False(t, ok)
24+
assert.Equal(t, Alias{}, got)
25+
26+
// Lookup is case-sensitive (callers normalise themselves).
27+
if _, ok := LookupAlias("MISTRAL"); ok {
28+
t.Errorf("LookupAlias should be case-sensitive")
29+
}
30+
}
31+
32+
func TestEachAlias(t *testing.T) {
33+
t.Parallel()
34+
35+
// Iterator yields every entry exactly once.
36+
collected := maps.Collect(EachAlias())
37+
assert.Equal(t, Aliases, collected)
38+
}
39+
40+
func TestEachAlias_EarlyTermination(t *testing.T) {
41+
t.Parallel()
42+
43+
// Iterator must respect a false return from the yield function.
44+
require.NotEmpty(t, Aliases, "test requires the alias table to be non-empty")
45+
46+
count := 0
47+
for range EachAlias() {
48+
count++
49+
if count == 1 {
50+
break
51+
}
52+
}
53+
assert.Equal(t, 1, count, "iteration should stop when consumer breaks out")
54+
}

pkg/model/provider/provider.go

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package provider
33
import (
44
"context"
55
"fmt"
6+
"iter"
67
"log/slog"
78
"maps"
89
"slices"
@@ -54,7 +55,7 @@ func IsKnownProvider(name string) bool {
5455
if slices.Contains(CoreProviders, strings.ToLower(name)) {
5556
return true
5657
}
57-
_, exists := Aliases[strings.ToLower(name)]
58+
_, exists := LookupAlias(strings.ToLower(name))
5859
return exists
5960
}
6061

@@ -68,7 +69,7 @@ func CatalogProviders() []string {
6869
providers = append(providers, CoreProviders...)
6970

7071
// Add aliases that have a defined BaseURL (they work out of the box)
71-
for name, alias := range Aliases {
72+
for name, alias := range EachAlias() {
7273
if alias.BaseURL != "" {
7374
providers = append(providers, name)
7475
}
@@ -84,13 +85,17 @@ func IsCatalogProvider(name string) bool {
8485
return true
8586
}
8687
// Check aliases with BaseURL
87-
if alias, exists := Aliases[name]; exists && alias.BaseURL != "" {
88+
if alias, exists := LookupAlias(name); exists && alias.BaseURL != "" {
8889
return true
8990
}
9091
return false
9192
}
9293

93-
// Aliases maps provider names to their corresponding configurations
94+
// Aliases maps provider names to their corresponding configurations.
95+
//
96+
// Most consumers should call [LookupAlias] for a single lookup or [EachAlias]
97+
// to iterate, both of which keep the rest of the codebase decoupled from this
98+
// concrete map. Direct mutation of Aliases is not supported.
9499
var Aliases = map[string]Alias{
95100
"requesty": {
96101
APIType: "openai",
@@ -132,6 +137,27 @@ var Aliases = map[string]Alias{
132137
},
133138
}
134139

140+
// LookupAlias returns the Alias registered for the given name (if any).
141+
// Lookup is case-sensitive; callers that need case-insensitive matching
142+
// should normalise the name first (e.g. [strings.ToLower]).
143+
func LookupAlias(name string) (Alias, bool) {
144+
alias, ok := Aliases[name]
145+
return alias, ok
146+
}
147+
148+
// EachAlias returns an iterator over every registered (name, Alias) pair.
149+
// Iteration order is not guaranteed; callers that need a deterministic order
150+
// should sort by name.
151+
func EachAlias() iter.Seq2[string, Alias] {
152+
return func(yield func(string, Alias) bool) {
153+
for name, alias := range Aliases {
154+
if !yield(name, alias) {
155+
return
156+
}
157+
}
158+
}
159+
}
160+
135161
// Provider defines the interface for model providers
136162
type Provider interface {
137163
// ID returns the model provider ID
@@ -305,7 +331,7 @@ func resolveProviderType(cfg *latest.ModelConfig) string {
305331
return apiType
306332
}
307333
}
308-
if alias, exists := Aliases[cfg.Provider]; exists && alias.APIType != "" {
334+
if alias, exists := LookupAlias(cfg.Provider); exists && alias.APIType != "" {
309335
return alias.APIType
310336
}
311337
return cfg.Provider
@@ -337,7 +363,7 @@ func applyProviderDefaults(cfg *latest.ModelConfig, customProviders map[string]l
337363
return enhancedCfg
338364
}
339365

340-
if alias, exists := Aliases[cfg.Provider]; exists {
366+
if alias, exists := LookupAlias(cfg.Provider); exists {
341367
applyAliasFallbacks(enhancedCfg, alias)
342368
}
343369

@@ -571,7 +597,7 @@ func isOpenAICompatibleProvider(providerType string) bool {
571597
return true
572598
default:
573599
// Check if it's an alias that maps to openai
574-
if alias, exists := Aliases[providerType]; exists {
600+
if alias, exists := LookupAlias(providerType); exists {
575601
return alias.APIType == "openai"
576602
}
577603
return false

pkg/model/provider/provider_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func TestCatalogProviders(t *testing.T) {
1717
}
1818

1919
// Should include aliases with BaseURL
20-
for name, alias := range Aliases {
20+
for name, alias := range EachAlias() {
2121
if alias.BaseURL != "" {
2222
assert.Contains(t, providers, name, "should include alias %s with BaseURL", name)
2323
} else {
@@ -35,7 +35,7 @@ func TestIsCatalogProvider(t *testing.T) {
3535
}
3636

3737
// Aliases: catalog if and only if they have a BaseURL
38-
for name, alias := range Aliases {
38+
for name, alias := range EachAlias() {
3939
if alias.BaseURL != "" {
4040
assert.True(t, IsCatalogProvider(name), "alias %s with BaseURL should be a catalog provider", name)
4141
} else {
@@ -59,7 +59,7 @@ func TestAllProviders(t *testing.T) {
5959
}
6060

6161
// Should include all aliases
62-
for name := range Aliases {
62+
for name := range EachAlias() {
6363
assert.Contains(t, all, name, "should include alias %s", name)
6464
}
6565

@@ -76,7 +76,7 @@ func TestIsKnownProvider(t *testing.T) {
7676
}
7777

7878
// All aliases should be known
79-
for name := range Aliases {
79+
for name := range EachAlias() {
8080
assert.True(t, IsKnownProvider(name), "alias %s should be known", name)
8181
}
8282

pkg/model/provider/resolve_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func TestIsOpenAICompatibleProvider(t *testing.T) {
5454
}
5555

5656
// Aliases that point to the openai api — the previously-uncovered tail.
57-
for name, alias := range Aliases {
57+
for name, alias := range EachAlias() {
5858
if alias.APIType == "openai" {
5959
t.Run("alias/"+name, func(t *testing.T) {
6060
t.Parallel()

pkg/runtime/model_switcher.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ func (r *LocalRuntime) getAvailableProviders(ctx context.Context) map[string]boo
466466
}
467467

468468
// Check credentials for each alias provider
469-
for name, alias := range provider.Aliases {
469+
for name, alias := range provider.EachAlias() {
470470
if alias.TokenEnvVar == "" {
471471
continue
472472
}

0 commit comments

Comments
 (0)