Skip to content

Commit 8adec31

Browse files
committed
refactor(provider): extract router factory closure to resolveRoutedModel
createRuleBasedRouter inlined a 25-line closure that resolves routing targets — either by looking up a name in the models map or by parsing it as an inline 'provider/model' spec. Because the closure was anonymous and only invoked through rulebased.NewClient (which wraps every error), four error paths were 0% covered: - the recursion-prevention branch (target itself has routing rules), - the ParseModelRef failure branch, - the named-reference factory error path, - the inline-spec factory error path. Promote the closure to a package-level function resolveRoutedModel. Tests can now call it directly with a swapped factory registry to verify all four error paths and that factoryOpts (e.g. WithMaxTokens) are forwarded unchanged. Behaviour is unchanged. createRuleBasedRouter and resolveRoutedModel both reach 100% coverage; package coverage 91.8% -> 93.9%. Assisted-By: docker-agent
1 parent 962e1fe commit 8adec31

2 files changed

Lines changed: 180 additions & 24 deletions

File tree

pkg/model/provider/provider.go

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -195,34 +195,39 @@ func NewWithModels(ctx context.Context, cfg *latest.ModelConfig, models map[stri
195195

196196
// createRuleBasedRouter creates a rule-based routing provider.
197197
func createRuleBasedRouter(ctx context.Context, cfg *latest.ModelConfig, models map[string]latest.ModelConfig, env environment.Provider, opts ...options.Opt) (Provider, error) {
198-
// Create a provider factory that can resolve model references
199-
factory := func(ctx context.Context, modelSpec string, models map[string]latest.ModelConfig, env environment.Provider, factoryOpts ...options.Opt) (rulebased.Provider, error) {
200-
// Check if modelSpec is a reference to a model in the models map
201-
if modelCfg, exists := models[modelSpec]; exists {
202-
// Prevent infinite recursion - referenced models cannot have routing rules
203-
if len(modelCfg.Routing) > 0 {
204-
return nil, fmt.Errorf("model %q has routing rules and cannot be used as a routing target", modelSpec)
205-
}
206-
p, err := createDirectProvider(ctx, &modelCfg, env, factoryOpts...)
207-
if err != nil {
208-
return nil, err
209-
}
210-
return p, nil
211-
}
198+
return rulebased.NewClient(ctx, cfg, models, env, resolveRoutedModel, opts...)
199+
}
212200

213-
// Otherwise, treat as an inline model spec (e.g., "openai/gpt-4o")
214-
inlineCfg, parseErr := latest.ParseModelRef(modelSpec)
215-
if parseErr != nil {
216-
return nil, fmt.Errorf("invalid model spec %q: expected 'provider/model' format or a model reference", modelSpec)
217-
}
218-
p, err := createDirectProvider(ctx, &inlineCfg, env, factoryOpts...)
219-
if err != nil {
220-
return nil, err
201+
// resolveRoutedModel is the rulebased.ProviderFactory used by
202+
// createRuleBasedRouter. It resolves a routing target — which is either a name
203+
// from the models map or an inline "provider/model" spec — and returns the
204+
// provider for it. Routing targets cannot themselves have routing rules.
205+
//
206+
// Defined as a package-level function (rather than an inline closure) so the
207+
// recursion-prevention, parse-error and factory-error paths can be unit-tested
208+
// directly without going through rulebased.NewClient.
209+
func resolveRoutedModel(
210+
ctx context.Context,
211+
modelSpec string,
212+
models map[string]latest.ModelConfig,
213+
env environment.Provider,
214+
factoryOpts ...options.Opt,
215+
) (rulebased.Provider, error) {
216+
// Check if modelSpec is a reference to a model in the models map.
217+
if modelCfg, exists := models[modelSpec]; exists {
218+
// Prevent infinite recursion - referenced models cannot have routing rules.
219+
if len(modelCfg.Routing) > 0 {
220+
return nil, fmt.Errorf("model %q has routing rules and cannot be used as a routing target", modelSpec)
221221
}
222-
return p, nil
222+
return createDirectProvider(ctx, &modelCfg, env, factoryOpts...)
223223
}
224224

225-
return rulebased.NewClient(ctx, cfg, models, env, factory, opts...)
225+
// Otherwise, treat as an inline model spec (e.g., "openai/gpt-4o").
226+
inlineCfg, parseErr := latest.ParseModelRef(modelSpec)
227+
if parseErr != nil {
228+
return nil, fmt.Errorf("invalid model spec %q: expected 'provider/model' format or a model reference", modelSpec)
229+
}
230+
return createDirectProvider(ctx, &inlineCfg, env, factoryOpts...)
226231
}
227232

228233
// createDirectProvider creates a provider without routing (direct model access).
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
package provider
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
11+
"github.com/docker/docker-agent/pkg/config/latest"
12+
"github.com/docker/docker-agent/pkg/environment"
13+
"github.com/docker/docker-agent/pkg/model/provider/options"
14+
)
15+
16+
// TestResolveRoutedModel_NamedReference verifies that a model name in the
17+
// models map is resolved to its config and dispatched through the registry.
18+
func TestResolveRoutedModel_NamedReference(t *testing.T) {
19+
var capturedCfg *latest.ModelConfig
20+
withFactories(t, map[string]providerFactory{
21+
"openai": func(_ context.Context, cfg *latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) {
22+
capturedCfg = cfg
23+
return &fakeProvider{id: "captured"}, nil
24+
},
25+
})
26+
27+
models := map[string]latest.ModelConfig{
28+
"fast": {Provider: "openai", Model: "gpt-4o-mini"},
29+
}
30+
31+
p, err := resolveRoutedModel(t.Context(), "fast", models, environment.NewNoEnvProvider())
32+
require.NoError(t, err)
33+
require.NotNil(t, p)
34+
35+
require.NotNil(t, capturedCfg)
36+
assert.Equal(t, "openai", capturedCfg.Provider)
37+
assert.Equal(t, "gpt-4o-mini", capturedCfg.Model)
38+
}
39+
40+
// TestResolveRoutedModel_InlineSpec verifies that a "provider/model" string
41+
// not present in the models map is parsed as an inline reference.
42+
func TestResolveRoutedModel_InlineSpec(t *testing.T) {
43+
var capturedCfg *latest.ModelConfig
44+
withFactories(t, map[string]providerFactory{
45+
"openai": func(_ context.Context, cfg *latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) {
46+
capturedCfg = cfg
47+
return &fakeProvider{id: "captured"}, nil
48+
},
49+
})
50+
51+
p, err := resolveRoutedModel(t.Context(), "openai/gpt-4o", nil, environment.NewNoEnvProvider())
52+
require.NoError(t, err)
53+
require.NotNil(t, p)
54+
55+
require.NotNil(t, capturedCfg)
56+
assert.Equal(t, "openai", capturedCfg.Provider)
57+
assert.Equal(t, "gpt-4o", capturedCfg.Model)
58+
}
59+
60+
// TestResolveRoutedModel_InvalidInlineSpec covers the previously-unreachable
61+
// ParseModelRef error branch (modelSpec not in map and not parseable).
62+
func TestResolveRoutedModel_InvalidInlineSpec(t *testing.T) {
63+
withFactories(t, map[string]providerFactory{
64+
"openai": tagFactory("openai"),
65+
})
66+
67+
_, err := resolveRoutedModel(t.Context(), "no-slash-here", nil, environment.NewNoEnvProvider())
68+
require.Error(t, err)
69+
assert.Contains(t, err.Error(), "invalid model spec")
70+
assert.Contains(t, err.Error(), "no-slash-here")
71+
}
72+
73+
// TestResolveRoutedModel_RecursiveRoutingRejected covers the recursion-prevention
74+
// branch: a routing target may not itself have routing rules.
75+
func TestResolveRoutedModel_RecursiveRoutingRejected(t *testing.T) {
76+
withFactories(t, map[string]providerFactory{
77+
"openai": tagFactory("openai"),
78+
})
79+
80+
models := map[string]latest.ModelConfig{
81+
"router-as-target": {
82+
Provider: "openai",
83+
Model: "gpt-4o",
84+
Routing: []latest.RoutingRule{{Model: "x", Examples: []string{"hi"}}},
85+
},
86+
}
87+
88+
_, err := resolveRoutedModel(t.Context(), "router-as-target", models, environment.NewNoEnvProvider())
89+
require.Error(t, err)
90+
assert.Contains(t, err.Error(), "routing rules")
91+
assert.Contains(t, err.Error(), "router-as-target")
92+
}
93+
94+
// TestResolveRoutedModel_FactoryErrorPropagated_NamedRef verifies that when a
95+
// named model's factory fails, the error is returned to the caller.
96+
func TestResolveRoutedModel_FactoryErrorPropagated_NamedRef(t *testing.T) {
97+
sentinel := errors.New("named-fail")
98+
withFactories(t, map[string]providerFactory{
99+
"openai": func(_ context.Context, _ *latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) {
100+
return nil, sentinel
101+
},
102+
})
103+
104+
models := map[string]latest.ModelConfig{
105+
"fast": {Provider: "openai", Model: "gpt-4o-mini"},
106+
}
107+
108+
_, err := resolveRoutedModel(t.Context(), "fast", models, environment.NewNoEnvProvider())
109+
require.ErrorIs(t, err, sentinel)
110+
}
111+
112+
// TestResolveRoutedModel_FactoryErrorPropagated_Inline verifies the same for
113+
// an inline "provider/model" spec.
114+
func TestResolveRoutedModel_FactoryErrorPropagated_Inline(t *testing.T) {
115+
sentinel := errors.New("inline-fail")
116+
withFactories(t, map[string]providerFactory{
117+
"openai": func(_ context.Context, _ *latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) {
118+
return nil, sentinel
119+
},
120+
})
121+
122+
_, err := resolveRoutedModel(t.Context(), "openai/gpt-4o", nil, environment.NewNoEnvProvider())
123+
require.ErrorIs(t, err, sentinel)
124+
}
125+
126+
// TestResolveRoutedModel_OptionsForwarded verifies that factoryOpts reach the
127+
// downstream factory unchanged. This guards against accidentally dropping
128+
// options (e.g. WithProviders) when extracting the closure.
129+
func TestResolveRoutedModel_OptionsForwarded(t *testing.T) {
130+
var capturedOpts []options.Opt
131+
withFactories(t, map[string]providerFactory{
132+
"openai": func(_ context.Context, _ *latest.ModelConfig, _ environment.Provider, opts ...options.Opt) (Provider, error) {
133+
capturedOpts = opts
134+
return &fakeProvider{id: "ok"}, nil
135+
},
136+
})
137+
138+
maxTokens := int64(2048)
139+
_, err := resolveRoutedModel(
140+
t.Context(), "openai/gpt-4o", nil, environment.NewNoEnvProvider(),
141+
options.WithMaxTokens(maxTokens),
142+
)
143+
require.NoError(t, err)
144+
145+
require.NotEmpty(t, capturedOpts)
146+
var probe options.ModelOptions
147+
for _, o := range capturedOpts {
148+
o(&probe)
149+
}
150+
assert.Equal(t, maxTokens, probe.MaxTokens())
151+
}

0 commit comments

Comments
 (0)