Skip to content

Commit 4a18d49

Browse files
authored
Add extra_args parameter to subscribe() for aiohttp transport (#584)
1 parent d2eddba commit 4a18d49

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

gql/transport/aiohttp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,19 +425,21 @@ async def execute_batch(
425425
async def subscribe(
426426
self,
427427
request: GraphQLRequest,
428+
*,
429+
extra_args: Optional[Dict[str, Any]] = None,
428430
) -> AsyncGenerator[ExecutionResult, None]:
429431
"""Execute a GraphQL subscription and yield results from multipart response.
430432
431433
:param request: GraphQL request to execute
434+
:param extra_args: additional arguments to send to the aiohttp post method
432435
:yields: ExecutionResult objects as they arrive in the multipart stream
433436
"""
434437
if self.session is None:
435438
raise TransportClosed("Transport is not connected")
436439

437-
post_args = self._prepare_request(request)
440+
post_args = self._prepare_request(request, extra_args)
438441

439-
# Add headers for multipart subscription
440-
headers = post_args.get("headers", {})
442+
headers = dict(post_args.get("headers", {}))
441443
headers.update(
442444
{
443445
"Content-Type": "application/json",

tests/test_aiohttp_multipart.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,3 +636,31 @@ async def test_aiohttp_multipart_actually_invalid_utf8(multipart_server):
636636

637637
# Should skip invalid part and not crash
638638
assert len(results) == 0
639+
640+
641+
@pytest.mark.asyncio
642+
async def test_aiohttp_multipart_subscribe_extra_args(multipart_server):
643+
"""Test that extra_args are passed through to the post method."""
644+
from gql.transport.aiohttp import AIOHTTPTransport
645+
646+
custom_header_received = False
647+
648+
def check_custom_header(request):
649+
nonlocal custom_header_received
650+
if request.headers.get("X-Custom-Header") == "custom-value":
651+
custom_header_received = True
652+
653+
parts = create_multipart_response([book1])
654+
server = await multipart_server(parts, request_handler=check_custom_header)
655+
url = server.make_url("/")
656+
transport = AIOHTTPTransport(url=url)
657+
658+
query = gql(subscription_str)
659+
660+
async with Client(transport=transport) as session:
661+
async for result in session.subscribe(
662+
query, extra_args={"headers": {"X-Custom-Header": "custom-value"}}
663+
):
664+
pass
665+
666+
assert custom_header_received

0 commit comments

Comments
 (0)