Skip to content

Commit 8beaab0

Browse files
refactor(auth): move IsMalformedHeader from server to auth package
The isMalformedAuthHeader function implements RFC 7230 header validation logic that belongs in the internal/auth package alongside other header parsing utilities (ParseAuthHeader, ExtractSessionID, etc.). Changes: - Add exported auth.IsMalformedHeader to internal/auth/header.go - Remove private isMalformedAuthHeader from internal/server/auth.go - Update server/auth.go to call auth.IsMalformedHeader - Update server/auth_test.go to call auth.IsMalformedHeader - Add TestIsMalformedHeader with 12 cases to internal/auth/header_test.go This makes all header validation logic discoverable in one place and ensures future additions follow the same pattern. Closes #4138 (partial: addresses Issue 3 from the analysis) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 18b19a6 commit 8beaab0

4 files changed

Lines changed: 98 additions & 17 deletions

File tree

internal/auth/header.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,19 @@ func TruncateSessionID(sessionID string) string {
183183
return strutil.Truncate(sessionID, 8)
184184
}
185185

186+
// IsMalformedHeader returns true if the header value contains characters
187+
// that are not valid in HTTP header values per RFC 7230: null bytes, control
188+
// characters below 0x20 (except horizontal tab 0x09), or DEL (0x7F).
189+
// Per spec 7.2 item 3, such headers must be rejected with HTTP 400.
190+
func IsMalformedHeader(header string) bool {
191+
for _, c := range header {
192+
if c == 0x00 || (c < 0x20 && c != 0x09) || c == 0x7F {
193+
return true
194+
}
195+
}
196+
return false
197+
}
198+
186199
// GenerateRandomAPIKey generates a cryptographically random API key.
187200
// Per spec §7.3, the gateway SHOULD generate a random API key on startup
188201
// if none is provided. Returns a 32-byte hex-encoded string (64 chars).

internal/auth/header_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,84 @@ import (
99
"github.com/github/gh-aw-mcpg/internal/logger/sanitize"
1010
)
1111

12+
func TestIsMalformedHeader(t *testing.T) {
13+
assert := assert.New(t)
14+
15+
tests := []struct {
16+
name string
17+
header string
18+
want bool
19+
}{
20+
{
21+
name: "Empty string is valid",
22+
header: "",
23+
want: false,
24+
},
25+
{
26+
name: "Normal API key is valid",
27+
header: "my-secret-api-key",
28+
want: false,
29+
},
30+
{
31+
name: "Bearer token is valid",
32+
header: "Bearer my-token-123",
33+
want: false,
34+
},
35+
{
36+
name: "Horizontal tab (0x09) is valid per RFC 7230",
37+
header: "key\twith\ttabs",
38+
want: false,
39+
},
40+
{
41+
name: "Printable ASCII is valid",
42+
header: "!#$%&'*+-.0123456789ABCDEFabcdef~",
43+
want: false,
44+
},
45+
{
46+
name: "Null byte (0x00) is malformed",
47+
header: "key\x00value",
48+
want: true,
49+
},
50+
{
51+
name: "DEL (0x7F) is malformed",
52+
header: "key\x7fvalue",
53+
want: true,
54+
},
55+
{
56+
name: "Control char 0x01 is malformed",
57+
header: "key\x01value",
58+
want: true,
59+
},
60+
{
61+
name: "Newline (0x0A) is malformed",
62+
header: "key\nvalue",
63+
want: true,
64+
},
65+
{
66+
name: "Carriage return (0x0D) is malformed",
67+
header: "key\rvalue",
68+
want: true,
69+
},
70+
{
71+
name: "Leading null byte",
72+
header: "\x00key",
73+
want: true,
74+
},
75+
{
76+
name: "Trailing null byte",
77+
header: "key\x00",
78+
want: true,
79+
},
80+
}
81+
82+
for _, tt := range tests {
83+
t.Run(tt.name, func(t *testing.T) {
84+
got := IsMalformedHeader(tt.header)
85+
assert.Equal(tt.want, got)
86+
})
87+
}
88+
}
89+
1290
func TestTruncateSecret(t *testing.T) {
1391
assert := assert.New(t)
1492

internal/server/auth.go

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,12 @@ package server
33
import (
44
"net/http"
55

6+
"github.com/github/gh-aw-mcpg/internal/auth"
67
"github.com/github/gh-aw-mcpg/internal/logger"
78
)
89

910
var logAuth = logger.New("server:auth")
1011

11-
// isMalformedAuthHeader returns true if the header value contains characters
12-
// that are not valid in HTTP header values per RFC 7230: null bytes, control
13-
// characters below 0x20 (except horizontal tab 0x09), or DEL (0x7F).
14-
// Per spec 7.2 item 3, such headers must be rejected with HTTP 400.
15-
func isMalformedAuthHeader(header string) bool {
16-
for _, c := range header {
17-
if c == 0x00 || (c < 0x20 && c != 0x09) || c == 0x7F {
18-
return true
19-
}
20-
}
21-
return false
22-
}
23-
2412
// authMiddleware implements API key authentication per spec section 7.1
2513
// Per spec: Authorization header MUST contain the API key directly (NOT Bearer scheme)
2614
//
@@ -43,7 +31,7 @@ func authMiddleware(apiKey string, next http.HandlerFunc) http.HandlerFunc {
4331

4432
// Spec 7.2 item 3: Malformed Authorization headers (null bytes, non-printable
4533
// control characters) must return 400 Bad Request, not 401.
46-
if isMalformedAuthHeader(authHeader) {
34+
if auth.IsMalformedHeader(authHeader) {
4735
rejectRequest(w, r, http.StatusBadRequest, "bad_request", "malformed Authorization header", "auth", "authentication_failed", "malformed_auth_header")
4836
return
4937
}

internal/server/auth_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77

88
"github.com/stretchr/testify/assert"
99
"github.com/stretchr/testify/require"
10+
11+
"github.com/github/gh-aw-mcpg/internal/auth"
1012
)
1113

1214
// TestAuthMiddleware tests the authMiddleware function with various scenarios
@@ -292,7 +294,7 @@ func TestAuthMiddleware_ConcurrentRequests(t *testing.T) {
292294
}
293295
}
294296

295-
// TestIsMalformedAuthHeader tests the isMalformedAuthHeader helper.
297+
// TestIsMalformedAuthHeader tests auth.IsMalformedHeader via the server package.
296298
func TestIsMalformedAuthHeader(t *testing.T) {
297299
tests := []struct {
298300
name string
@@ -313,8 +315,8 @@ func TestIsMalformedAuthHeader(t *testing.T) {
313315

314316
for _, tt := range tests {
315317
t.Run(tt.name, func(t *testing.T) {
316-
got := isMalformedAuthHeader(tt.header)
317-
assert.Equal(t, tt.malformed, got, "isMalformedAuthHeader(%q) should return %v", tt.header, tt.malformed)
318+
got := auth.IsMalformedHeader(tt.header)
319+
assert.Equal(t, tt.malformed, got, "auth.IsMalformedHeader(%q) should return %v", tt.header, tt.malformed)
318320
})
319321
}
320322
}

0 commit comments

Comments
 (0)