Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,31 @@ def _add_auth_header(self, request: httpx.Request) -> None:
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"

async def prepare_request_with_refresh(self, client: httpx.AsyncClient, request: httpx.Request) -> None:
"""Refresh stored tokens and add an auth header for requests sent outside the auth flow."""
async with self.context.lock:
if not self._initialized:
await self._initialize()

protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
if protocol_version is not None:
self.context.protocol_version = protocol_version

if self.context.is_token_valid():
self._add_auth_header(request)
return

if not self.context.can_refresh_token():
return

refresh_request = await self._refresh_token()
refresh_response = await client.send(refresh_request, auth=None)

if not await self._handle_refresh_response(refresh_response):
return

self._add_auth_header(request)

async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
content = await response.aread()
metadata = OAuthMetadata.model_validate_json(content)
Expand Down
10 changes: 10 additions & 0 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from httpx_sse import SSEError, aconnect_sse

import mcp.types as types
from mcp.client.auth import OAuthClientProvider
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage

Expand Down Expand Up @@ -65,10 +66,19 @@ async def sse_client(
async with httpx_client_factory(
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
) as client:
sse_request_kwargs: dict[str, Any] = {}
if isinstance(auth, OAuthClientProvider):
sse_request = httpx.Request("GET", url, headers=headers)
await auth.prepare_request_with_refresh(client, sse_request)
if "Authorization" in sse_request.headers:
sse_request_kwargs["headers"] = dict(sse_request.headers)
sse_request_kwargs["auth"] = None
Comment thread
pragnyanramtha marked this conversation as resolved.

async with aconnect_sse(
client,
"GET",
url,
**sse_request_kwargs,
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
Expand Down
204 changes: 204 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
is_valid_client_metadata_url,
should_use_client_metadata_url,
)
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
Expand Down Expand Up @@ -631,6 +632,209 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
assert "client_id=test_client" in content
assert "client_secret=test_secret" in content

@pytest.mark.anyio
async def test_prepare_request_with_refresh_refreshes_expired_token(
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken
):
"""Test preflight refresh for streaming requests that cannot drive OAuth inline."""

class FailingAuth(httpx.Auth):
async def async_auth_flow(self, request: httpx.Request): # pragma: no cover
raise AssertionError("preflight refresh should bypass client auth")
yield request

oauth_provider.context.current_tokens = valid_tokens
oauth_provider.context.token_expiry_time = time.time() - 1
oauth_provider.context.client_info = OAuthClientInformationFull(
client_id="test_client",
client_secret="test_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
token_endpoint_auth_method="client_secret_post",
)
oauth_provider._initialized = True

requests: list[httpx.Request] = []

async def handler(request: httpx.Request) -> httpx.Response:
requests.append(request)
return httpx.Response(
200,
json={
"access_token": "refreshed_access_token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "refreshed_refresh_token",
},
request=request,
)

request = httpx.Request(
"GET",
"https://api.example.com/v1/mcp/sse",
headers={MCP_PROTOCOL_VERSION: "2025-06-18"},
)
Comment thread
pragnyanramtha marked this conversation as resolved.

async with httpx.AsyncClient(transport=httpx.MockTransport(handler), auth=FailingAuth()) as client:
await oauth_provider.prepare_request_with_refresh(client, request)

assert len(requests) == 1
assert requests[0].method == "POST"
assert str(requests[0].url) == "https://api.example.com/token"
assert "grant_type=refresh_token" in requests[0].content.decode()
assert "resource=" in requests[0].content.decode()
assert request.headers["Authorization"] == "Bearer refreshed_access_token"
assert oauth_provider.context.current_tokens is not None
assert oauth_provider.context.current_tokens.access_token == "refreshed_access_token"
assert mock_storage._tokens is not None
assert mock_storage._tokens.access_token == "refreshed_access_token"

@pytest.mark.anyio
async def test_prepare_request_with_refresh_skips_valid_token(
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
):
"""Test preflight refresh is a no-op while the current token is still valid."""
oauth_provider.context.current_tokens = valid_tokens
oauth_provider.context.token_expiry_time = time.time() + 1800
oauth_provider.context.client_info = OAuthClientInformationFull(
client_id="test_client",
client_secret="test_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)
oauth_provider._initialized = True

requests: list[httpx.Request] = []

async def handler(request: httpx.Request) -> httpx.Response: # pragma: no cover
requests.append(request)
return httpx.Response(500, request=request)

request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse")

async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
await oauth_provider.prepare_request_with_refresh(client, request)

assert requests == []
assert request.headers["Authorization"] == "Bearer test_access_token"
assert oauth_provider.context.current_tokens is valid_tokens

@pytest.mark.anyio
async def test_prepare_request_with_refresh_preserves_protocol_version_without_header(
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
):
"""Test preflight refresh preserves an existing protocol version when the request has no header."""
oauth_provider.context.current_tokens = valid_tokens
oauth_provider.context.token_expiry_time = time.time() - 1
oauth_provider.context.client_info = OAuthClientInformationFull(
client_id="test_client",
client_secret="test_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
token_endpoint_auth_method="client_secret_post",
)
oauth_provider.context.protocol_version = "2025-06-18"
oauth_provider._initialized = True

requests: list[httpx.Request] = []

async def handler(request: httpx.Request) -> httpx.Response:
requests.append(request)
return httpx.Response(
200,
json={
"access_token": "refreshed_access_token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "refreshed_refresh_token",
},
request=request,
)

request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse")

async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
await oauth_provider.prepare_request_with_refresh(client, request)

assert len(requests) == 1
assert "resource=" in requests[0].content.decode()
assert oauth_provider.context.protocol_version == "2025-06-18"
assert request.headers["Authorization"] == "Bearer refreshed_access_token"

@pytest.mark.anyio
async def test_prepare_request_with_refresh_initializes_storage(
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken
):
"""Test preflight refresh loads persisted OAuth state before preparing the request."""
client_info = OAuthClientInformationFull(
client_id="test_client",
client_secret="test_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)
await mock_storage.set_tokens(valid_tokens)
await mock_storage.set_client_info(client_info)

request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse")

async with httpx.AsyncClient(transport=httpx.MockTransport(lambda request: httpx.Response(500))) as client:
await oauth_provider.prepare_request_with_refresh(client, request)

assert request.headers["Authorization"] == "Bearer test_access_token"
assert oauth_provider.context.current_tokens is valid_tokens
assert oauth_provider.context.client_info is client_info

@pytest.mark.anyio
async def test_prepare_request_with_refresh_skips_without_refresh_token(self, oauth_provider: OAuthClientProvider):
"""Test preflight refresh leaves the request alone when refresh is not possible."""
oauth_provider.context.current_tokens = OAuthToken(
access_token="expired_access_token",
refresh_token=None,
expires_in=1,
)
oauth_provider.context.token_expiry_time = time.time() - 1
oauth_provider.context.client_info = OAuthClientInformationFull(
client_id="test_client",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)
oauth_provider._initialized = True

requests: list[httpx.Request] = []

async def handler(request: httpx.Request) -> httpx.Response: # pragma: no cover
requests.append(request)
return httpx.Response(500, request=request)

request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse")

async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
await oauth_provider.prepare_request_with_refresh(client, request)

assert requests == []
assert "Authorization" not in request.headers

@pytest.mark.anyio
async def test_prepare_request_with_refresh_keeps_request_unauthenticated_after_refresh_failure(
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
):
"""Test failed preflight refresh does not add a stale bearer header."""
oauth_provider.context.current_tokens = valid_tokens
oauth_provider.context.token_expiry_time = time.time() - 1
oauth_provider.context.client_info = OAuthClientInformationFull(
client_id="test_client",
client_secret="test_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
token_endpoint_auth_method="client_secret_post",
)
oauth_provider._initialized = True

async def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(400, request=request)

request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse")

async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
await oauth_provider.prepare_request_with_refresh(client, request)

assert "Authorization" not in request.headers
assert oauth_provider.context.current_tokens is None

@pytest.mark.anyio
async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvider):
"""Test token exchange with client_secret_basic authentication."""
Expand Down
Loading
Loading