Skip to content

Commit 84e8306

Browse files
committed
refactor(provider): make googleFactory's gemini vs vertex dispatch testable
googleFactory contains the only non-trivial dispatch in the factory registry: it routes Gemini calls to gemini.NewClient and Vertex Model Garden calls (publisher != google) to vertexai.NewClient. Both arms reach out to real Google APIs, so neither was reachable from a unit test. Lift the two inner constructors into package-level providerFactory variables (geminiClientFactory / vertexClientFactory). Tests can swap them via t.Cleanup and assert that vertexai.IsModelGardenConfig is consulted correctly: - plain google config -> gemini, - publisher=google still -> gemini (documented edge case), - publisher=anthropic -> vertex (Model Garden), - errors from either inner factory propagate unchanged. Behaviour is unchanged. googleFactory goes from 0% to 100% coverage; package coverage 96.5% -> 97.1%. Assisted-By: docker-agent
1 parent 80865b1 commit 84e8306

2 files changed

Lines changed: 150 additions & 2 deletions

File tree

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package provider
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
11+
"github.com/docker/docker-agent/pkg/config/latest"
12+
"github.com/docker/docker-agent/pkg/environment"
13+
"github.com/docker/docker-agent/pkg/model/provider/options"
14+
)
15+
16+
// withGoogleFactories swaps the gemini and vertex inner factories for the
17+
// duration of a test, restoring the production functions via t.Cleanup.
18+
func withGoogleFactories(t *testing.T, gemini, vertex providerFactory) {
19+
t.Helper()
20+
originalGemini, originalVertex := geminiClientFactory, vertexClientFactory
21+
geminiClientFactory = gemini
22+
vertexClientFactory = vertex
23+
t.Cleanup(func() {
24+
geminiClientFactory = originalGemini
25+
vertexClientFactory = originalVertex
26+
})
27+
}
28+
29+
// TestGoogleFactory_RoutesGeminiByDefault verifies that a plain "google" model
30+
// (no Vertex Model Garden hints) is dispatched to the Gemini client.
31+
func TestGoogleFactory_RoutesGeminiByDefault(t *testing.T) {
32+
withGoogleFactories(t,
33+
tagFactory("gemini"),
34+
func(_ context.Context, _ *latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) {
35+
t.Errorf("vertex factory should not be called for plain Gemini config")
36+
return nil, errors.New("unreachable")
37+
},
38+
)
39+
40+
cfg := &latest.ModelConfig{Provider: "google", Model: "gemini-2.5-flash"}
41+
42+
p, err := googleFactory(t.Context(), cfg, environment.NewNoEnvProvider())
43+
require.NoError(t, err)
44+
fp, ok := p.(*fakeProvider)
45+
require.True(t, ok)
46+
assert.Equal(t, "gemini", fp.id)
47+
}
48+
49+
// TestGoogleFactory_RoutesGeminiWhenPublisherIsGoogle covers the documented
50+
// edge case in vertexai.IsModelGardenConfig: publisher=google still routes to
51+
// Gemini (it's only a Model Garden config when publisher is non-google).
52+
func TestGoogleFactory_RoutesGeminiWhenPublisherIsGoogle(t *testing.T) {
53+
withGoogleFactories(t,
54+
tagFactory("gemini"),
55+
func(_ context.Context, _ *latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) {
56+
t.Errorf("vertex factory must not be called when publisher=google")
57+
return nil, errors.New("unreachable")
58+
},
59+
)
60+
61+
cfg := &latest.ModelConfig{
62+
Provider: "google",
63+
Model: "gemini-2.5-flash",
64+
ProviderOpts: map[string]any{"publisher": "google"},
65+
}
66+
67+
p, err := googleFactory(t.Context(), cfg, environment.NewNoEnvProvider())
68+
require.NoError(t, err)
69+
fp, ok := p.(*fakeProvider)
70+
require.True(t, ok)
71+
assert.Equal(t, "gemini", fp.id)
72+
}
73+
74+
// TestGoogleFactory_RoutesVertexForModelGarden verifies that any non-Google
75+
// publisher routes through the Vertex Model Garden factory.
76+
func TestGoogleFactory_RoutesVertexForModelGarden(t *testing.T) {
77+
withGoogleFactories(t,
78+
func(_ context.Context, _ *latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) {
79+
t.Errorf("gemini factory must not be called for Model Garden config")
80+
return nil, errors.New("unreachable")
81+
},
82+
tagFactory("vertex"),
83+
)
84+
85+
cfg := &latest.ModelConfig{
86+
Provider: "google",
87+
Model: "claude-3-5-sonnet@20240620",
88+
ProviderOpts: map[string]any{"publisher": "anthropic"},
89+
}
90+
91+
p, err := googleFactory(t.Context(), cfg, environment.NewNoEnvProvider())
92+
require.NoError(t, err)
93+
fp, ok := p.(*fakeProvider)
94+
require.True(t, ok)
95+
assert.Equal(t, "vertex", fp.id)
96+
}
97+
98+
// TestGoogleFactory_PropagatesGeminiError verifies that errors from the inner
99+
// gemini factory are surfaced unchanged.
100+
func TestGoogleFactory_PropagatesGeminiError(t *testing.T) {
101+
sentinel := errors.New("gemini-fail")
102+
withGoogleFactories(t,
103+
func(_ context.Context, _ *latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) {
104+
return nil, sentinel
105+
},
106+
tagFactory("vertex"),
107+
)
108+
109+
cfg := &latest.ModelConfig{Provider: "google", Model: "gemini-2.5-flash"}
110+
111+
_, err := googleFactory(t.Context(), cfg, environment.NewNoEnvProvider())
112+
require.ErrorIs(t, err, sentinel)
113+
}
114+
115+
// TestGoogleFactory_PropagatesVertexError verifies that errors from the inner
116+
// vertex factory are surfaced unchanged.
117+
func TestGoogleFactory_PropagatesVertexError(t *testing.T) {
118+
sentinel := errors.New("vertex-fail")
119+
withGoogleFactories(t,
120+
tagFactory("gemini"),
121+
func(_ context.Context, _ *latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) {
122+
return nil, sentinel
123+
},
124+
)
125+
126+
cfg := &latest.ModelConfig{
127+
Provider: "google",
128+
Model: "claude-3-5-sonnet@20240620",
129+
ProviderOpts: map[string]any{"publisher": "anthropic"},
130+
}
131+
132+
_, err := googleFactory(t.Context(), cfg, environment.NewNoEnvProvider())
133+
require.ErrorIs(t, err, sentinel)
134+
}

pkg/model/provider/provider.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,25 @@ func googleFactory(ctx context.Context, cfg *latest.ModelConfig, env environment
306306
// Route non-Gemini models on Vertex AI (Model Garden) through the
307307
// vertexai package, which picks the right endpoint per publisher.
308308
if vertexai.IsModelGardenConfig(cfg) {
309-
return vertexai.NewClient(ctx, cfg, env, opts...)
309+
return vertexClientFactory(ctx, cfg, env, opts...)
310310
}
311-
return gemini.NewClient(ctx, cfg, env, opts...)
311+
return geminiClientFactory(ctx, cfg, env, opts...)
312312
}
313313

314+
// geminiClientFactory and vertexClientFactory are the inner constructors used
315+
// by googleFactory. They are package-level variables (rather than direct
316+
// references to gemini.NewClient / vertexai.NewClient) so that tests can swap
317+
// them with fakes via t.Cleanup and assert that googleFactory routes correctly
318+
// based on vertexai.IsModelGardenConfig — without spinning up real clients.
319+
var (
320+
geminiClientFactory providerFactory = func(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (Provider, error) {
321+
return gemini.NewClient(ctx, cfg, env, opts...)
322+
}
323+
vertexClientFactory providerFactory = func(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (Provider, error) {
324+
return vertexai.NewClient(ctx, cfg, env, opts...)
325+
}
326+
)
327+
314328
func dmrFactory(ctx context.Context, cfg *latest.ModelConfig, _ environment.Provider, opts ...options.Opt) (Provider, error) {
315329
return dmr.NewClient(ctx, cfg, opts...)
316330
}

0 commit comments

Comments
 (0)