Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/mcp/server/auth/handlers/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,15 @@ async def handle(self, request: Request):
)

try:
form_data = await request.form()
# TODO(Marcelo): Can someone check if this `dict()` wrapper is necessary?
token_request = token_request_adapter.validate_python(dict(form_data))
form_data = dict(await request.form())
# client_id may have been supplied via HTTP Basic auth header instead of the
# request body (RFC 6749 §2.3.1). ClientAuthenticator already verified it,
# so we can safely populate it from client_info when absent from form data.
# The truthiness check narrows `str | None` to `str` for the type checker;
# ClientAuthenticator guarantees a non-empty client_id reached this point.
if "client_id" not in form_data and client_info.client_id:
form_data["client_id"] = client_info.client_id
token_request = token_request_adapter.validate_python(form_data)
except ValidationError as validation_error: # pragma: no cover
return self.response(
TokenErrorResponse(
Expand Down
10 changes: 10 additions & 0 deletions src/mcp/server/auth/middleware/client_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation
"""
form_data = await request.form()
client_id = form_data.get("client_id")
if not client_id:
# RFC 6749 §2.3.1: client credentials MAY be sent via HTTP Basic auth
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
try:
decoded = base64.b64decode(auth_header[6:]).decode("utf-8")
if ":" in decoded:
client_id = unquote(decoded.split(":", 1)[0])
except (ValueError, UnicodeDecodeError, binascii.Error):
pass
if not client_id:
raise AuthenticationError("Missing client_id")

Expand Down
115 changes: 115 additions & 0 deletions tests/server/mcpserver/auth/test_auth_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,121 @@ async def test_none_auth_method_public_client(
token_response = response.json()
assert "access_token" in token_response

@pytest.mark.anyio
async def test_basic_auth_without_client_id_in_body(
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
):
"""Test RFC 6749 §2.3.1: client_id supplied only via Basic auth header, not in body."""
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
"client_name": "Basic Auth Only Header Client",
"token_endpoint_auth_method": "client_secret_basic",
"grant_types": ["authorization_code", "refresh_token"],
}

response = await test_client.post("/register", json=client_metadata)
assert response.status_code == 201
client_info = response.json()

auth_code = f"code_{int(time.time())}"
mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode(
code=auth_code,
client_id=client_info["client_id"],
code_challenge=pkce_challenge["code_challenge"],
redirect_uri=AnyUrl("https://client.example.com/callback"),
redirect_uri_provided_explicitly=True,
scopes=["read", "write"],
expires_at=time.time() + 600,
)

credentials = f"{client_info['client_id']}:{client_info['client_secret']}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()

# client_id intentionally omitted from body — only in Authorization header
response = await test_client.post(
"/token",
headers={"Authorization": f"Basic {encoded_credentials}"},
data={
"grant_type": "authorization_code",
"code": auth_code,
"code_verifier": pkce_challenge["code_verifier"],
"redirect_uri": "https://client.example.com/callback",
},
)
assert response.status_code == 200, f"Expected 200, got {response.status_code}: {response.text}"
token_response = response.json()
assert "access_token" in token_response

@pytest.mark.anyio
async def test_basic_auth_refresh_token_without_client_id_in_body(
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
):
"""Test RFC 6749 §2.3.1: refresh_token grant with client_id only in Basic auth header."""
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
"client_name": "Basic Auth Refresh Client",
"token_endpoint_auth_method": "client_secret_basic",
"grant_types": ["authorization_code", "refresh_token"],
}

response = await test_client.post("/register", json=client_metadata)
assert response.status_code == 201
client_info = response.json()

access_token_str = f"access_{secrets.token_hex(16)}"
refresh_token_str = f"refresh_{int(time.time())}"
mock_oauth_provider.tokens[access_token_str] = AccessToken(
token=access_token_str,
client_id=client_info["client_id"],
scopes=["read"],
expires_at=int(time.time()) + 3600,
)
mock_oauth_provider.refresh_tokens[refresh_token_str] = access_token_str

credentials = f"{client_info['client_id']}:{client_info['client_secret']}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()

# client_id intentionally omitted from body — only in Authorization header
response = await test_client.post(
"/token",
headers={"Authorization": f"Basic {encoded_credentials}"},
data={
"grant_type": "refresh_token",
"refresh_token": refresh_token_str,
},
)
assert response.status_code == 200, f"Expected 200, got {response.status_code}: {response.text}"
token_response = response.json()
assert "access_token" in token_response

@pytest.mark.anyio
async def test_basic_auth_fallback_invalid_base64_falls_through(self, test_client: httpx.AsyncClient):
"""Basic auth fallback (no client_id in body): invalid base64 is silently ignored,
request continues without client_id and surfaces a normal 'invalid_request' error."""
# client_id missing from body AND the Basic header is malformed base64
response = await test_client.post(
"/token",
headers={"Authorization": "Basic !!!not-valid-base64!!!"},
data={"grant_type": "authorization_code", "code": "irrelevant"},
)
# The malformed base64 is swallowed; client_id remains missing → standard 401
assert response.status_code == 401
assert response.json()["error"] == "invalid_client"

@pytest.mark.anyio
async def test_basic_auth_fallback_no_colon_falls_through(self, test_client: httpx.AsyncClient):
"""Basic auth fallback (no client_id in body): decoded credentials without a colon
are not treated as a valid client_id, and the request fails with 'invalid_client'."""
# b64("no_colon_here") decodes cleanly but contains no ':' separator
encoded = base64.b64encode(b"no_colon_here").decode()
response = await test_client.post(
"/token",
headers={"Authorization": f"Basic {encoded}"},
data={"grant_type": "authorization_code", "code": "irrelevant"},
)
assert response.status_code == 401
assert response.json()["error"] == "invalid_client"


class TestAuthorizeEndpointErrors:
"""Test error handling in the OAuth authorization endpoint."""
Expand Down
Loading