Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pkg/http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
mcpHandler := mcp.NewStreamableHTTPHandler(func(_ *http.Request) *mcp.Server {
return ghServer
}, &mcp.StreamableHTTPOptions{
Stateless: true,
Stateless: true,
CrossOriginProtection: h.config.CrossOriginProtection,
})

mcpHandler.ServeHTTP(w, r)
Expand Down
114 changes: 114 additions & 0 deletions pkg/http/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http/httptest"
"slices"
"sort"
"strings"
"testing"

ghcontext "github.com/github/github-mcp-server/pkg/context"
Expand Down Expand Up @@ -660,3 +661,116 @@ func buildStaticInventoryFromTools(cfg *ServerConfig, tools []inventory.ServerTo
ctx := context.Background()
return inv.AvailableTools(ctx), inv.AvailableResourceTemplates(ctx), inv.AvailablePrompts(ctx)
}

func TestCrossOriginProtection(t *testing.T) {
jsonRPCBody := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"test","version":"0.1"}}}`

newHandler := func(t *testing.T, crossOriginProtection *http.CrossOriginProtection) http.Handler {
t.Helper()

apiHost, err := utils.NewAPIHost("https://api.githubcopilot.com")
require.NoError(t, err)

handler := NewHTTPMcpHandler(
context.Background(),
&ServerConfig{
Version: "test",
CrossOriginProtection: crossOriginProtection,
},
nil,
translations.NullTranslationHelper,
slog.Default(),
apiHost,
WithInventoryFactory(func(_ *http.Request) (*inventory.Inventory, error) {
return inventory.NewBuilder().Build()
}),
WithGitHubMCPServerFactory(func(_ *http.Request, _ github.ToolDependencies, _ *inventory.Inventory, _ *github.MCPServerConfig) (*mcp.Server, error) {
return mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil), nil
}),
WithScopeFetcher(allScopesFetcher{}),
)

r := chi.NewRouter()
handler.RegisterMiddleware(r)
handler.RegisterRoutes(r)
return r
}

tests := []struct {
name string
crossOriginProtection *http.CrossOriginProtection
secFetchSite string
origin string
expectedStatusCode int
}{
{
name: "SDK default rejects cross-site when no bypass configured",
secFetchSite: "cross-site",
origin: "https://evil.example.com",
expectedStatusCode: http.StatusForbidden,
},
{
name: "SDK default allows same-origin request",
secFetchSite: "same-origin",
expectedStatusCode: http.StatusOK,
},
{
name: "SDK default allows request without Sec-Fetch-Site (native client)",
secFetchSite: "",
expectedStatusCode: http.StatusOK,
},
{
name: "bypass protection allows cross-site request",
crossOriginProtection: func() *http.CrossOriginProtection {
p := http.NewCrossOriginProtection()
p.AddInsecureBypassPattern("/")
return p
}(),
secFetchSite: "cross-site",
origin: "https://example.com",
expectedStatusCode: http.StatusOK,
},
{
name: "bypass protection still allows same-origin request",
crossOriginProtection: func() *http.CrossOriginProtection {
p := http.NewCrossOriginProtection()
p.AddInsecureBypassPattern("/")
return p
}(),
secFetchSite: "same-origin",
expectedStatusCode: http.StatusOK,
},
{
name: "bypass protection allows request without Sec-Fetch-Site (native client)",
crossOriginProtection: func() *http.CrossOriginProtection {
p := http.NewCrossOriginProtection()
p.AddInsecureBypassPattern("/")
return p
}(),
secFetchSite: "",
expectedStatusCode: http.StatusOK,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := newHandler(t, tt.crossOriginProtection)

req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(jsonRPCBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
req.Header.Set(headers.AuthorizationHeader, "Bearer github_pat_xyz")
if tt.secFetchSite != "" {
req.Header.Set("Sec-Fetch-Site", tt.secFetchSite)
}
if tt.origin != "" {
req.Header.Set("Origin", tt.origin)
}

rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)

assert.Equal(t, tt.expectedStatusCode, rr.Code, "unexpected status code; body: %s", rr.Body.String())
})
}
}
43 changes: 43 additions & 0 deletions pkg/http/middleware/cors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package middleware

import (
"net/http"
"strings"

"github.com/github/github-mcp-server/pkg/http/headers"
)

// SetCorsHeaders is middleware that sets CORS headers to allow browser-based
// MCP clients to connect from any origin. This is safe because the server
// authenticates via bearer tokens (not cookies), so cross-origin requests
// cannot exploit ambient credentials.
func SetCorsHeaders(h http.Handler) http.Handler {
allowHeaders := strings.Join([]string{
"Content-Type",
"Mcp-Session-Id",
"Mcp-Protocol-Version",
"Last-Event-ID",
headers.AuthorizationHeader,
headers.MCPReadOnlyHeader,
headers.MCPToolsetsHeader,
headers.MCPToolsHeader,
headers.MCPExcludeToolsHeader,
headers.MCPFeaturesHeader,
headers.MCPLockdownHeader,
headers.MCPInsidersHeader,
}, ", ")

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
w.Header().Set("Access-Control-Max-Age", "86400")
w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id, WWW-Authenticate")
w.Header().Set("Access-Control-Allow-Headers", allowHeaders)

if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
h.ServeHTTP(w, r)
})
}
45 changes: 45 additions & 0 deletions pkg/http/middleware/cors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package middleware_test

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/github/github-mcp-server/pkg/http/middleware"
"github.com/stretchr/testify/assert"
)

func TestSetCorsHeaders(t *testing.T) {
inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := middleware.SetCorsHeaders(inner)

t.Run("OPTIONS preflight returns 200 with CORS headers", func(t *testing.T) {
req := httptest.NewRequest(http.MethodOptions, "/", nil)
req.Header.Set("Origin", "http://localhost:6274")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Methods"), "POST")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Authorization")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Content-Type")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "Mcp-Session-Id")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "X-MCP-Lockdown")
assert.Contains(t, rr.Header().Get("Access-Control-Allow-Headers"), "X-MCP-Insiders")
assert.Contains(t, rr.Header().Get("Access-Control-Expose-Headers"), "Mcp-Session-Id")
assert.Contains(t, rr.Header().Get("Access-Control-Expose-Headers"), "WWW-Authenticate")
})

t.Run("POST request includes CORS headers", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set("Origin", "http://localhost:6274")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "*", rr.Header().Get("Access-Control-Allow-Origin"))
})
}
16 changes: 16 additions & 0 deletions pkg/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

ghcontext "github.com/github/github-mcp-server/pkg/context"
"github.com/github/github-mcp-server/pkg/github"
"github.com/github/github-mcp-server/pkg/http/middleware"
"github.com/github/github-mcp-server/pkg/http/oauth"
"github.com/github/github-mcp-server/pkg/inventory"
"github.com/github/github-mcp-server/pkg/lockdown"
Expand Down Expand Up @@ -86,6 +87,11 @@ type ServerConfig struct {

// InsidersMode indicates if we should enable experimental features.
InsidersMode bool

// CrossOriginProtection configures the SDK's cross-origin request protection.
// If nil and using RunHTTPServer, cross-origin requests are allowed (auto-bypass).
// If nil and using the handler as a library, the SDK default (reject) applies.
CrossOriginProtection *http.CrossOriginProtection
}

func RunHTTPServer(cfg ServerConfig) error {
Expand Down Expand Up @@ -159,6 +165,14 @@ func RunHTTPServer(cfg ServerConfig) error {
serverOptions = append(serverOptions, WithScopeFetcher(scopeFetcher))
}

// Bypass cross-origin protection: this server uses bearer tokens, not
// cookies, so CSRF checks are unnecessary.
if cfg.CrossOriginProtection == nil {
p := http.NewCrossOriginProtection()
p.AddInsecureBypassPattern("/")
cfg.CrossOriginProtection = p
}

r := chi.NewRouter()
handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, apiHost, append(serverOptions, WithFeatureChecker(featureChecker), WithOAuthConfig(oauthCfg))...)
oauthHandler, err := oauth.NewAuthHandler(oauthCfg, apiHost)
Expand All @@ -167,6 +181,8 @@ func RunHTTPServer(cfg ServerConfig) error {
}

r.Group(func(r chi.Router) {
r.Use(middleware.SetCorsHeaders)

// Register Middleware First, needs to be before route registration
handler.RegisterMiddleware(r)

Expand Down
Loading