Skip to content

Commit c3167b2

Browse files
committed
Calculate max_tokens value for anthropic requests based on available context using official token counting api
Signed-off-by: Christopher Petito <chrisjpetito@gmail.com>
1 parent 8ffa57e commit c3167b2

4 files changed

Lines changed: 642 additions & 6 deletions

File tree

pkg/model/provider/anthropic/beta_client.go

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@ func (c *Client) createBetaStream(
3838
}
3939
}
4040

41+
sys := extractBetaSystemBlocks(messages)
42+
43+
if used, err := countAnthropicTokensBeta(ctx, client, anthropic.Model(c.ModelConfig.Model), converted, sys, allTools); err == nil {
44+
configuredMaxTokens := maxTokens
45+
maxTokens = clampMaxTokens(anthropicContextLimit(c.ModelConfig.Model), used, maxTokens)
46+
if maxTokens < configuredMaxTokens {
47+
slog.Warn("Anthropic Beta API max_tokens clamped to", "max_tokens", maxTokens)
48+
}
49+
}
50+
4151
params := anthropic.BetaMessageNewParams{
4252
Model: anthropic.Model(c.ModelConfig.Model),
4353
MaxTokens: maxTokens,
@@ -47,7 +57,7 @@ func (c *Client) createBetaStream(
4757
}
4858

4959
// Populate proper Anthropic system prompt from input messages
50-
if sys := extractBetaSystemBlocks(messages); len(sys) > 0 {
60+
if len(sys) > 0 {
5161
params.System = sys
5262
}
5363

@@ -186,3 +196,42 @@ func contentArrayBeta(m map[string]any) []any {
186196
}
187197
return nil
188198
}
199+
200+
// countAnthropicTokensBeta calls Anthropic's Count Tokens API for the provided Beta API payload
201+
// and returns the number of input tokens.
202+
func countAnthropicTokensBeta(
203+
ctx context.Context,
204+
client anthropic.Client,
205+
model anthropic.Model,
206+
messages []anthropic.BetaMessageParam,
207+
system []anthropic.BetaTextBlockParam,
208+
anthropicTools []anthropic.BetaToolUnionParam,
209+
) (int64, error) {
210+
params := anthropic.BetaMessageCountTokensParams{
211+
Model: model,
212+
Messages: messages,
213+
}
214+
if len(system) > 0 {
215+
params.System = anthropic.BetaMessageCountTokensParamsSystemUnion{
216+
OfBetaTextBlockArray: system,
217+
}
218+
}
219+
if len(anthropicTools) > 0 {
220+
// Convert BetaToolUnionParam to BetaMessageCountTokensParamsToolUnion
221+
toolParams := make([]anthropic.BetaMessageCountTokensParamsToolUnion, len(anthropicTools))
222+
for i, tool := range anthropicTools {
223+
if tool.OfTool != nil {
224+
toolParams[i] = anthropic.BetaMessageCountTokensParamsToolUnion{
225+
OfTool: tool.OfTool,
226+
}
227+
}
228+
}
229+
params.Tools = toolParams
230+
}
231+
232+
result, err := client.Beta.Messages.CountTokens(ctx, params)
233+
if err != nil {
234+
return 0, err
235+
}
236+
return result.InputTokens, nil
237+
}
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
package anthropic
2+
3+
import (
4+
"encoding/json"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/anthropics/anthropic-sdk-go"
10+
"github.com/anthropics/anthropic-sdk-go/option"
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
14+
"github.com/docker/cagent/pkg/chat"
15+
)
16+
17+
// TestCountAnthropicTokensBeta_Success tests successful token counting for beta API
18+
func TestCountAnthropicTokensBeta_Success(t *testing.T) {
19+
// Setup mock server
20+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
21+
assert.Equal(t, "/v1/messages/count_tokens", r.URL.Path)
22+
assert.Equal(t, "application/json", r.Header.Get("content-type"))
23+
assert.NotEmpty(t, r.Header.Get("x-api-key"))
24+
25+
// Verify request body contains expected fields
26+
var payload map[string]any
27+
err := json.NewDecoder(r.Body).Decode(&payload)
28+
assert.NoError(t, err)
29+
assert.Equal(t, "claude-3-5-sonnet-20241022", payload["model"])
30+
assert.NotNil(t, payload["messages"])
31+
32+
// Return mock response
33+
w.Header().Set("content-type", "application/json")
34+
err = json.NewEncoder(w).Encode(map[string]int64{"input_tokens": 150})
35+
assert.NoError(t, err)
36+
}))
37+
defer server.Close()
38+
39+
// Create test data
40+
messages := []anthropic.BetaMessageParam{
41+
{
42+
Role: anthropic.BetaMessageParamRoleUser,
43+
Content: []anthropic.BetaContentBlockParamUnion{
44+
{OfText: &anthropic.BetaTextBlockParam{Text: "Hello"}},
45+
},
46+
},
47+
}
48+
system := []anthropic.BetaTextBlockParam{
49+
{Text: "You are helpful"},
50+
}
51+
52+
// Create client with test server URL
53+
client := anthropic.NewClient(
54+
option.WithAPIKey("test-key"),
55+
option.WithBaseURL(server.URL),
56+
)
57+
58+
// Call function
59+
tokens, err := countAnthropicTokensBeta(t.Context(), client, "claude-3-5-sonnet-20241022", messages, system, nil)
60+
61+
// Verify
62+
require.NoError(t, err)
63+
assert.Equal(t, int64(150), tokens)
64+
}
65+
66+
// TestCountAnthropicTokensBeta_NoAPIKey tests error when API key is missing
67+
func TestCountAnthropicTokensBeta_NoAPIKey(t *testing.T) {
68+
messages := []anthropic.BetaMessageParam{}
69+
system := []anthropic.BetaTextBlockParam{}
70+
71+
// Create client without base URL to trigger error
72+
client := anthropic.NewClient(
73+
option.WithAPIKey("test-key"),
74+
// No base URL set
75+
)
76+
77+
tokens, err := countAnthropicTokensBeta(t.Context(), client, "claude-3-5-sonnet-20241022", messages, system, nil)
78+
79+
require.Error(t, err)
80+
assert.Equal(t, int64(0), tokens)
81+
}
82+
83+
// TestCountAnthropicTokensBeta_ServerError tests error handling for server errors
84+
func TestCountAnthropicTokensBeta_ServerError(t *testing.T) {
85+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
86+
w.WriteHeader(http.StatusInternalServerError)
87+
}))
88+
defer server.Close()
89+
90+
messages := []anthropic.BetaMessageParam{}
91+
system := []anthropic.BetaTextBlockParam{}
92+
93+
// Create client with test server URL
94+
client := anthropic.NewClient(
95+
option.WithAPIKey("test-key"),
96+
option.WithBaseURL(server.URL),
97+
)
98+
99+
tokens, err := countAnthropicTokensBeta(t.Context(), client, "claude-3-5-sonnet-20241022", messages, system, nil)
100+
require.Error(t, err)
101+
assert.Equal(t, int64(0), tokens)
102+
}
103+
104+
// TestCountAnthropicTokensBeta_WithTools tests token counting includes tools
105+
func TestCountAnthropicTokensBeta_WithTools(t *testing.T) {
106+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
107+
var payload map[string]any
108+
err := json.NewDecoder(r.Body).Decode(&payload)
109+
assert.NoError(t, err)
110+
111+
// Verify tools are included in payload
112+
assert.NotNil(t, payload["tools"])
113+
tools, ok := payload["tools"].([]any)
114+
assert.True(t, ok)
115+
assert.Len(t, tools, 1)
116+
117+
w.Header().Set("content-type", "application/json")
118+
err = json.NewEncoder(w).Encode(map[string]int64{"input_tokens": 200})
119+
assert.NoError(t, err)
120+
}))
121+
defer server.Close()
122+
123+
messages := []anthropic.BetaMessageParam{}
124+
system := []anthropic.BetaTextBlockParam{}
125+
tools := []anthropic.BetaToolUnionParam{
126+
{OfTool: &anthropic.BetaToolParam{
127+
Name: "test_tool",
128+
Description: anthropic.String("A test tool"),
129+
}},
130+
}
131+
132+
// Create client with test server URL
133+
client := anthropic.NewClient(
134+
option.WithAPIKey("test-key"),
135+
option.WithBaseURL(server.URL),
136+
)
137+
138+
tokens, err := countAnthropicTokensBeta(t.Context(), client, "claude-3-5-sonnet-20241022", messages, system, tools)
139+
140+
require.NoError(t, err)
141+
assert.Equal(t, int64(200), tokens)
142+
}
143+
144+
// TestClampMaxTokens_WithinLimit tests clamping when configured tokens are within limit
145+
func TestClampMaxTokens_WithinLimit(t *testing.T) {
146+
// Context limit: 200k, used: 50k, safety: 1k, remaining: 149k
147+
// Configured: 8k (within limit)
148+
result := clampMaxTokens(200000, 50000, 8000)
149+
assert.Equal(t, int64(8000), result)
150+
}
151+
152+
// TestClampMaxTokens_ExceedsLimit tests clamping when configured tokens exceed remaining
153+
func TestClampMaxTokens_ExceedsLimit(t *testing.T) {
154+
// Context limit: 200k, used: 190k, safety: 1024, remaining: 8976
155+
// Configured: 16k (exceeds limit)
156+
result := clampMaxTokens(200000, 190000, 16000)
157+
assert.Equal(t, int64(8976), result)
158+
}
159+
160+
// TestClampMaxTokens_MinimumOne tests clamping never returns less than 1
161+
func TestClampMaxTokens_MinimumOne(t *testing.T) {
162+
// Context limit: 200k, used: 199k, safety: 1k, remaining: 0 (would be negative)
163+
result := clampMaxTokens(200000, 199000, 8000)
164+
assert.Equal(t, int64(1), result)
165+
}
166+
167+
// TestClampMaxTokens_ExactlyAtLimit tests clamping when used + safety equals limit
168+
func TestClampMaxTokens_ExactlyAtLimit(t *testing.T) {
169+
// Context limit: 200k, used: 199k, safety: 1k, remaining: 0
170+
result := clampMaxTokens(200000, 199000, 1000)
171+
assert.Equal(t, int64(1), result)
172+
}
173+
174+
// TestAnthropicContextLimit_ReturnsCorrectLimit tests context limit function
175+
func TestAnthropicContextLimit_ReturnsCorrectLimit(t *testing.T) {
176+
limit := anthropicContextLimit("claude-3-5-sonnet-20241022")
177+
assert.Equal(t, int64(200000), limit)
178+
}
179+
180+
// TestExtractBetaSystemBlocks_SingleSystemMessage tests extracting system messages
181+
func TestExtractBetaSystemBlocks_SingleSystemMessage(t *testing.T) {
182+
msgs := []chat.Message{
183+
{
184+
Role: chat.MessageRoleSystem,
185+
Content: "You are a helpful assistant",
186+
},
187+
}
188+
189+
blocks := extractBetaSystemBlocks(msgs)
190+
191+
require.Len(t, blocks, 1)
192+
assert.Equal(t, "You are a helpful assistant", blocks[0].Text)
193+
}
194+
195+
// TestExtractBetaSystemBlocks_MultipleSystemMessages tests extracting multiple system messages
196+
func TestExtractBetaSystemBlocks_MultipleSystemMessages(t *testing.T) {
197+
msgs := []chat.Message{
198+
{
199+
Role: chat.MessageRoleSystem,
200+
Content: "You are helpful",
201+
},
202+
{
203+
Role: chat.MessageRoleUser,
204+
Content: "Hello",
205+
},
206+
{
207+
Role: chat.MessageRoleSystem,
208+
Content: "Be concise",
209+
},
210+
}
211+
212+
blocks := extractBetaSystemBlocks(msgs)
213+
214+
require.Len(t, blocks, 2)
215+
assert.Equal(t, "You are helpful", blocks[0].Text)
216+
assert.Equal(t, "Be concise", blocks[1].Text)
217+
}
218+
219+
// TestExtractBetaSystemBlocks_SkipsEmptyText tests that empty system text is skipped
220+
func TestExtractBetaSystemBlocks_SkipsEmptyText(t *testing.T) {
221+
msgs := []chat.Message{
222+
{
223+
Role: chat.MessageRoleSystem,
224+
Content: " \n\t ",
225+
},
226+
{
227+
Role: chat.MessageRoleSystem,
228+
Content: "Valid system prompt",
229+
},
230+
}
231+
232+
blocks := extractBetaSystemBlocks(msgs)
233+
234+
require.Len(t, blocks, 1)
235+
assert.Equal(t, "Valid system prompt", blocks[0].Text)
236+
}
237+
238+
// TestExtractBetaSystemBlocks_MultiContent tests extracting from multi-content system messages
239+
func TestExtractBetaSystemBlocks_MultiContent(t *testing.T) {
240+
msgs := []chat.Message{
241+
{
242+
Role: chat.MessageRoleSystem,
243+
MultiContent: []chat.MessagePart{
244+
{Type: chat.MessagePartTypeText, Text: "Part 1"},
245+
{Type: chat.MessagePartTypeText, Text: "Part 2"},
246+
},
247+
},
248+
}
249+
250+
blocks := extractBetaSystemBlocks(msgs)
251+
252+
require.Len(t, blocks, 2)
253+
assert.Equal(t, "Part 1", blocks[0].Text)
254+
assert.Equal(t, "Part 2", blocks[1].Text)
255+
}
256+
257+
// TestConvertBetaMessages_UserMessage tests converting user messages
258+
func TestConvertBetaMessages_UserMessage(t *testing.T) {
259+
msgs := []chat.Message{
260+
{
261+
Role: chat.MessageRoleUser,
262+
Content: "Hello, assistant!",
263+
},
264+
}
265+
266+
converted := convertBetaMessages(msgs)
267+
268+
require.Len(t, converted, 1)
269+
assert.Equal(t, anthropic.BetaMessageParamRoleUser, converted[0].Role)
270+
require.Len(t, converted[0].Content, 1)
271+
}
272+
273+
// TestConvertBetaMessages_SkipsSystemMessages tests that system messages are skipped
274+
func TestConvertBetaMessages_SkipsSystemMessages(t *testing.T) {
275+
msgs := []chat.Message{
276+
{
277+
Role: chat.MessageRoleSystem,
278+
Content: "System prompt",
279+
},
280+
{
281+
Role: chat.MessageRoleUser,
282+
Content: "User message",
283+
},
284+
}
285+
286+
converted := convertBetaMessages(msgs)
287+
288+
require.Len(t, converted, 1)
289+
assert.Equal(t, anthropic.BetaMessageParamRoleUser, converted[0].Role)
290+
}
291+
292+
// TestConvertBetaMessages_AssistantMessage tests converting assistant messages
293+
func TestConvertBetaMessages_AssistantMessage(t *testing.T) {
294+
msgs := []chat.Message{
295+
{
296+
Role: chat.MessageRoleAssistant,
297+
Content: "I can help with that",
298+
},
299+
}
300+
301+
converted := convertBetaMessages(msgs)
302+
303+
require.Len(t, converted, 1)
304+
assert.Equal(t, anthropic.BetaMessageParamRoleAssistant, converted[0].Role)
305+
}

0 commit comments

Comments
 (0)