Skip to content
14 changes: 9 additions & 5 deletions pkg/http/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
package oauth

import (
"context"
"fmt"
"net/http"
"strings"

"github.com/github/github-mcp-server/pkg/http/headers"
"github.com/github/github-mcp-server/pkg/utils"
"github.com/go-chi/chi/v5"
"github.com/modelcontextprotocol/go-sdk/auth"
"github.com/modelcontextprotocol/go-sdk/oauthex"
Expand All @@ -16,9 +18,6 @@ import (
const (
// OAuthProtectedResourcePrefix is the well-known path prefix for OAuth protected resource metadata.
OAuthProtectedResourcePrefix = "/.well-known/oauth-protected-resource"

// DefaultAuthorizationServer is GitHub's OAuth authorization server.
DefaultAuthorizationServer = "https://github.com/login/oauth"
)

// SupportedScopes lists all OAuth scopes that may be required by MCP tools.
Expand Down Expand Up @@ -59,14 +58,19 @@ type AuthHandler struct {
}

// NewAuthHandler creates a new OAuth auth handler.
func NewAuthHandler(cfg *Config) (*AuthHandler, error) {
func NewAuthHandler(ctx context.Context, cfg *Config, apiHost utils.APIHostResolver) (*AuthHandler, error) {
if cfg == nil {
cfg = &Config{}
}

// Default authorization server to GitHub
if cfg.AuthorizationServer == "" {
cfg.AuthorizationServer = DefaultAuthorizationServer
url, err := apiHost.AuthorizationServerURL(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get authorization server URL from API host: %w", err)
}

cfg.AuthorizationServer = url.String()
}
Comment thread
omgitsads marked this conversation as resolved.
Outdated

return &AuthHandler{
Expand Down
106 changes: 94 additions & 12 deletions pkg/http/oauth/oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,22 @@ import (
"testing"

"github.com/github/github-mcp-server/pkg/http/headers"
"github.com/github/github-mcp-server/pkg/utils"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var (
defaultAuthorizationServer = "https://github.com/login/oauth"
)

func TestNewAuthHandler(t *testing.T) {
t.Parallel()

dotcomHost, err := utils.NewAPIHost("https://api.github.com")
require.NoError(t, err)

tests := []struct {
name string
cfg *Config
Expand All @@ -25,13 +33,13 @@ func TestNewAuthHandler(t *testing.T) {
{
name: "nil config uses defaults",
cfg: nil,
expectedAuthServer: DefaultAuthorizationServer,
expectedAuthServer: defaultAuthorizationServer,
expectedResourcePath: "",
},
{
name: "empty config uses defaults",
cfg: &Config{},
expectedAuthServer: DefaultAuthorizationServer,
expectedAuthServer: defaultAuthorizationServer,
expectedResourcePath: "",
},
{
Expand All @@ -48,7 +56,7 @@ func TestNewAuthHandler(t *testing.T) {
BaseURL: "https://example.com",
ResourcePath: "/mcp",
},
expectedAuthServer: DefaultAuthorizationServer,
expectedAuthServer: defaultAuthorizationServer,
expectedResourcePath: "/mcp",
},
}
Expand All @@ -57,11 +65,12 @@ func TestNewAuthHandler(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

handler, err := NewAuthHandler(tc.cfg)
handler, err := NewAuthHandler(t.Context(), tc.cfg, dotcomHost)
require.NoError(t, err)
require.NotNil(t, handler)

assert.Equal(t, tc.expectedAuthServer, handler.cfg.AuthorizationServer)
assert.Equal(t, tc.expectedResourcePath, handler.cfg.ResourcePath)
})
}
}
Comment thread
omgitsads marked this conversation as resolved.
Expand Down Expand Up @@ -372,7 +381,7 @@ func TestHandleProtectedResource(t *testing.T) {
authServers, ok := body["authorization_servers"].([]any)
require.True(t, ok)
require.Len(t, authServers, 1)
assert.Equal(t, DefaultAuthorizationServer, authServers[0])
assert.Equal(t, defaultAuthorizationServer, authServers[0])
},
},
{
Expand Down Expand Up @@ -451,7 +460,10 @@ func TestHandleProtectedResource(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

handler, err := NewAuthHandler(tc.cfg)
dotcomHost, err := utils.NewAPIHost("https://api.github.com")
require.NoError(t, err)

handler, err := NewAuthHandler(t.Context(), tc.cfg, dotcomHost)
require.NoError(t, err)

router := chi.NewRouter()
Expand Down Expand Up @@ -493,9 +505,12 @@ func TestHandleProtectedResource(t *testing.T) {
func TestRegisterRoutes(t *testing.T) {
t.Parallel()

handler, err := NewAuthHandler(&Config{
dotcomHost, err := utils.NewAPIHost("https://api.github.com")
require.NoError(t, err)

handler, err := NewAuthHandler(t.Context(), &Config{
BaseURL: "https://api.example.com",
})
}, dotcomHost)
require.NoError(t, err)

router := chi.NewRouter()
Expand Down Expand Up @@ -559,9 +574,12 @@ func TestSupportedScopes(t *testing.T) {
func TestProtectedResourceResponseFormat(t *testing.T) {
t.Parallel()

handler, err := NewAuthHandler(&Config{
dotcomHost, err := utils.NewAPIHost("https://api.github.com")
require.NoError(t, err)

handler, err := NewAuthHandler(t.Context(), &Config{
BaseURL: "https://api.example.com",
})
}, dotcomHost)
require.NoError(t, err)

router := chi.NewRouter()
Expand Down Expand Up @@ -598,7 +616,7 @@ func TestProtectedResourceResponseFormat(t *testing.T) {
authServers, ok := response["authorization_servers"].([]any)
require.True(t, ok)
assert.Len(t, authServers, 1)
assert.Equal(t, DefaultAuthorizationServer, authServers[0])
assert.Equal(t, defaultAuthorizationServer, authServers[0])
}

func TestOAuthProtectedResourcePrefix(t *testing.T) {
Expand All @@ -611,5 +629,69 @@ func TestOAuthProtectedResourcePrefix(t *testing.T) {
func TestDefaultAuthorizationServer(t *testing.T) {
t.Parallel()

assert.Equal(t, "https://github.com/login/oauth", DefaultAuthorizationServer)
assert.Equal(t, "https://github.com/login/oauth", defaultAuthorizationServer)
}

func TestAPIHostResolver_AuthorizationServerURL(t *testing.T) {
t.Parallel()

tests := []struct {
name string
host string
expectedURL string
expectError bool
errorContains string
}{
{
name: "valid host returns authorization server URL",
Comment thread
omgitsads marked this conversation as resolved.
Outdated
host: "http://github.com",
expectedURL: "https://github.com/login/oauth",
expectError: false,
},
{
name: "invalid host returns error",
host: "://invalid-url",
expectedURL: "",
expectError: true,
errorContains: "could not parse host as URL",
},
{
name: "host without scheme returns error",
host: "github.com",
expectedURL: "",
expectError: true,
errorContains: "host must have a scheme",
},
{
name: "GHEC host returns correct authorization server URL",
host: "https://test.ghe.com",
expectedURL: "https://test.ghe.com/login/oauth",
},
{
name: "GHES host returns correct authorization server URL",
host: "https://ghe.example.com",
expectedURL: "https://ghe.example.com/login/oauth",
expectError: false,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

apiHost, err := utils.NewAPIHost(tc.host)
if tc.expectError {
require.Error(t, err)
if tc.errorContains != "" {
assert.Contains(t, err.Error(), tc.errorContains)
}
return
}
require.NoError(t, err)

url, err := apiHost.AuthorizationServerURL(t.Context())
require.NoError(t, err)
assert.Equal(t, tc.expectedURL, url.String())
})
}
}
2 changes: 1 addition & 1 deletion pkg/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func RunHTTPServer(cfg ServerConfig) error {

r := chi.NewRouter()
handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, apiHost, append(serverOptions, WithFeatureChecker(featureChecker), WithOAuthConfig(oauthCfg))...)
oauthHandler, err := oauth.NewAuthHandler(oauthCfg)
oauthHandler, err := oauth.NewAuthHandler(ctx, oauthCfg, apiHost)
if err != nil {
return fmt.Errorf("failed to create OAuth handler: %w", err)
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/scopes/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ func (t testAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) {
func (t testAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) {
return nil, nil
}
func (t testAPIHostResolver) AuthorizationServerURL(_ context.Context) (*url.URL, error) {
return nil, nil
}

func TestParseScopeHeader(t *testing.T) {
tests := []struct {
Expand Down
57 changes: 41 additions & 16 deletions pkg/utils/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ type APIHostResolver interface {
GraphqlURL(ctx context.Context) (*url.URL, error)
UploadURL(ctx context.Context) (*url.URL, error)
RawURL(ctx context.Context) (*url.URL, error)
AuthorizationServerURL(ctx context.Context) (*url.URL, error)
}

type APIHost struct {
restURL *url.URL
gqlURL *url.URL
uploadURL *url.URL
rawURL *url.URL
restURL *url.URL
gqlURL *url.URL
uploadURL *url.URL
rawURL *url.URL
authorizationServerURL *url.URL
}

var _ APIHostResolver = APIHost{}
Expand Down Expand Up @@ -52,6 +54,10 @@ func (a APIHost) RawURL(_ context.Context) (*url.URL, error) {
return a.rawURL, nil
}

func (a APIHost) AuthorizationServerURL(_ context.Context) (*url.URL, error) {
return a.authorizationServerURL, nil
}

func newDotcomHost() (APIHost, error) {
baseRestURL, err := url.Parse("https://api.github.com/")
if err != nil {
Expand All @@ -73,11 +79,18 @@ func newDotcomHost() (APIHost, error) {
return APIHost{}, fmt.Errorf("failed to parse dotcom Raw URL: %w", err)
}

// The authorization server for GitHub.com is at github.com/login/oauth, not api.github.com
authorizationServerURL, err := url.Parse("https://github.com/login/oauth")
if err != nil {
return APIHost{}, fmt.Errorf("failed to parse dotcom Authorization Server URL: %w", err)
}

return APIHost{
restURL: baseRestURL,
gqlURL: gqlURL,
uploadURL: uploadURL,
rawURL: rawURL,
restURL: baseRestURL,
gqlURL: gqlURL,
uploadURL: uploadURL,
rawURL: rawURL,
authorizationServerURL: authorizationServerURL,
}, nil
}

Expand Down Expand Up @@ -112,11 +125,17 @@ func newGHECHost(hostname string) (APIHost, error) {
return APIHost{}, fmt.Errorf("failed to parse GHEC Raw URL: %w", err)
}

authorizationServerURL, err := url.Parse(fmt.Sprintf("https://%s/login/oauth", u.Hostname()))
if err != nil {
return APIHost{}, fmt.Errorf("failed to parse GHEC Authorization Server URL: %w", err)
}

return APIHost{
restURL: restURL,
gqlURL: gqlURL,
uploadURL: uploadURL,
rawURL: rawURL,
restURL: restURL,
gqlURL: gqlURL,
uploadURL: uploadURL,
rawURL: rawURL,
authorizationServerURL: authorizationServerURL,
}, nil
}

Expand Down Expand Up @@ -164,11 +183,17 @@ func newGHESHost(hostname string) (APIHost, error) {
return APIHost{}, fmt.Errorf("failed to parse GHES Raw URL: %w", err)
}

authorizationServerURL, err := url.Parse(fmt.Sprintf("%s://%s/login/oauth", u.Scheme, u.Hostname()))
if err != nil {
return APIHost{}, fmt.Errorf("failed to parse GHES Authorization Server URL: %w", err)
}

return APIHost{
restURL: restURL,
gqlURL: gqlURL,
uploadURL: uploadURL,
rawURL: rawURL,
restURL: restURL,
gqlURL: gqlURL,
uploadURL: uploadURL,
rawURL: rawURL,
authorizationServerURL: authorizationServerURL,
}, nil
}

Expand Down