Skip to content

Commit 46495a3

Browse files
authored
Fix GraphQL authorAssociation injection into User-type nodes; log 503 on missing policy (#3413)
The DIFC proxy's GraphQL rewriter blindly injected `authorAssociation` into **all** `nodes {}` blocks via `ReplaceAllString`. Queries with `assignees`, `participants`, `labels`, etc. failed with `Field 'authorAssociation' doesn't exist on type 'User'`. Separately, 503 responses from missing `--policy` were not logged, making debugging costly. ### GraphQL injection fix - Added safe-parent allowlists per field group — only inject into `nodes {}` blocks whose parent connection returns types that actually have the field: - Issue/PR fields (`author{login}`, `authorAssociation`): `pullRequests`, `issues`, `comments`, `reviews`, `search` - Commit fields (`author{user{login}}`): `history` - Added `findParentField()` — walks backward from each `nodes {` match to extract the enclosing connection field name, handling args and nested braces - `fieldsForTool()` now returns `([]guardFieldSet, map[string]bool)` to thread safe parents through to injection Before (breaks on `assignees.nodes → User`): ```graphql assignees(first:100) { nodes {authorAssociation, login } } ``` After (skips unsafe parents): ```graphql assignees(first:100) { nodes { login } } comments(first:10) { nodes {author{login},authorAssociation, body } } ``` ### 503 logging fix - Replaced `log.Printf` with `logHandler.Printf` + `logger.LogError` so missing-policy 503s appear in both debug output and `proxy.log` - Removed unused `log` import > [!WARNING] > > <details> > <summary>Firewall rules blocked me from connecting to one or more addresses (expand for details)</summary> > > #### I tried to connect to the following addresses, but was blocked by firewall rules: > > - `example.com` > - Triggering command: `/tmp/go-build2109691296/b514/launcher.test /tmp/go-build2109691296/b514/launcher.test -test.testlogfile=/tmp/go-build2109691296/b514/testlog.txt -test.paniconexit0 -test.timeout=10m0s -I olang.org/grpc@v1.80.0/internal/go1.25.8 .cfg x_amd64/vet --gdwarf-5 --64 -o x_amd64/vet -I 2335907/b437/_pkg_.a dU_c/JEVDElkKgVJXLgEedU_c x_amd64/vet --gdwarf-5 g/grpc/encoding/-atomic -o x_amd64/vet` (dns block) > - `invalid-host-that-does-not-exist-12345.com` > - Triggering command: `/tmp/go-build2109691296/b496/config.test /tmp/go-build2109691296/b496/config.test -test.testlogfile=/tmp/go-build2109691296/b496/testlog.txt -test.paniconexit0 -test.timeout=10m0s ortc�� H/bin/golangci-lint run --timeout=5m || echo &#34;��� Warning: golangci-lint failed (compatibility issue with Go 1.25.0). Continuing with other checks...&#34;; \ elif command -v golan 1.80.0/balancer_wrapper.go 64/pkg/tool/linux_amd64/vet . -imultiarch x86_64-linux-gnu/tmp/go-build547184664/b347/vet.cfg 64/pkg/tool/linux_amd64/vet 2335�� olang.org/grpc@v1.80.0/internal/-I olang.org/grpc@v1.80.0/internal//tmp/go-build3905974454/b444/ x_amd64/vet . --gdwarf2 --64 x_amd64/vet` (dns block) > - `nonexistent.local` > - Triggering command: `/tmp/go-build2109691296/b514/launcher.test /tmp/go-build2109691296/b514/launcher.test -test.testlogfile=/tmp/go-build2109691296/b514/testlog.txt -test.paniconexit0 -test.timeout=10m0s -I olang.org/grpc@v1.80.0/internal/go1.25.8 .cfg x_amd64/vet --gdwarf-5 --64 -o x_amd64/vet -I 2335907/b437/_pkg_.a dU_c/JEVDElkKgVJXLgEedU_c x_amd64/vet --gdwarf-5 g/grpc/encoding/-atomic -o x_amd64/vet` (dns block) > - `slow.example.com` > - Triggering command: `/tmp/go-build2109691296/b514/launcher.test /tmp/go-build2109691296/b514/launcher.test -test.testlogfile=/tmp/go-build2109691296/b514/testlog.txt -test.paniconexit0 -test.timeout=10m0s -I olang.org/grpc@v1.80.0/internal/go1.25.8 .cfg x_amd64/vet --gdwarf-5 --64 -o x_amd64/vet -I 2335907/b437/_pkg_.a dU_c/JEVDElkKgVJXLgEedU_c x_amd64/vet --gdwarf-5 g/grpc/encoding/-atomic -o x_amd64/vet` (dns block) > - `this-host-does-not-exist-12345.com` > - Triggering command: `/tmp/go-build2109691296/b523/mcp.test /tmp/go-build2109691296/b523/mcp.test -test.testlogfile=/tmp/go-build2109691296/b523/testlog.txt -test.paniconexit0 -test.timeout=10m0s` (dns block) > > If you need me to access, download, or install something from one of these locations, you can either: > > - Configure [Actions setup steps](https://gh.io/copilot/actions-setup-steps) to set up my environment, which run before the firewall is enabled > - Add the appropriate URLs or hosts to the custom allowlist in this repository's [Copilot coding agent settings](https://github.com/github/gh-aw-mcpg/settings/copilot/coding_agent) (admins only) > > </details>
2 parents 3ad6880 + cee726f commit 46495a3

File tree

3 files changed

+254
-16
lines changed

3 files changed

+254
-16
lines changed

internal/proxy/graphql_rewrite.go

Lines changed: 120 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,36 @@ var commitFields = []guardFieldSet{
3232
{"author{user{login}}", regexp.MustCompile(`\bauthor\s*\{[^}]*\buser\s*\{[^}]*\blogin\b`)},
3333
}
3434

35-
// fieldsForTool returns the guard fields applicable to the given tool name,
36-
// or nil if no injection is needed.
37-
func fieldsForTool(toolName string) []guardFieldSet {
35+
// issueAndPRSafeParents are connection field names whose node types support
36+
// author{login} and authorAssociation (i.e., types implementing the Comment
37+
// interface: Issue, PullRequest, IssueComment, PullRequestReview, etc.).
38+
// Injection into other connection nodes (e.g. assignees→User, labels→Label)
39+
// would cause GraphQL validation errors.
40+
var issueAndPRSafeParents = map[string]bool{
41+
"pullRequests": true,
42+
"issues": true,
43+
"comments": true,
44+
"reviews": true,
45+
"search": true,
46+
}
47+
48+
// commitSafeParents are connection field names whose node types support
49+
// author{user{login}} (Commit type).
50+
var commitSafeParents = map[string]bool{
51+
"history": true,
52+
}
53+
54+
// fieldsForTool returns the guard fields and safe parent connection names
55+
// applicable to the given tool name, or nil if no injection is needed.
56+
func fieldsForTool(toolName string) ([]guardFieldSet, map[string]bool) {
3857
switch toolName {
3958
case "list_issues", "list_pull_requests", "issue_read", "pull_request_read",
4059
"search_issues":
41-
return issueAndPRFields
60+
return issueAndPRFields, issueAndPRSafeParents
4261
case "list_commits":
43-
return commitFields
62+
return commitFields, commitSafeParents
4463
default:
45-
return nil
64+
return nil, nil
4665
}
4766
}
4867

@@ -73,7 +92,7 @@ func missingFields(query string, fields []guardFieldSet) []string {
7392
// Returns the (possibly modified) body. If injection is not needed or fails,
7493
// the original body is returned unchanged.
7594
func InjectGuardFields(body []byte, toolName string) []byte {
76-
fields := fieldsForTool(toolName)
95+
fields, safeParents := fieldsForTool(toolName)
7796
if fields == nil {
7897
logGraphQLRewrite.Printf("No guard field injection needed for tool=%s", toolName)
7998
return body
@@ -90,7 +109,7 @@ func InjectGuardFields(body []byte, toolName string) []byte {
90109
}
91110

92111
missing := missingFields(gql.Query, fields)
93-
modified := injectFieldsIntoQuery(gql.Query, missing)
112+
modified := injectFieldsIntoQuery(gql.Query, missing, safeParents)
94113
if modified == gql.Query {
95114
return body
96115
}
@@ -108,7 +127,10 @@ func InjectGuardFields(body []byte, toolName string) []byte {
108127
// injectFieldsIntoQuery adds the given fields into the GraphQL query's node
109128
// selection or fragment. Each field string (e.g. "author{login}",
110129
// "authorAssociation") is comma-joined and injected as a single block.
111-
func injectFieldsIntoQuery(query string, fields []string) string {
130+
// safeParents limits direct nodes injection (Step 3) to nodes blocks whose
131+
// parent connection field is in the set, preventing injection into User/Label
132+
// type nodes that don't support the injected fields.
133+
func injectFieldsIntoQuery(query string, fields []string, safeParents map[string]bool) string {
112134
injection := strings.Join(fields, ",")
113135

114136
// Step 1: Check if the query uses a fragment spread in the nodes.
@@ -134,16 +156,101 @@ func injectFieldsIntoQuery(query string, fields []string) string {
134156
}
135157

136158
// Step 3: No fragment — inject directly into nodes { ... }
137-
nodesPattern := regexp.MustCompile(`(nodes\s*\{)`)
138-
if nodesPattern.MatchString(query) {
139-
logGraphQLRewrite.Printf("Injecting into nodes selection: fields=%q", injection)
140-
return nodesPattern.ReplaceAllString(query, "${1}"+injection+",")
159+
// Only inject into nodes blocks whose parent connection field is in the
160+
// safeParents set. This prevents injecting fields like authorAssociation
161+
// into nodes of types that don't support them (e.g. User, Label, Team).
162+
nodesPattern := regexp.MustCompile(`nodes\s*\{`)
163+
matches := nodesPattern.FindAllStringIndex(query, -1)
164+
if len(matches) > 0 {
165+
var buf strings.Builder
166+
pos := 0
167+
injected := false
168+
for _, m := range matches {
169+
parent := findParentField(query, m[0])
170+
buf.WriteString(query[pos:m[1]])
171+
if safeParents[parent] {
172+
buf.WriteString(injection + ",")
173+
injected = true
174+
} else {
175+
logGraphQLRewrite.Printf("Skipping injection into nodes under %q (not a safe parent)", parent)
176+
}
177+
pos = m[1]
178+
}
179+
buf.WriteString(query[pos:])
180+
if injected {
181+
logGraphQLRewrite.Printf("Injecting into nodes selection: fields=%q", injection)
182+
return buf.String()
183+
}
141184
}
142185

143186
logGraphQLRewrite.Printf("No injection point found in query for fields=%q", injection)
144187
return query
145188
}
146189

190+
// findParentField extracts the GraphQL connection field name that contains
191+
// the given nodes block. It walks backward from idx to find the enclosing
192+
// opening brace, then extracts the field name before it (skipping any
193+
// arguments in parentheses).
194+
func findParentField(query string, nodesIdx int) string {
195+
// Walk backward from nodesIdx to find the enclosing `{`
196+
depth := 0
197+
i := nodesIdx - 1
198+
for i >= 0 {
199+
switch query[i] {
200+
case '{':
201+
if depth == 0 {
202+
goto foundBrace
203+
}
204+
depth--
205+
case '}':
206+
depth++
207+
}
208+
i--
209+
}
210+
return "" // no enclosing brace found
211+
212+
foundBrace:
213+
214+
// i now points to the `{` of the enclosing block.
215+
// Walk backward past whitespace.
216+
i--
217+
for i >= 0 && (query[i] == ' ' || query[i] == '\n' || query[i] == '\t' || query[i] == '\r') {
218+
i--
219+
}
220+
// If there are parenthesized args, skip them.
221+
if i >= 0 && query[i] == ')' {
222+
parenDepth := 1
223+
i--
224+
for i >= 0 && parenDepth > 0 {
225+
switch query[i] {
226+
case ')':
227+
parenDepth++
228+
case '(':
229+
parenDepth--
230+
}
231+
i--
232+
}
233+
// Skip whitespace between the argument list and the field name.
234+
for i >= 0 && (query[i] == ' ' || query[i] == '\n' || query[i] == '\t' || query[i] == '\r') {
235+
i--
236+
}
237+
}
238+
// Extract the field name (alphanumeric + underscore)
239+
end := i + 1
240+
for i >= 0 && isGraphQLFieldNameChar(query[i]) {
241+
i--
242+
}
243+
if i+1 >= end {
244+
return ""
245+
}
246+
return query[i+1 : end]
247+
}
248+
249+
// isGraphQLFieldNameChar returns true for characters valid in a GraphQL field name.
250+
func isGraphQLFieldNameChar(c byte) bool {
251+
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_'
252+
}
253+
147254
// injectIntoFragment adds a field to the end of a named fragment definition.
148255
// "fragment Name on Type { existing fields }" → "fragment Name on Type { existing fields field }"
149256
func injectIntoFragment(query, fragName, field string) string {

internal/proxy/graphql_rewrite_test.go

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package proxy
22

33
import (
44
"encoding/json"
5+
"strings"
56
"testing"
67

78
"github.com/stretchr/testify/assert"
@@ -247,6 +248,133 @@ func TestInjectGuardFields_PullRequestRead(t *testing.T) {
247248
assert.Contains(t, gql.Query, "authorAssociation")
248249
}
249250

251+
func TestInjectGuardFields_SkipsAssigneesNodes(t *testing.T) {
252+
// Reproduces the bug from the issue: gh pr view with assignees causes
253+
// "Field 'authorAssociation' doesn't exist on type 'User'" because
254+
// injection was applied to ALL nodes blocks including assignees.nodes.
255+
query := `query PullRequestByNumber {
256+
repository(owner:"o", name:"r") {
257+
pullRequest(number: 1820) {
258+
number title author{login}
259+
assignees(first: 100) { nodes { login } }
260+
}
261+
}
262+
}`
263+
body, _ := json.Marshal(GraphQLRequest{Query: query})
264+
265+
result := InjectGuardFields(body, "pull_request_read")
266+
267+
var gql GraphQLRequest
268+
require.NoError(t, json.Unmarshal(result, &gql))
269+
// authorAssociation should NOT be injected because the only nodes block
270+
// is assignees.nodes which returns User objects.
271+
assert.NotContains(t, gql.Query, "authorAssociation",
272+
"authorAssociation must not be injected into assignees.nodes (User type)")
273+
}
274+
275+
func TestInjectGuardFields_MixedNodesBlocks(t *testing.T) {
276+
// Query has both a safe connection (comments) and an unsafe one (assignees).
277+
// Injection should only go into comments.nodes, not assignees.nodes.
278+
query := `query {
279+
repository(owner:"o", name:"r") {
280+
pullRequest(number: 1) {
281+
assignees(first: 10) { nodes { login } }
282+
comments(first: 10) { nodes { body } }
283+
}
284+
}
285+
}`
286+
body, _ := json.Marshal(GraphQLRequest{Query: query})
287+
288+
result := InjectGuardFields(body, "pull_request_read")
289+
290+
var gql GraphQLRequest
291+
require.NoError(t, json.Unmarshal(result, &gql))
292+
assert.Contains(t, gql.Query, "author{login}")
293+
assert.Contains(t, gql.Query, "authorAssociation")
294+
// Verify injection is in comments.nodes, not assignees.nodes
295+
assert.Contains(t, gql.Query, `assignees(first: 10) { nodes { login } }`,
296+
"assignees.nodes should remain unmodified")
297+
assert.Contains(t, gql.Query, `comments(first: 10) { nodes {author{login},authorAssociation,`,
298+
"comments.nodes should have injected fields")
299+
}
300+
301+
func TestInjectGuardFields_SkipsLabelsNodes(t *testing.T) {
302+
// labels.nodes returns Label objects — no authorAssociation field.
303+
query := `query { repository(owner:"o", name:"r") { issues(first:10) { nodes { number labels(first:5) { nodes { name } } } } } }`
304+
body, _ := json.Marshal(GraphQLRequest{Query: query})
305+
306+
result := InjectGuardFields(body, "list_issues")
307+
308+
var gql GraphQLRequest
309+
require.NoError(t, json.Unmarshal(result, &gql))
310+
assert.Contains(t, gql.Query, "author{login}")
311+
assert.Contains(t, gql.Query, "authorAssociation")
312+
// labels.nodes must not have injected fields
313+
assert.Contains(t, gql.Query, `labels(first:5) { nodes { name } }`,
314+
"labels.nodes should remain unmodified")
315+
}
316+
317+
func TestInjectGuardFields_SkipsParticipantsNodes(t *testing.T) {
318+
// participants.nodes returns User objects — no authorAssociation field.
319+
query := `query {
320+
repository(owner:"o", name:"r") {
321+
pullRequest(number: 1) {
322+
reviews(first: 5) { nodes { body } }
323+
participants(first: 10) { nodes { login } }
324+
}
325+
}
326+
}`
327+
body, _ := json.Marshal(GraphQLRequest{Query: query})
328+
329+
result := InjectGuardFields(body, "pull_request_read")
330+
331+
var gql GraphQLRequest
332+
require.NoError(t, json.Unmarshal(result, &gql))
333+
assert.Contains(t, gql.Query, "author{login}")
334+
assert.Contains(t, gql.Query, "authorAssociation")
335+
// participants.nodes should be unmodified
336+
assert.Contains(t, gql.Query, `participants(first: 10) { nodes { login } }`)
337+
}
338+
339+
func TestFindParentField(t *testing.T) {
340+
tests := []struct {
341+
name string
342+
query string
343+
nodesIdx int // index of "nodes" in the query
344+
want string
345+
}{
346+
{
347+
name: "simple connection",
348+
query: `pullRequests(first:10) { nodes { number } }`,
349+
want: "pullRequests",
350+
},
351+
{
352+
name: "connection with totalCount before nodes",
353+
query: `issues(first:5) { totalCount nodes { title } }`,
354+
want: "issues",
355+
},
356+
{
357+
name: "nested connection",
358+
query: `pullRequest(number:1) { assignees(first:10) { nodes { login } } }`,
359+
want: "assignees",
360+
},
361+
{
362+
name: "connection without args",
363+
query: `comments { nodes { body } }`,
364+
want: "comments",
365+
},
366+
}
367+
368+
for _, tt := range tests {
369+
t.Run(tt.name, func(t *testing.T) {
370+
idx := strings.Index(tt.query, "nodes")
371+
require.NotEqual(t, -1, idx, "query must contain 'nodes'")
372+
got := findParentField(tt.query, idx)
373+
assert.Equal(t, tt.want, got)
374+
})
375+
}
376+
}
377+
250378
func TestInjectGuardFields_NoNodesNoFragment(t *testing.T) {
251379
// A query with required tool but no "nodes" block and no fragment spread —
252380
// the injector cannot find a place to insert fields, so body is returned unchanged.
@@ -284,11 +412,13 @@ func TestFieldsForTool(t *testing.T) {
284412

285413
for _, tt := range tests {
286414
t.Run("tool: "+tt.toolName, func(t *testing.T) {
287-
fields := fieldsForTool(tt.toolName)
415+
fields, safeParents := fieldsForTool(tt.toolName)
288416
if tt.wantNil {
289417
assert.Nil(t, fields, "expected nil fields for tool %q", tt.toolName)
418+
assert.Nil(t, safeParents, "expected nil safeParents for tool %q", tt.toolName)
290419
} else {
291420
require.NotNil(t, fields, "expected non-nil fields for tool %q", tt.toolName)
421+
require.NotNil(t, safeParents, "expected non-nil safeParents for tool %q", tt.toolName)
292422
fieldStrings := make([]string, len(fields))
293423
for i, f := range fields {
294424
fieldStrings[i] = f.field

internal/proxy/handler.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"encoding/json"
77
"fmt"
88
"io"
9-
"log"
109
"net/http"
1110

1211
"go.opentelemetry.io/otel/attribute"
@@ -141,7 +140,9 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
141140
defer difcSpan.End()
142141

143142
if !s.guardInitialized {
144-
log.Printf("[proxy] WARNING: guard not initialized, blocking request")
143+
errMsg := "returning 503: proxy enforcement not configured (no --policy flag provided)"
144+
logHandler.Print(errMsg)
145+
logger.LogError("proxy", "%s", errMsg)
145146
http.Error(w, "proxy enforcement not configured", http.StatusServiceUnavailable)
146147
return
147148
}

0 commit comments

Comments
 (0)