diff --git a/pkg/github/tools.go b/pkg/github/tools.go index e5e9502800..02b86a9d9a 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -145,7 +145,23 @@ var ( // When active, consolidated tools are replaced by single-purpose granular tools. FeatureFlagIssuesGranular = "issues_granular" FeatureFlagPullRequestsGranular = "pull_requests_granular" +) + +// headerAllowedFeatureFlags are the feature flags that clients may enable via the +// X-MCP-Features header. Only these flags are accepted from headers; unknown flags +// are silently ignored. +var headerAllowedFeatureFlags = []string{ + FeatureFlagIssuesGranular, + FeatureFlagPullRequestsGranular, +} +// HeaderAllowedFeatureFlags returns the feature flags that clients may enable via +// the X-MCP-Features header. +func HeaderAllowedFeatureFlags() []string { + return slices.Clone(headerAllowedFeatureFlags) +} + +var ( // Remote-only toolsets - these are only available in the remote MCP server // but are documented here for consistency and to enable automated documentation. ToolsetMetadataCopilotSpaces = inventory.ToolsetMetadata{ diff --git a/pkg/http/server.go b/pkg/http/server.go index 38ea0de301..83586509bc 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -27,7 +27,7 @@ import ( // knownFeatureFlags are the feature flags that can be enabled via X-MCP-Features header. // Only these flags are accepted from headers. -var knownFeatureFlags = []string{} +var knownFeatureFlags = github.HeaderAllowedFeatureFlags() type ServerConfig struct { // Version of the server diff --git a/pkg/http/server_test.go b/pkg/http/server_test.go new file mode 100644 index 0000000000..7aeabc5823 --- /dev/null +++ b/pkg/http/server_test.go @@ -0,0 +1,86 @@ +package http + +import ( + "context" + "testing" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/github" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateHTTPFeatureChecker_Whitelist(t *testing.T) { + checker := createHTTPFeatureChecker() + + tests := []struct { + name string + flagName string + headerFeatures []string + wantEnabled bool + }{ + { + name: "whitelisted issues_granular flag accepted from header", + flagName: github.FeatureFlagIssuesGranular, + headerFeatures: []string{github.FeatureFlagIssuesGranular}, + wantEnabled: true, + }, + { + name: "whitelisted pull_requests_granular flag accepted from header", + flagName: github.FeatureFlagPullRequestsGranular, + headerFeatures: []string{github.FeatureFlagPullRequestsGranular}, + wantEnabled: true, + }, + { + name: "unknown flag in header is ignored", + flagName: "unknown_flag", + headerFeatures: []string{"unknown_flag"}, + wantEnabled: false, + }, + { + name: "whitelisted flag not in header returns false", + flagName: github.FeatureFlagIssuesGranular, + headerFeatures: nil, + wantEnabled: false, + }, + { + name: "whitelisted flag with different flag in header returns false", + flagName: github.FeatureFlagIssuesGranular, + headerFeatures: []string{github.FeatureFlagPullRequestsGranular}, + wantEnabled: false, + }, + { + name: "multiple whitelisted flags in header", + flagName: github.FeatureFlagIssuesGranular, + headerFeatures: []string{github.FeatureFlagIssuesGranular, github.FeatureFlagPullRequestsGranular}, + wantEnabled: true, + }, + { + name: "empty header features", + flagName: github.FeatureFlagIssuesGranular, + headerFeatures: []string{}, + wantEnabled: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + if len(tt.headerFeatures) > 0 { + ctx = ghcontext.WithHeaderFeatures(ctx, tt.headerFeatures) + } + + enabled, err := checker(ctx, tt.flagName) + require.NoError(t, err) + assert.Equal(t, tt.wantEnabled, enabled) + }) + } +} + +func TestKnownFeatureFlagsMatchesHeaderAllowed(t *testing.T) { + // Ensure knownFeatureFlags stays in sync with HeaderAllowedFeatureFlags + allowed := github.HeaderAllowedFeatureFlags() + assert.Equal(t, allowed, knownFeatureFlags, + "knownFeatureFlags should match github.HeaderAllowedFeatureFlags()") + assert.NotEmpty(t, knownFeatureFlags, "knownFeatureFlags should not be empty") +}