Skip to content

Commit 81b3a88

Browse files
Copilotlpcox
andauthored
refactor sys and guard logging duplication
Agent-Logs-Url: https://github.com/github/gh-aw-mcpg/sessions/ac7dc66f-a385-4612-8f8a-66a15160508c Co-authored-by: lpcox <15877973+lpcox@users.noreply.github.com>
1 parent f92f874 commit 81b3a88

4 files changed

Lines changed: 95 additions & 40 deletions

File tree

internal/guard/wasm.go

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ import (
2020

2121
var logWasm = logger.New("guard:wasm")
2222

23+
func logMarshaledForDebug(value interface{}, onMarshalSuccess func(string), onMarshalFailure func(error)) {
24+
resultJSON, err := json.Marshal(value)
25+
if err != nil {
26+
onMarshalFailure(err)
27+
return
28+
}
29+
onMarshalSuccess(string(resultJSON))
30+
}
31+
2332
// globalCompilationCache is a process-level compilation cache shared across all
2433
// WasmGuard instances. wazero's cache is goroutine-safe and eliminates redundant
2534
// JIT compilation when multiple guards load the same WASM binary.
@@ -695,12 +704,15 @@ func (g *WasmGuard) LabelAgent(ctx context.Context, policy interface{}, backend
695704
logWasm.Printf("LabelAgent normalizePolicyPayload failed: guard=%s, error=%v", g.name, err)
696705
return nil, err
697706
}
698-
normalizedPolicyJSON, normalizeMarshalErr := json.Marshal(normalizedPolicy)
699-
if normalizeMarshalErr != nil {
700-
logWasm.Printf("LabelAgent normalized policy (marshal failed): guard=%s, error=%v", g.name, normalizeMarshalErr)
701-
} else {
702-
logWasm.Printf("LabelAgent normalized policy: guard=%s, policy=%s", g.name, string(normalizedPolicyJSON))
703-
}
707+
logMarshaledForDebug(
708+
normalizedPolicy,
709+
func(policyJSON string) {
710+
logWasm.Printf("LabelAgent normalized policy: guard=%s, policy=%s", g.name, policyJSON)
711+
},
712+
func(marshalErr error) {
713+
logWasm.Printf("LabelAgent normalized policy (marshal failed): guard=%s, error=%v", g.name, marshalErr)
714+
},
715+
)
704716
_ = caps
705717

706718
input, err := buildStrictLabelAgentPayload(normalizedPolicy)
@@ -727,12 +739,15 @@ func (g *WasmGuard) LabelAgent(ctx context.Context, policy interface{}, backend
727739
return nil, err
728740
}
729741

730-
resultLogJSON, resultMarshalErr := json.Marshal(result)
731-
if resultMarshalErr != nil {
732-
logWasm.Printf("LabelAgent parsed response (marshal failed): guard=%s, error=%v", g.name, resultMarshalErr)
733-
} else {
734-
logWasm.Printf("LabelAgent parsed response: guard=%s, response=%s", g.name, string(resultLogJSON))
735-
}
742+
logMarshaledForDebug(
743+
result,
744+
func(responseJSON string) {
745+
logWasm.Printf("LabelAgent parsed response: guard=%s, response=%s", g.name, responseJSON)
746+
},
747+
func(marshalErr error) {
748+
logWasm.Printf("LabelAgent parsed response (marshal failed): guard=%s, error=%v", g.name, marshalErr)
749+
},
750+
)
736751

737752
return result, nil
738753
}

internal/server/guard_init.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ import (
1717

1818
var logGuardInit = logger.New("server:guard_init")
1919

20+
func logMarshaledForDebug(value interface{}, onMarshalSuccess func(string), onMarshalFailure func(error)) {
21+
resultJSON, err := json.Marshal(value)
22+
if err != nil {
23+
onMarshalFailure(err)
24+
return
25+
}
26+
onMarshalSuccess(string(resultJSON))
27+
}
28+
2029
// hasServerGuardPolicies reports whether any server in cfg has per-server guard policies
2130
// configured. This is used during DIFC auto-detection to enable enforcement when policies
2231
// are present even if no non-noop guard was registered (e.g., guard missing or failed to load).
@@ -351,12 +360,15 @@ func (us *UnifiedServer) ensureGuardInitialized(
351360
log.Printf("[DIFC] label_agent returned nil result: server=%s, session=%s, guard=%s", serverID, sessionID, g.Name())
352361
return defaultMode, fmt.Errorf("label_agent returned nil result")
353362
}
354-
resultJSON, marshalErr := json.Marshal(labelAgentResult)
355-
if marshalErr != nil {
356-
log.Printf("[DIFC] label_agent returned result (failed to serialize for logging): server=%s, session=%s, guard=%s, error=%v", serverID, sessionID, g.Name(), marshalErr)
357-
} else {
358-
log.Printf("[DIFC] label_agent response: server=%s, session=%s, guard=%s, response=%s", serverID, sessionID, g.Name(), string(resultJSON))
359-
}
363+
logMarshaledForDebug(
364+
labelAgentResult,
365+
func(resultJSON string) {
366+
log.Printf("[DIFC] label_agent response: server=%s, session=%s, guard=%s, response=%s", serverID, sessionID, g.Name(), resultJSON)
367+
},
368+
func(marshalErr error) {
369+
log.Printf("[DIFC] label_agent returned result (failed to serialize for logging): server=%s, session=%s, guard=%s, error=%v", serverID, sessionID, g.Name(), marshalErr)
370+
},
371+
)
360372

361373
mode := defaultMode
362374
if labelAgentResult.DIFCMode != "" {

internal/server/tool_registry.go

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,7 @@ func (us *UnifiedServer) registerToolsFromBackend(serverID string) error {
285285
if err != nil {
286286
logger.LogError("client", "MCP tool call error, session=%s, tool=%s, error=%v", sessionID, toolNameCopy, err)
287287
} else {
288-
resultJSON, _ := json.Marshal(data)
289-
sanitizedResult := sanitize.SanitizeString(string(resultJSON))
290-
logger.LogInfo("client", "MCP tool call response, session=%s, tool=%s, result=%s", sessionID, toolNameCopy, sanitizedResult)
288+
logger.LogInfo("client", "MCP tool call response, session=%s, tool=%s, result=%s", sessionID, toolNameCopy, marshalAndSanitizeForLog(data))
291289
}
292290

293291
return result, data, err
@@ -351,6 +349,22 @@ func (us *UnifiedServer) callSysServer(toolName string) (interface{}, error) {
351349
return result, nil
352350
}
353351

352+
func marshalAndSanitizeForLog(value interface{}) string {
353+
resultJSON, _ := json.Marshal(value)
354+
return sanitize.SanitizeString(string(resultJSON))
355+
}
356+
357+
func (us *UnifiedServer) callAndLogSysTool(sessionID, operationName, sysToolName string) (*sdk.CallToolResult, interface{}, error) {
358+
result, err := us.callSysServer(sysToolName)
359+
if err != nil {
360+
logger.LogError("client", "MCP %s call failed, session=%s, error=%v", operationName, sessionID, err)
361+
return mcp.NewErrorCallToolResult(err)
362+
}
363+
364+
logger.LogInfo("client", "MCP %s response, session=%s, result=%s", operationName, sessionID, marshalAndSanitizeForLog(result))
365+
return nil, result, nil
366+
}
367+
354368
// registerSysTools registers built-in sys tools
355369
func (us *UnifiedServer) registerSysTools() error {
356370
// Create sys_init handler
@@ -392,16 +406,7 @@ func (us *UnifiedServer) registerSysTools() error {
392406
logger.LogInfo("client", "MCP session initialized successfully, session=%s, available_servers=%v", sessionID, us.launcher.ServerIDs())
393407

394408
// Call sys_init
395-
result, err := us.callSysServer("sys_init")
396-
if err != nil {
397-
logger.LogError("client", "MCP session initialization: sys_init call failed, session=%s, error=%v", sessionID, err)
398-
return mcp.NewErrorCallToolResult(err)
399-
}
400-
401-
resultJSON, _ := json.Marshal(result)
402-
sanitizedResult := sanitize.SanitizeString(string(resultJSON))
403-
logger.LogInfo("client", "MCP session initialization complete, session=%s, result=%s", sessionID, sanitizedResult)
404-
return nil, result, nil
409+
return us.callAndLogSysTool(sessionID, "session initialization", "sys_init")
405410
}
406411

407412
// Register sys_init tool using helper
@@ -431,15 +436,7 @@ func (us *UnifiedServer) registerSysTools() error {
431436
return mcp.NewErrorCallToolResult(err)
432437
}
433438

434-
result, err := us.callSysServer("sys_list_servers")
435-
if err != nil {
436-
logger.LogError("client", "MCP sys_list_servers error, session=%s, error=%v", sessionID, err)
437-
return mcp.NewErrorCallToolResult(err)
438-
}
439-
440-
resultJSON, _ := json.Marshal(result)
441-
logger.LogInfo("client", "MCP sys_list_servers response, session=%s, result=%s", sessionID, string(resultJSON))
442-
return nil, result, nil
439+
return us.callAndLogSysTool(sessionID, "sys_list_servers", "sys_list_servers")
443440
}
444441

445442
// Register sys_list_servers tool using helper

internal/server/tool_registry_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,37 @@ func TestCallSysServer_UnknownTool(t *testing.T) {
496496
require.Error(err, "callSysServer with unknown tool should return an error")
497497
}
498498

499+
func TestMarshalAndSanitizeForLog_RedactsSecrets(t *testing.T) {
500+
assert := assert.New(t)
501+
502+
const secret = "ghp_1234567890123456789012345678901234567890"
503+
sanitized := marshalAndSanitizeForLog(map[string]interface{}{
504+
"token": secret,
505+
})
506+
507+
assert.Contains(sanitized, "[REDACTED]")
508+
assert.NotContains(sanitized, secret)
509+
}
510+
511+
func TestCallAndLogSysTool_UnknownToolReturnsErrorResult(t *testing.T) {
512+
assert := assert.New(t)
513+
require := require.New(t)
514+
515+
cfg := &config.Config{
516+
Servers: map[string]*config.ServerConfig{},
517+
}
518+
519+
us, err := NewUnified(context.Background(), cfg)
520+
require.NoError(err)
521+
defer us.Close()
522+
523+
result, data, callErr := us.callAndLogSysTool("session-id", "sys test", "nonexistent_tool")
524+
require.Error(callErr)
525+
require.NotNil(result)
526+
assert.Nil(data)
527+
assert.True(result.IsError)
528+
}
529+
499530
// TestRegisterAllToolsParallel_EmptyList verifies that parallel registration with no
500531
// servers does not block and returns immediately.
501532
func TestRegisterAllToolsParallel_EmptyList(t *testing.T) {

0 commit comments

Comments
 (0)