Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/mcp/server/auth/handlers/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,13 @@ async def handle(self, request: Request):
)

try:
form_data = await request.form()
# TODO(Marcelo): Can someone check if this `dict()` wrapper is necessary?
token_request = token_request_adapter.validate_python(dict(form_data))
form_data = dict(await request.form())
# client_id may have been supplied via HTTP Basic auth header instead of the
# request body (RFC 6749 §2.3.1). ClientAuthenticator already verified it,
# so we can safely populate it from client_info when absent from form data.
if "client_id" not in form_data:
form_data["client_id"] = client_info.client_id
token_request = token_request_adapter.validate_python(form_data)
except ValidationError as validation_error: # pragma: no cover
return self.response(
TokenErrorResponse(
Expand Down
10 changes: 10 additions & 0 deletions src/mcp/server/auth/middleware/client_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation
"""
form_data = await request.form()
client_id = form_data.get("client_id")
if not client_id:
# RFC 6749 §2.3.1: client credentials MAY be sent via HTTP Basic auth
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
try:
decoded = base64.b64decode(auth_header[6:]).decode("utf-8")
if ":" in decoded:
client_id = unquote(decoded.split(":", 1)[0])
except (ValueError, UnicodeDecodeError, binascii.Error):
pass
if not client_id:
raise AuthenticationError("Missing client_id")

Expand Down
9 changes: 7 additions & 2 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def main():

import anyio
from opentelemetry.trace import SpanKind, StatusCode
from pydantic import AnyHttpUrl
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
Expand Down Expand Up @@ -633,8 +634,11 @@ def streamable_http_app(
# Determine resource metadata URL
resource_metadata_url = None
if auth and auth.resource_server_url:
# RFC 9728: resource identifier must match the URL clients use to access
# the protected resource, including the transport path (e.g. /mcp)
actual_resource_url = AnyHttpUrl(str(auth.resource_server_url).rstrip("/") + streamable_http_path)
# Build compliant metadata URL for WWW-Authenticate header
resource_metadata_url = build_resource_metadata_url(auth.resource_server_url)
resource_metadata_url = build_resource_metadata_url(actual_resource_url)

routes.append(
Route(
Expand All @@ -653,9 +657,10 @@ def streamable_http_app(

# Add protected resource metadata endpoint if configured as RS
if auth and auth.resource_server_url: # pragma: no cover
actual_resource_url = AnyHttpUrl(str(auth.resource_server_url).rstrip("/") + streamable_http_path)
routes.extend(
create_protected_resource_routes(
resource_url=auth.resource_server_url,
resource_url=actual_resource_url,
authorization_servers=[auth.issuer_url],
scopes_supported=auth.required_scopes,
)
Expand Down
9 changes: 7 additions & 2 deletions src/mcp/server/mcpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import anyio
import pydantic_core
from pydantic import AnyHttpUrl
from pydantic.networks import AnyUrl
from pydantic_settings import BaseSettings, SettingsConfigDict
from starlette.applications import Starlette
Expand Down Expand Up @@ -987,8 +988,11 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no
if self.settings.auth and self.settings.auth.resource_server_url:
from mcp.server.auth.routes import build_resource_metadata_url

# RFC 9728: resource identifier must match the URL clients use to access
# the protected resource, including the transport path (e.g. /sse)
actual_resource_url = AnyHttpUrl(str(self.settings.auth.resource_server_url).rstrip("/") + sse_path)
# Build compliant metadata URL for WWW-Authenticate header
resource_metadata_url = build_resource_metadata_url(self.settings.auth.resource_server_url)
resource_metadata_url = build_resource_metadata_url(actual_resource_url)

# Auth is enabled, wrap the endpoints with RequireAuthMiddleware
routes.append(
Expand Down Expand Up @@ -1028,9 +1032,10 @@ async def sse_endpoint(request: Request) -> Response: # pragma: no cover
if self.settings.auth and self.settings.auth.resource_server_url: # pragma: no cover
from mcp.server.auth.routes import create_protected_resource_routes

actual_resource_url = AnyHttpUrl(str(self.settings.auth.resource_server_url).rstrip("/") + sse_path)
routes.extend(
create_protected_resource_routes(
resource_url=self.settings.auth.resource_server_url,
resource_url=actual_resource_url,
authorization_servers=[self.settings.auth.issuer_url],
scopes_supported=self.settings.auth.required_scopes,
)
Expand Down
65 changes: 65 additions & 0 deletions tests/server/auth/test_protected_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,68 @@ def test_route_consistency_consistent_paths_for_various_resources(resource_url:
assert url_path == expected_path
assert route_path == expected_path
assert url_path == route_path


# Tests for issue #1264: resource URL must include transport path


@pytest.mark.parametrize(
"resource_server_url,transport_path,expected_resource,expected_metadata_url",
[
(
"http://localhost:8000",
"/mcp",
"http://localhost:8000/mcp",
"http://localhost:8000/.well-known/oauth-protected-resource/mcp",
),
(
"http://localhost:8000/",
"/mcp",
"http://localhost:8000/mcp",
"http://localhost:8000/.well-known/oauth-protected-resource/mcp",
),
(
"https://mcp.example.com",
"/sse",
"https://mcp.example.com/sse",
"https://mcp.example.com/.well-known/oauth-protected-resource/sse",
),
],
)
def test_resource_url_includes_transport_path(
resource_server_url: str,
transport_path: str,
expected_resource: str,
expected_metadata_url: str,
):
"""Transport path must be appended to resource_server_url (issue #1264).

Per RFC 9728, the resource identifier must match the URL clients use to access
the protected resource — e.g. http://localhost:8000/mcp, not http://localhost:8000/.
"""
actual_resource_url = AnyHttpUrl(resource_server_url.rstrip("/") + transport_path)

assert str(actual_resource_url) == expected_resource

metadata_url = build_resource_metadata_url(actual_resource_url)
assert str(metadata_url) == expected_metadata_url


@pytest.mark.anyio
async def test_protected_resource_metadata_contains_transport_path():
"""Metadata endpoint returns resource URL with transport path, not bare server URL."""
resource_url = AnyHttpUrl("http://localhost:8000/mcp")
routes = create_protected_resource_routes(
resource_url=resource_url,
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
scopes_supported=["read", "write"],
)
app = Starlette(routes=routes)

async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://localhost:8000") as client:
response = await client.get("/.well-known/oauth-protected-resource/mcp")
assert response.status_code == 200
data = response.json()
# resource must be the full endpoint URL, not the bare server base
assert data["resource"] == "http://localhost:8000/mcp"
assert data["resource"] != "http://localhost:8000/"
87 changes: 87 additions & 0 deletions tests/server/mcpserver/auth/test_auth_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,93 @@ async def test_none_auth_method_public_client(
token_response = response.json()
assert "access_token" in token_response

@pytest.mark.anyio
async def test_basic_auth_without_client_id_in_body(
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
):
"""Test RFC 6749 §2.3.1: client_id supplied only via Basic auth header, not in body."""
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
"client_name": "Basic Auth Only Header Client",
"token_endpoint_auth_method": "client_secret_basic",
"grant_types": ["authorization_code", "refresh_token"],
}

response = await test_client.post("/register", json=client_metadata)
assert response.status_code == 201
client_info = response.json()

auth_code = f"code_{int(time.time())}"
mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode(
code=auth_code,
client_id=client_info["client_id"],
code_challenge=pkce_challenge["code_challenge"],
redirect_uri=AnyUrl("https://client.example.com/callback"),
redirect_uri_provided_explicitly=True,
scopes=["read", "write"],
expires_at=time.time() + 600,
)

credentials = f"{client_info['client_id']}:{client_info['client_secret']}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()

# client_id intentionally omitted from body — only in Authorization header
response = await test_client.post(
"/token",
headers={"Authorization": f"Basic {encoded_credentials}"},
data={
"grant_type": "authorization_code",
"code": auth_code,
"code_verifier": pkce_challenge["code_verifier"],
"redirect_uri": "https://client.example.com/callback",
},
)
assert response.status_code == 200, f"Expected 200, got {response.status_code}: {response.text}"
token_response = response.json()
assert "access_token" in token_response

@pytest.mark.anyio
async def test_basic_auth_refresh_token_without_client_id_in_body(
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
):
"""Test RFC 6749 §2.3.1: refresh_token grant with client_id only in Basic auth header."""
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
"client_name": "Basic Auth Refresh Client",
"token_endpoint_auth_method": "client_secret_basic",
"grant_types": ["authorization_code", "refresh_token"],
}

response = await test_client.post("/register", json=client_metadata)
assert response.status_code == 201
client_info = response.json()

access_token_str = f"access_{secrets.token_hex(16)}"
refresh_token_str = f"refresh_{int(time.time())}"
mock_oauth_provider.tokens[access_token_str] = AccessToken(
token=access_token_str,
client_id=client_info["client_id"],
scopes=["read"],
expires_at=int(time.time()) + 3600,
)
mock_oauth_provider.refresh_tokens[refresh_token_str] = access_token_str

credentials = f"{client_info['client_id']}:{client_info['client_secret']}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()

# client_id intentionally omitted from body — only in Authorization header
response = await test_client.post(
"/token",
headers={"Authorization": f"Basic {encoded_credentials}"},
data={
"grant_type": "refresh_token",
"refresh_token": refresh_token_str,
},
)
assert response.status_code == 200, f"Expected 200, got {response.status_code}: {response.text}"
token_response = response.json()
assert "access_token" in token_response


class TestAuthorizeEndpointErrors:
"""Test error handling in the OAuth authorization endpoint."""
Expand Down
Loading