From 3f4901e3b4b7e9dc643379764a5afda0f25cc672 Mon Sep 17 00:00:00 2001 From: pragnyanramtha Date: Wed, 20 May 2026 19:53:51 +0000 Subject: [PATCH] fix(auth): avoid SSE OAuth refresh deadlock --- src/mcp/client/auth/oauth2.py | 25 ++++ src/mcp/client/sse.py | 10 ++ tests/client/test_auth.py | 204 ++++++++++++++++++++++++++ tests/shared/test_sse.py | 268 ++++++++++++++++++++++++++++++++++ 4 files changed, 507 insertions(+) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 0ec087968..9d5cee459 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -483,6 +483,31 @@ def _add_auth_header(self, request: httpx.Request) -> None: if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" + async def prepare_request_with_refresh(self, client: httpx.AsyncClient, request: httpx.Request) -> None: + """Refresh stored tokens and add an auth header for requests sent outside the auth flow.""" + async with self.context.lock: + if not self._initialized: + await self._initialize() + + protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) + if protocol_version is not None: + self.context.protocol_version = protocol_version + + if self.context.is_token_valid(): + self._add_auth_header(request) + return + + if not self.context.can_refresh_token(): + return + + refresh_request = await self._refresh_token() + refresh_response = await client.send(refresh_request, auth=None) + + if not await self._handle_refresh_response(refresh_response): + return + + self._add_auth_header(request) + async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: content = await response.aread() metadata = OAuthMetadata.model_validate_json(content) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 0d7fa0fb4..6855dd35f 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -11,6 +11,7 @@ from httpx_sse import SSEError, aconnect_sse import mcp.types as types +from mcp.client.auth import OAuthClientProvider from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import SessionMessage @@ -65,10 +66,19 @@ async def sse_client( async with httpx_client_factory( headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) ) as client: + sse_request_kwargs: dict[str, Any] = {} + if isinstance(auth, OAuthClientProvider): + sse_request = httpx.Request("GET", url, headers=headers) + await auth.prepare_request_with_refresh(client, sse_request) + if "Authorization" in sse_request.headers: + sse_request_kwargs["headers"] = dict(sse_request.headers) + sse_request_kwargs["auth"] = None + async with aconnect_sse( client, "GET", url, + **sse_request_kwargs, ) as event_source: event_source.response.raise_for_status() logger.debug("SSE connection established") diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 5f8bc1410..128e352e6 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -28,6 +28,7 @@ is_valid_client_metadata_url, should_use_client_metadata_url, ) +from mcp.client.streamable_http import MCP_PROTOCOL_VERSION from mcp.shared.auth import ( OAuthClientInformationFull, OAuthClientMetadata, @@ -631,6 +632,209 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, assert "client_id=test_client" in content assert "client_secret=test_secret" in content + @pytest.mark.anyio + async def test_prepare_request_with_refresh_refreshes_expired_token( + self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken + ): + """Test preflight refresh for streaming requests that cannot drive OAuth inline.""" + + class FailingAuth(httpx.Auth): + async def async_auth_flow(self, request: httpx.Request): # pragma: no cover + raise AssertionError("preflight refresh should bypass client auth") + yield request + + oauth_provider.context.current_tokens = valid_tokens + oauth_provider.context.token_expiry_time = time.time() - 1 + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + token_endpoint_auth_method="client_secret_post", + ) + oauth_provider._initialized = True + + requests: list[httpx.Request] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response( + 200, + json={ + "access_token": "refreshed_access_token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "refreshed_refresh_token", + }, + request=request, + ) + + request = httpx.Request( + "GET", + "https://api.example.com/v1/mcp/sse", + headers={MCP_PROTOCOL_VERSION: "2025-06-18"}, + ) + + async with httpx.AsyncClient(transport=httpx.MockTransport(handler), auth=FailingAuth()) as client: + await oauth_provider.prepare_request_with_refresh(client, request) + + assert len(requests) == 1 + assert requests[0].method == "POST" + assert str(requests[0].url) == "https://api.example.com/token" + assert "grant_type=refresh_token" in requests[0].content.decode() + assert "resource=" in requests[0].content.decode() + assert request.headers["Authorization"] == "Bearer refreshed_access_token" + assert oauth_provider.context.current_tokens is not None + assert oauth_provider.context.current_tokens.access_token == "refreshed_access_token" + assert mock_storage._tokens is not None + assert mock_storage._tokens.access_token == "refreshed_access_token" + + @pytest.mark.anyio + async def test_prepare_request_with_refresh_skips_valid_token( + self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken + ): + """Test preflight refresh is a no-op while the current token is still valid.""" + oauth_provider.context.current_tokens = valid_tokens + oauth_provider.context.token_expiry_time = time.time() + 1800 + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + oauth_provider._initialized = True + + requests: list[httpx.Request] = [] + + async def handler(request: httpx.Request) -> httpx.Response: # pragma: no cover + requests.append(request) + return httpx.Response(500, request=request) + + request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse") + + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: + await oauth_provider.prepare_request_with_refresh(client, request) + + assert requests == [] + assert request.headers["Authorization"] == "Bearer test_access_token" + assert oauth_provider.context.current_tokens is valid_tokens + + @pytest.mark.anyio + async def test_prepare_request_with_refresh_preserves_protocol_version_without_header( + self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken + ): + """Test preflight refresh preserves an existing protocol version when the request has no header.""" + oauth_provider.context.current_tokens = valid_tokens + oauth_provider.context.token_expiry_time = time.time() - 1 + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + token_endpoint_auth_method="client_secret_post", + ) + oauth_provider.context.protocol_version = "2025-06-18" + oauth_provider._initialized = True + + requests: list[httpx.Request] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response( + 200, + json={ + "access_token": "refreshed_access_token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "refreshed_refresh_token", + }, + request=request, + ) + + request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse") + + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: + await oauth_provider.prepare_request_with_refresh(client, request) + + assert len(requests) == 1 + assert "resource=" in requests[0].content.decode() + assert oauth_provider.context.protocol_version == "2025-06-18" + assert request.headers["Authorization"] == "Bearer refreshed_access_token" + + @pytest.mark.anyio + async def test_prepare_request_with_refresh_initializes_storage( + self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken + ): + """Test preflight refresh loads persisted OAuth state before preparing the request.""" + client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + await mock_storage.set_tokens(valid_tokens) + await mock_storage.set_client_info(client_info) + + request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse") + + async with httpx.AsyncClient(transport=httpx.MockTransport(lambda request: httpx.Response(500))) as client: + await oauth_provider.prepare_request_with_refresh(client, request) + + assert request.headers["Authorization"] == "Bearer test_access_token" + assert oauth_provider.context.current_tokens is valid_tokens + assert oauth_provider.context.client_info is client_info + + @pytest.mark.anyio + async def test_prepare_request_with_refresh_skips_without_refresh_token(self, oauth_provider: OAuthClientProvider): + """Test preflight refresh leaves the request alone when refresh is not possible.""" + oauth_provider.context.current_tokens = OAuthToken( + access_token="expired_access_token", + refresh_token=None, + expires_in=1, + ) + oauth_provider.context.token_expiry_time = time.time() - 1 + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + oauth_provider._initialized = True + + requests: list[httpx.Request] = [] + + async def handler(request: httpx.Request) -> httpx.Response: # pragma: no cover + requests.append(request) + return httpx.Response(500, request=request) + + request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse") + + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: + await oauth_provider.prepare_request_with_refresh(client, request) + + assert requests == [] + assert "Authorization" not in request.headers + + @pytest.mark.anyio + async def test_prepare_request_with_refresh_keeps_request_unauthenticated_after_refresh_failure( + self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken + ): + """Test failed preflight refresh does not add a stale bearer header.""" + oauth_provider.context.current_tokens = valid_tokens + oauth_provider.context.token_expiry_time = time.time() - 1 + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + token_endpoint_auth_method="client_secret_post", + ) + oauth_provider._initialized = True + + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(400, request=request) + + request = httpx.Request("GET", "https://api.example.com/v1/mcp/sse") + + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: + await oauth_provider.prepare_request_with_refresh(client, request) + + assert "Authorization" not in request.headers + assert oauth_provider.context.current_tokens is None + @pytest.mark.anyio async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvider): """Test token exchange with client_secret_basic authentication.""" diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 7604450f8..413b3e7cb 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -20,11 +20,14 @@ import mcp.client.sse import mcp.types as types +from mcp.client.auth import OAuthClientProvider from mcp.client.session import ClientSession from mcp.client.sse import _extract_session_id_from_endpoint, sse_client +from mcp.client.streamable_http import MCP_PROTOCOL_VERSION from mcp.server import Server from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken from mcp.shared.exceptions import McpError from mcp.types import ( EmptyResult, @@ -602,3 +605,268 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: assert not isinstance(msg, Exception) assert isinstance(msg.message.root, types.JSONRPCResponse) assert msg.message.root.id == 1 + + +@pytest.mark.filterwarnings("ignore::ResourceWarning") +@pytest.mark.anyio +async def test_sse_client_preflights_oauth_refresh_before_streaming() -> None: + """Regression test for OAuth refresh deadlocks while opening SSE streams.""" + + class MemoryTokenStorage: + def __init__(self) -> None: + self.tokens: OAuthToken | None = None + self.client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + return self.tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self.tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self.client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self.client_info = client_info + + class NoStreamAuthProvider(OAuthClientProvider): + async def async_auth_flow(self, request: httpx.Request): # pragma: no cover + if request.url.path.endswith("/sse"): + raise AssertionError("SSE stream should use the preflight bearer header") + async for auth_request in super().async_auth_flow(request): + yield auth_request + + storage = MemoryTokenStorage() + oauth_provider = NoStreamAuthProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=OAuthClientMetadata( + client_name="Test Client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ), + storage=storage, + ) + await storage.set_tokens( + OAuthToken( + access_token="expired_access_token", + refresh_token="refresh_token", + expires_in=1, + ) + ) + await storage.set_client_info( + OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + token_endpoint_auth_method="client_secret_post", + ) + ) + oauth_provider.context.token_expiry_time = time.time() - 1 + + events: list[str] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + if request.url.path == "/token": + events.append("refresh") + assert request.method == "POST" + assert "resource=" in request.content.decode() + return httpx.Response( + 200, + json={ + "access_token": "refreshed_access_token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "refreshed_refresh_token", + }, + request=request, + ) + + events.append("sse") + assert request.url.path == "/v1/mcp/sse" + assert request.headers["Authorization"] == "Bearer refreshed_access_token" + return httpx.Response( + 200, + headers={"Content-Type": "text/event-stream"}, + content=b"event: endpoint\ndata: /messages/?session_id=abc123\n\n", + request=request, + ) + + def client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + assert auth is oauth_provider + return httpx.AsyncClient( + headers=headers, + timeout=timeout, + auth=auth, + transport=httpx.MockTransport(handler), + ) + + with anyio.fail_after(5): + async with sse_client( + "https://api.example.com/v1/mcp/sse", + headers={MCP_PROTOCOL_VERSION: "2025-06-18"}, + auth=oauth_provider, + httpx_client_factory=client_factory, + ): + pass + + assert events == ["refresh", "sse"] + assert storage.tokens is not None + assert storage.tokens.access_token == "refreshed_access_token" + + +@pytest.mark.filterwarnings("ignore::ResourceWarning") +@pytest.mark.anyio +async def test_sse_client_keeps_oauth_on_stream_when_no_bearer_header() -> None: + class MemoryTokenStorage: + async def get_tokens(self) -> OAuthToken | None: + return None + + async def set_tokens(self, tokens: OAuthToken) -> None: # pragma: no cover + raise AssertionError("no tokens should be stored") + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return None + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: # pragma: no cover + raise AssertionError("no client info should be stored") + + oauth_provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=OAuthClientMetadata( + client_name="Test Client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ), + storage=MemoryTokenStorage(), + ) + + mock_event_source = MagicMock() + mock_event_source.response = MagicMock() + mock_event_source.response.raise_for_status = MagicMock() + + async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: + yield ServerSentEvent(event="endpoint", data="/messages/?session_id=abc123") + + mock_event_source.aiter_sse.return_value = mock_aiter_sse() + + mock_aconnect_sse = MagicMock() + mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source) + mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None) + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + def connect_sse(client: httpx.AsyncClient, method: str, url: str, **kwargs: Any) -> MagicMock: + assert kwargs == {} + return mock_aconnect_sse + + with ( + patch("mcp.client.sse.create_mcp_http_client", return_value=mock_client), + patch("mcp.client.sse.aconnect_sse", side_effect=connect_sse), + ): + async with sse_client("https://api.example.com/v1/mcp/sse", auth=oauth_provider): + pass + + +@pytest.mark.filterwarnings("ignore::ResourceWarning") +@pytest.mark.anyio +async def test_sse_client_preflights_initialized_oauth_refresh_before_streaming() -> None: + """Regression test for OAuth refresh deadlocks while opening SSE streams with pre-loaded tokens.""" + + class MemoryTokenStorage: + def __init__(self) -> None: + self.tokens: OAuthToken | None = None + + async def get_tokens(self) -> OAuthToken | None: # pragma: no cover + return None + + async def set_tokens(self, tokens: OAuthToken) -> None: + self.tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: # pragma: no cover + return None + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: # pragma: no cover + raise AssertionError("client info should already be initialized") + + storage = MemoryTokenStorage() + oauth_provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=OAuthClientMetadata( + client_name="Test Client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ), + storage=storage, + ) + oauth_provider.context.current_tokens = OAuthToken( + access_token="expired_access_token", + refresh_token="refresh_token", + expires_in=1, + ) + oauth_provider.context.token_expiry_time = time.time() - 1 + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + token_endpoint_auth_method="client_secret_post", + ) + oauth_provider._initialized = True + + events: list[str] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + if request.url.path == "/token": + events.append("refresh") + assert request.method == "POST" + assert "resource=" in request.content.decode() + return httpx.Response( + 200, + json={ + "access_token": "refreshed_access_token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "refreshed_refresh_token", + }, + request=request, + ) + + events.append("sse") + assert request.url.path == "/v1/mcp/sse" + assert request.headers["Authorization"] == "Bearer refreshed_access_token" + return httpx.Response( + 200, + headers={"Content-Type": "text/event-stream"}, + content=b"event: endpoint\ndata: /messages/?session_id=abc123\n\n", + request=request, + ) + + def client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + assert auth is oauth_provider + return httpx.AsyncClient( + headers=headers, + timeout=timeout, + auth=auth, + transport=httpx.MockTransport(handler), + ) + + with anyio.fail_after(5): + async with sse_client( + "https://api.example.com/v1/mcp/sse", + headers={MCP_PROTOCOL_VERSION: "2025-06-18"}, + auth=oauth_provider, + httpx_client_factory=client_factory, + ): + pass + + assert events == ["refresh", "sse"] + assert oauth_provider.context.current_tokens is not None + assert oauth_provider.context.current_tokens.access_token == "refreshed_access_token" + assert storage.tokens is not None + assert storage.tokens.access_token == "refreshed_access_token"