Skip to content

Commit e2fcfcc

Browse files
Copilotlpcox
andauthored
fix: prevent GraphQL authorAssociation injection into User-type nodes and add 503 logging
Issue 1: The GraphQL rewriter was blindly injecting authorAssociation into ALL nodes {} blocks using ReplaceAllString. This broke queries with assignees, participants, labels, and other connections returning User/Team/ Label types that don't support authorAssociation. Fix: Added findParentField() to identify the connection field name containing each nodes block. Injection now only occurs when the parent is in a safe set (pullRequests, issues, comments, reviews, search for issue/PR fields; history for commit fields). Issue 2: When --policy is missing, the proxy returned 503 "proxy enforcement not configured" but did not log the error, making debugging difficult. Fix: Added logger.LogError and logHandler.Printf calls so the 503 cause appears in both proxy.log and debug output. Agent-Logs-Url: https://github.com/github/gh-aw-mcpg/sessions/483eafda-7d9e-4972-abe8-062c6a85813e Co-authored-by: lpcox <15877973+lpcox@users.noreply.github.com>
1 parent f603d02 commit e2fcfcc

3 files changed

Lines changed: 258 additions & 16 deletions

File tree

internal/proxy/graphql_rewrite.go

Lines changed: 125 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,36 @@ var commitFields = []guardFieldSet{
3232
{"author{user{login}}", regexp.MustCompile(`\bauthor\s*\{[^}]*\buser\s*\{[^}]*\blogin\b`)},
3333
}
3434

35-
// fieldsForTool returns the guard fields applicable to the given tool name,
36-
// or nil if no injection is needed.
37-
func fieldsForTool(toolName string) []guardFieldSet {
35+
// issueAndPRSafeParents are connection field names whose node types support
36+
// author{login} and authorAssociation (i.e., types implementing the Comment
37+
// interface: Issue, PullRequest, IssueComment, PullRequestReview, etc.).
38+
// Injection into other connection nodes (e.g. assignees→User, labels→Label)
39+
// would cause GraphQL validation errors.
40+
var issueAndPRSafeParents = map[string]bool{
41+
"pullRequests": true,
42+
"issues": true,
43+
"comments": true,
44+
"reviews": true,
45+
"search": true,
46+
}
47+
48+
// commitSafeParents are connection field names whose node types support
49+
// author{user{login}} (Commit type).
50+
var commitSafeParents = map[string]bool{
51+
"history": true,
52+
}
53+
54+
// fieldsForTool returns the guard fields and safe parent connection names
55+
// applicable to the given tool name, or nil if no injection is needed.
56+
func fieldsForTool(toolName string) ([]guardFieldSet, map[string]bool) {
3857
switch toolName {
3958
case "list_issues", "list_pull_requests", "issue_read", "pull_request_read",
4059
"search_issues":
41-
return issueAndPRFields
60+
return issueAndPRFields, issueAndPRSafeParents
4261
case "list_commits":
43-
return commitFields
62+
return commitFields, commitSafeParents
4463
default:
45-
return nil
64+
return nil, nil
4665
}
4766
}
4867

@@ -73,7 +92,7 @@ func missingFields(query string, fields []guardFieldSet) []string {
7392
// Returns the (possibly modified) body. If injection is not needed or fails,
7493
// the original body is returned unchanged.
7594
func InjectGuardFields(body []byte, toolName string) []byte {
76-
fields := fieldsForTool(toolName)
95+
fields, safeParents := fieldsForTool(toolName)
7796
if fields == nil {
7897
logGraphQLRewrite.Printf("No guard field injection needed for tool=%s", toolName)
7998
return body
@@ -90,7 +109,7 @@ func InjectGuardFields(body []byte, toolName string) []byte {
90109
}
91110

92111
missing := missingFields(gql.Query, fields)
93-
modified := injectFieldsIntoQuery(gql.Query, missing)
112+
modified := injectFieldsIntoQuery(gql.Query, missing, safeParents)
94113
if modified == gql.Query {
95114
return body
96115
}
@@ -108,7 +127,10 @@ func InjectGuardFields(body []byte, toolName string) []byte {
108127
// injectFieldsIntoQuery adds the given fields into the GraphQL query's node
109128
// selection or fragment. Each field string (e.g. "author{login}",
110129
// "authorAssociation") is comma-joined and injected as a single block.
111-
func injectFieldsIntoQuery(query string, fields []string) string {
130+
// safeParents limits direct nodes injection (Step 3) to nodes blocks whose
131+
// parent connection field is in the set, preventing injection into User/Label
132+
// type nodes that don't support the injected fields.
133+
func injectFieldsIntoQuery(query string, fields []string, safeParents map[string]bool) string {
112134
injection := strings.Join(fields, ",")
113135

114136
// Step 1: Check if the query uses a fragment spread in the nodes.
@@ -134,16 +156,106 @@ func injectFieldsIntoQuery(query string, fields []string) string {
134156
}
135157

136158
// Step 3: No fragment — inject directly into nodes { ... }
137-
nodesPattern := regexp.MustCompile(`(nodes\s*\{)`)
138-
if nodesPattern.MatchString(query) {
139-
logGraphQLRewrite.Printf("Injecting into nodes selection: fields=%q", injection)
140-
return nodesPattern.ReplaceAllString(query, "${1}"+injection+",")
159+
// Only inject into nodes blocks whose parent connection field is in the
160+
// safeParents set. This prevents injecting fields like authorAssociation
161+
// into nodes of types that don't support them (e.g. User, Label, Team).
162+
nodesPattern := regexp.MustCompile(`nodes\s*\{`)
163+
matches := nodesPattern.FindAllStringIndex(query, -1)
164+
if len(matches) > 0 {
165+
var buf strings.Builder
166+
pos := 0
167+
injected := false
168+
for _, m := range matches {
169+
parent := findParentField(query, m[0])
170+
buf.WriteString(query[pos:m[1]])
171+
if safeParents[parent] {
172+
buf.WriteString(injection + ",")
173+
injected = true
174+
} else {
175+
logGraphQLRewrite.Printf("Skipping injection into nodes under %q (not a safe parent)", parent)
176+
}
177+
pos = m[1]
178+
}
179+
buf.WriteString(query[pos:])
180+
if injected {
181+
logGraphQLRewrite.Printf("Injecting into nodes selection: fields=%q", injection)
182+
return buf.String()
183+
}
141184
}
142185

143186
logGraphQLRewrite.Printf("No injection point found in query for fields=%q", injection)
144187
return query
145188
}
146189

190+
// findParentField extracts the GraphQL connection field name that contains
191+
// the given nodes block. It walks backward from idx to find the enclosing
192+
// opening brace, then extracts the field name before it (skipping any
193+
// arguments in parentheses).
194+
func findParentField(query string, nodesIdx int) string {
195+
// Walk backward from nodesIdx to find the enclosing `{`
196+
depth := 0
197+
i := nodesIdx - 1
198+
braceFound := false
199+
for i >= 0 {
200+
switch query[i] {
201+
case '{':
202+
if depth == 0 {
203+
braceFound = true
204+
} else {
205+
depth--
206+
}
207+
case '}':
208+
depth++
209+
}
210+
if braceFound {
211+
break
212+
}
213+
i--
214+
}
215+
if !braceFound {
216+
return ""
217+
}
218+
219+
// i now points to the `{` of the enclosing block.
220+
// Walk backward past whitespace.
221+
i--
222+
for i >= 0 && (query[i] == ' ' || query[i] == '\n' || query[i] == '\t' || query[i] == '\r') {
223+
i--
224+
}
225+
// If there are parenthesized args, skip them.
226+
if i >= 0 && query[i] == ')' {
227+
parenDepth := 1
228+
i--
229+
for i >= 0 && parenDepth > 0 {
230+
switch query[i] {
231+
case ')':
232+
parenDepth++
233+
case '(':
234+
parenDepth--
235+
}
236+
i--
237+
}
238+
// Skip whitespace after field name
239+
for i >= 0 && (query[i] == ' ' || query[i] == '\n' || query[i] == '\t' || query[i] == '\r') {
240+
i--
241+
}
242+
}
243+
// Extract the field name (alphanumeric + underscore)
244+
end := i + 1
245+
for i >= 0 && isGraphQLFieldNameChar(query[i]) {
246+
i--
247+
}
248+
if i+1 >= end {
249+
return ""
250+
}
251+
return query[i+1 : end]
252+
}
253+
254+
// isGraphQLFieldNameChar returns true for characters valid in a GraphQL field name.
255+
func isGraphQLFieldNameChar(c byte) bool {
256+
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_'
257+
}
258+
147259
// injectIntoFragment adds a field to the end of a named fragment definition.
148260
// "fragment Name on Type { existing fields }" → "fragment Name on Type { existing fields field }"
149261
func injectIntoFragment(query, fragName, field string) string {

internal/proxy/graphql_rewrite_test.go

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package proxy
22

33
import (
44
"encoding/json"
5+
"strings"
56
"testing"
67

78
"github.com/stretchr/testify/assert"
@@ -247,6 +248,133 @@ func TestInjectGuardFields_PullRequestRead(t *testing.T) {
247248
assert.Contains(t, gql.Query, "authorAssociation")
248249
}
249250

251+
func TestInjectGuardFields_SkipsAssigneesNodes(t *testing.T) {
252+
// Reproduces the bug from the issue: gh pr view with assignees causes
253+
// "Field 'authorAssociation' doesn't exist on type 'User'" because
254+
// injection was applied to ALL nodes blocks including assignees.nodes.
255+
query := `query PullRequestByNumber {
256+
repository(owner:"o", name:"r") {
257+
pullRequest(number: 1820) {
258+
number title author{login}
259+
assignees(first: 100) { nodes { login } }
260+
}
261+
}
262+
}`
263+
body, _ := json.Marshal(GraphQLRequest{Query: query})
264+
265+
result := InjectGuardFields(body, "pull_request_read")
266+
267+
var gql GraphQLRequest
268+
require.NoError(t, json.Unmarshal(result, &gql))
269+
// authorAssociation should NOT be injected because the only nodes block
270+
// is assignees.nodes which returns User objects.
271+
assert.NotContains(t, gql.Query, "authorAssociation",
272+
"authorAssociation must not be injected into assignees.nodes (User type)")
273+
}
274+
275+
func TestInjectGuardFields_MixedNodesBlocks(t *testing.T) {
276+
// Query has both a safe connection (comments) and an unsafe one (assignees).
277+
// Injection should only go into comments.nodes, not assignees.nodes.
278+
query := `query {
279+
repository(owner:"o", name:"r") {
280+
pullRequest(number: 1) {
281+
assignees(first: 10) { nodes { login } }
282+
comments(first: 10) { nodes { body } }
283+
}
284+
}
285+
}`
286+
body, _ := json.Marshal(GraphQLRequest{Query: query})
287+
288+
result := InjectGuardFields(body, "pull_request_read")
289+
290+
var gql GraphQLRequest
291+
require.NoError(t, json.Unmarshal(result, &gql))
292+
assert.Contains(t, gql.Query, "author{login}")
293+
assert.Contains(t, gql.Query, "authorAssociation")
294+
// Verify injection is in comments.nodes, not assignees.nodes
295+
assert.Contains(t, gql.Query, `assignees(first: 10) { nodes { login } }`,
296+
"assignees.nodes should remain unmodified")
297+
assert.Contains(t, gql.Query, `comments(first: 10) { nodes {author{login},authorAssociation,`,
298+
"comments.nodes should have injected fields")
299+
}
300+
301+
func TestInjectGuardFields_SkipsLabelsNodes(t *testing.T) {
302+
// labels.nodes returns Label objects — no authorAssociation field.
303+
query := `query { repository(owner:"o", name:"r") { issues(first:10) { nodes { number labels(first:5) { nodes { name } } } } } }`
304+
body, _ := json.Marshal(GraphQLRequest{Query: query})
305+
306+
result := InjectGuardFields(body, "list_issues")
307+
308+
var gql GraphQLRequest
309+
require.NoError(t, json.Unmarshal(result, &gql))
310+
assert.Contains(t, gql.Query, "author{login}")
311+
assert.Contains(t, gql.Query, "authorAssociation")
312+
// labels.nodes must not have injected fields
313+
assert.Contains(t, gql.Query, `labels(first:5) { nodes { name } }`,
314+
"labels.nodes should remain unmodified")
315+
}
316+
317+
func TestInjectGuardFields_SkipsParticipantsNodes(t *testing.T) {
318+
// participants.nodes returns User objects — no authorAssociation field.
319+
query := `query {
320+
repository(owner:"o", name:"r") {
321+
pullRequest(number: 1) {
322+
reviews(first: 5) { nodes { body } }
323+
participants(first: 10) { nodes { login } }
324+
}
325+
}
326+
}`
327+
body, _ := json.Marshal(GraphQLRequest{Query: query})
328+
329+
result := InjectGuardFields(body, "pull_request_read")
330+
331+
var gql GraphQLRequest
332+
require.NoError(t, json.Unmarshal(result, &gql))
333+
assert.Contains(t, gql.Query, "author{login}")
334+
assert.Contains(t, gql.Query, "authorAssociation")
335+
// participants.nodes should be unmodified
336+
assert.Contains(t, gql.Query, `participants(first: 10) { nodes { login } }`)
337+
}
338+
339+
func TestFindParentField(t *testing.T) {
340+
tests := []struct {
341+
name string
342+
query string
343+
nodesIdx int // index of "nodes" in the query
344+
want string
345+
}{
346+
{
347+
name: "simple connection",
348+
query: `pullRequests(first:10) { nodes { number } }`,
349+
want: "pullRequests",
350+
},
351+
{
352+
name: "connection with totalCount before nodes",
353+
query: `issues(first:5) { totalCount nodes { title } }`,
354+
want: "issues",
355+
},
356+
{
357+
name: "nested connection",
358+
query: `pullRequest(number:1) { assignees(first:10) { nodes { login } } }`,
359+
want: "assignees",
360+
},
361+
{
362+
name: "connection without args",
363+
query: `comments { nodes { body } }`,
364+
want: "comments",
365+
},
366+
}
367+
368+
for _, tt := range tests {
369+
t.Run(tt.name, func(t *testing.T) {
370+
idx := strings.Index(tt.query, "nodes")
371+
require.NotEqual(t, -1, idx, "query must contain 'nodes'")
372+
got := findParentField(tt.query, idx)
373+
assert.Equal(t, tt.want, got)
374+
})
375+
}
376+
}
377+
250378
func TestInjectGuardFields_NoNodesNoFragment(t *testing.T) {
251379
// A query with required tool but no "nodes" block and no fragment spread —
252380
// the injector cannot find a place to insert fields, so body is returned unchanged.
@@ -284,11 +412,13 @@ func TestFieldsForTool(t *testing.T) {
284412

285413
for _, tt := range tests {
286414
t.Run("tool: "+tt.toolName, func(t *testing.T) {
287-
fields := fieldsForTool(tt.toolName)
415+
fields, safeParents := fieldsForTool(tt.toolName)
288416
if tt.wantNil {
289417
assert.Nil(t, fields, "expected nil fields for tool %q", tt.toolName)
418+
assert.Nil(t, safeParents, "expected nil safeParents for tool %q", tt.toolName)
290419
} else {
291420
require.NotNil(t, fields, "expected non-nil fields for tool %q", tt.toolName)
421+
require.NotNil(t, safeParents, "expected non-nil safeParents for tool %q", tt.toolName)
292422
fieldStrings := make([]string, len(fields))
293423
for i, f := range fields {
294424
fieldStrings[i] = f.field

internal/proxy/handler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"encoding/json"
77
"fmt"
88
"io"
9-
"log"
109
"net/http"
1110

1211
"go.opentelemetry.io/otel/attribute"
@@ -141,7 +140,8 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
141140
defer difcSpan.End()
142141

143142
if !s.guardInitialized {
144-
log.Printf("[proxy] WARNING: guard not initialized, blocking request")
143+
logHandler.Printf("returning 503: proxy enforcement not configured (no --policy flag provided)")
144+
logger.LogError("proxy", "returning 503: proxy enforcement not configured (no --policy flag provided)")
145145
http.Error(w, "proxy enforcement not configured", http.StatusServiceUnavailable)
146146
return
147147
}

0 commit comments

Comments
 (0)