Skip to content

Commit 3b26a26

Browse files
magicmarkclaude
andcommitted
Add subscribe and batch tests for request extensions
Cover extensions across all transports and operation types: - aiohttp subscribe (HTTP + websocket) - aiohttp, httpx, and requests execute_batch Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c029117 commit 3b26a26

5 files changed

Lines changed: 252 additions & 1 deletion

File tree

tests/test_aiohttp.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,46 @@ def test_code():
619619
await run_sync_test(server, test_code)
620620

621621

622+
@pytest.mark.asyncio
623+
async def test_aiohttp_subscribe_with_extensions(aiohttp_server):
624+
from aiohttp import web
625+
626+
from gql.transport.aiohttp import AIOHTTPTransport
627+
628+
async def handler(request):
629+
body = await request.json()
630+
assert "extensions" in body
631+
assert body["extensions"] == {
632+
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
633+
}
634+
return web.Response(
635+
text=query1_server_answer,
636+
content_type="application/json",
637+
)
638+
639+
app = web.Application()
640+
app.router.add_route("POST", "/", handler)
641+
server = await aiohttp_server(app)
642+
643+
url = server.make_url("/")
644+
645+
transport = AIOHTTPTransport(url=url, timeout=10)
646+
647+
request = GraphQLRequest(
648+
query1_str,
649+
extensions={"persistedQuery": {"version": 1, "sha256Hash": "abc123"}},
650+
)
651+
652+
async with Client(transport=transport) as session:
653+
654+
results = []
655+
async for result in session.subscribe(request):
656+
results.append(result)
657+
658+
assert len(results) == 1
659+
assert results[0]["continents"][0]["code"] == "AF"
660+
661+
622662
file_upload_mutation_1 = """
623663
mutation($file: Upload!) {
624664
uploadFile(input:{other_var:$other_var, file:$file}) {

tests/test_aiohttp_batch.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,49 @@ async def handler(request):
8989
assert transport.response_headers["dummy"] == "test1234"
9090

9191

92+
@pytest.mark.asyncio
93+
async def test_aiohttp_batch_query_with_extensions(aiohttp_server):
94+
from aiohttp import web
95+
96+
from gql.transport.aiohttp import AIOHTTPTransport
97+
98+
async def handler(request):
99+
body = await request.json()
100+
assert isinstance(body, list)
101+
assert "extensions" in body[0]
102+
assert body[0]["extensions"] == {
103+
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
104+
}
105+
return web.Response(
106+
text=query1_server_answer_list,
107+
content_type="application/json",
108+
)
109+
110+
app = web.Application()
111+
app.router.add_route("POST", "/", handler)
112+
server = await aiohttp_server(app)
113+
114+
url = server.make_url("/")
115+
116+
transport = AIOHTTPTransport(url=url, timeout=10)
117+
118+
async with Client(transport=transport) as session:
119+
120+
query = [
121+
GraphQLRequest(
122+
query1_str,
123+
extensions={
124+
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
125+
},
126+
)
127+
]
128+
129+
results = await session.execute_batch(query)
130+
131+
continents = results[0]["continents"]
132+
assert continents[0]["code"] == "AF"
133+
134+
92135
@pytest.mark.asyncio
93136
async def test_aiohttp_batch_query_auto_batch_enabled(aiohttp_server, run_sync_test):
94137
from aiohttp import web

tests/test_aiohttp_websocket_subscription.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from graphql import ExecutionResult
99
from parse import search
1010

11-
from gql import Client, gql
11+
from gql import Client, GraphQLRequest, gql
1212
from gql.client import AsyncClientSession
1313
from gql.transport.exceptions import TransportConnectionFailed, TransportServerError
1414

@@ -460,6 +460,36 @@ async def test_aiohttp_websocket_subscription_with_operation_name(
460460
assert '"operationName": "CountdownSubscription"' in logged_messages[0]
461461

462462

463+
@pytest.mark.asyncio
464+
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
465+
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
466+
async def test_aiohttp_websocket_subscription_with_extensions(
467+
aiohttp_client_and_server, subscription_str
468+
):
469+
470+
session, server = aiohttp_client_and_server
471+
472+
count = 10
473+
request = GraphQLRequest(
474+
subscription_str.format(count=count),
475+
extensions={"persistedQuery": {"version": 1, "sha256Hash": "abc123"}},
476+
)
477+
478+
async for result in session.subscribe(request):
479+
480+
number = result["number"]
481+
print(f"Number received: {number}")
482+
483+
assert number == count
484+
count -= 1
485+
486+
assert count == -1
487+
488+
# Check that the query contains the extensions
489+
assert '"persistedQuery"' in logged_messages[0]
490+
assert '"sha256Hash": "abc123"' in logged_messages[0]
491+
492+
463493
WITH_KEEPALIVE = True
464494

465495

tests/test_httpx_batch.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,97 @@ def test_code():
118118
await run_sync_test(server, test_code)
119119

120120

121+
@pytest.mark.aiohttp
122+
@pytest.mark.asyncio
123+
async def test_httpx_async_batch_query_with_extensions(aiohttp_server):
124+
from aiohttp import web
125+
126+
from gql.transport.httpx import HTTPXAsyncTransport
127+
128+
async def handler(request):
129+
body = await request.json()
130+
assert isinstance(body, list)
131+
assert "extensions" in body[0]
132+
assert body[0]["extensions"] == {
133+
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
134+
}
135+
return web.Response(
136+
text=query1_server_answer_list,
137+
content_type="application/json",
138+
)
139+
140+
app = web.Application()
141+
app.router.add_route("POST", "/", handler)
142+
server = await aiohttp_server(app)
143+
144+
url = str(server.make_url("/"))
145+
146+
transport = HTTPXAsyncTransport(url=url, timeout=10)
147+
148+
async with Client(transport=transport) as session:
149+
150+
query = [
151+
GraphQLRequest(
152+
query1_str,
153+
extensions={
154+
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
155+
},
156+
)
157+
]
158+
159+
results = await session.execute_batch(query)
160+
161+
continents = results[0]["continents"]
162+
assert continents[0]["code"] == "AF"
163+
164+
165+
@pytest.mark.aiohttp
166+
@pytest.mark.asyncio
167+
async def test_httpx_sync_batch_query_with_extensions(aiohttp_server, run_sync_test):
168+
from aiohttp import web
169+
170+
from gql.transport.httpx import HTTPXTransport
171+
172+
async def handler(request):
173+
body = await request.json()
174+
assert isinstance(body, list)
175+
assert "extensions" in body[0]
176+
assert body[0]["extensions"] == {
177+
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
178+
}
179+
return web.Response(
180+
text=query1_server_answer_list,
181+
content_type="application/json",
182+
)
183+
184+
app = web.Application()
185+
app.router.add_route("POST", "/", handler)
186+
server = await aiohttp_server(app)
187+
188+
url = str(server.make_url("/"))
189+
190+
transport = HTTPXTransport(url=url, timeout=10)
191+
192+
def test_code():
193+
with Client(transport=transport) as session:
194+
195+
query = [
196+
GraphQLRequest(
197+
query1_str,
198+
extensions={
199+
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
200+
},
201+
)
202+
]
203+
204+
results = session.execute_batch(query)
205+
206+
continents = results[0]["continents"]
207+
assert continents[0]["code"] == "AF"
208+
209+
await run_sync_test(server, test_code)
210+
211+
121212
@pytest.mark.aiohttp
122213
@pytest.mark.asyncio
123214
async def test_httpx_async_batch_query_without_session(aiohttp_server, run_sync_test):

tests/test_requests_batch.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,53 @@ def test_code():
9090
await run_sync_test(server, test_code)
9191

9292

93+
@pytest.mark.aiohttp
94+
@pytest.mark.asyncio
95+
async def test_requests_batch_query_with_extensions(aiohttp_server, run_sync_test):
96+
from aiohttp import web
97+
98+
from gql.transport.requests import RequestsHTTPTransport
99+
100+
async def handler(request):
101+
body = await request.json()
102+
assert isinstance(body, list)
103+
assert "extensions" in body[0]
104+
assert body[0]["extensions"] == {
105+
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
106+
}
107+
return web.Response(
108+
text=query1_server_answer_list,
109+
content_type="application/json",
110+
)
111+
112+
app = web.Application()
113+
app.router.add_route("POST", "/", handler)
114+
server = await aiohttp_server(app)
115+
116+
url = server.make_url("/")
117+
118+
def test_code():
119+
transport = RequestsHTTPTransport(url=url)
120+
121+
with Client(transport=transport) as session:
122+
123+
query = [
124+
GraphQLRequest(
125+
query1_str,
126+
extensions={
127+
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
128+
},
129+
)
130+
]
131+
132+
results = session.execute_batch(query)
133+
134+
continents = results[0]["continents"]
135+
assert continents[0]["code"] == "AF"
136+
137+
await run_sync_test(server, test_code)
138+
139+
93140
@pytest.mark.aiohttp
94141
@pytest.mark.asyncio
95142
async def test_requests_query_auto_batch_enabled(aiohttp_server, run_sync_test):

0 commit comments

Comments
 (0)