File tree Expand file tree Collapse file tree 2 files changed +33
-3
lines changed
Expand file tree Collapse file tree 2 files changed +33
-3
lines changed Original file line number Diff line number Diff line change @@ -425,19 +425,21 @@ async def execute_batch(
425425 async def subscribe (
426426 self ,
427427 request : GraphQLRequest ,
428+ * ,
429+ extra_args : Optional [Dict [str , Any ]] = None ,
428430 ) -> AsyncGenerator [ExecutionResult , None ]:
429431 """Execute a GraphQL subscription and yield results from multipart response.
430432
431433 :param request: GraphQL request to execute
434+ :param extra_args: additional arguments to send to the aiohttp post method
432435 :yields: ExecutionResult objects as they arrive in the multipart stream
433436 """
434437 if self .session is None :
435438 raise TransportClosed ("Transport is not connected" )
436439
437- post_args = self ._prepare_request (request )
440+ post_args = self ._prepare_request (request , extra_args )
438441
439- # Add headers for multipart subscription
440- headers = post_args .get ("headers" , {})
442+ headers = dict (post_args .get ("headers" , {}))
441443 headers .update (
442444 {
443445 "Content-Type" : "application/json" ,
Original file line number Diff line number Diff line change @@ -636,3 +636,31 @@ async def test_aiohttp_multipart_actually_invalid_utf8(multipart_server):
636636
637637 # Should skip invalid part and not crash
638638 assert len (results ) == 0
639+
640+
641+ @pytest .mark .asyncio
642+ async def test_aiohttp_multipart_subscribe_extra_args (multipart_server ):
643+ """Test that extra_args are passed through to the post method."""
644+ from gql .transport .aiohttp import AIOHTTPTransport
645+
646+ custom_header_received = False
647+
648+ def check_custom_header (request ):
649+ nonlocal custom_header_received
650+ if request .headers .get ("X-Custom-Header" ) == "custom-value" :
651+ custom_header_received = True
652+
653+ parts = create_multipart_response ([book1 ])
654+ server = await multipart_server (parts , request_handler = check_custom_header )
655+ url = server .make_url ("/" )
656+ transport = AIOHTTPTransport (url = url )
657+
658+ query = gql (subscription_str )
659+
660+ async with Client (transport = transport ) as session :
661+ async for result in session .subscribe (
662+ query , extra_args = {"headers" : {"X-Custom-Header" : "custom-value" }}
663+ ):
664+ pass
665+
666+ assert custom_header_received
You can’t perform that action at this time.
0 commit comments