Skip to content

Commit c795223

Browse files
magicmarkclaude
andcommitted
Consolidate and clean up extensions tests
- Merge 6 unit tests into one test_graphql_request_extensions - Merge aiohttp execute + subscribe into one test - Drop redundant httpx sync batch test (same code path as async) - Tighten assertions and remove unnecessary comments - Parse WS logged message as JSON instead of fragile string matching Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3b26a26 commit c795223

8 files changed

Lines changed: 59 additions & 231 deletions

tests/test_aiohttp.py

Lines changed: 11 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,16 @@ async def handler(request):
8888

8989

9090
@pytest.mark.asyncio
91-
async def test_aiohttp_query_with_extensions(aiohttp_server):
91+
async def test_aiohttp_request_extensions(aiohttp_server):
9292
from aiohttp import web
9393

9494
from gql.transport.aiohttp import AIOHTTPTransport
9595

96+
extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}
97+
9698
async def handler(request):
9799
body = await request.json()
98-
assert "extensions" in body
99-
assert body["extensions"] == {
100-
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
101-
}
100+
assert body["extensions"] == extensions
102101
return web.Response(
103102
text=query1_server_answer,
104103
content_type="application/json",
@@ -112,19 +111,17 @@ async def handler(request):
112111

113112
transport = AIOHTTPTransport(url=url, timeout=10)
114113

115-
async with Client(transport=transport) as session:
114+
request = GraphQLRequest(query1_str, extensions=extensions)
116115

117-
request = GraphQLRequest(
118-
query1_str,
119-
extensions={
120-
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
121-
},
122-
)
116+
async with Client(transport=transport) as session:
123117

118+
# execute
124119
result = await session.execute(request)
120+
assert result["continents"][0]["code"] == "AF"
125121

126-
continents = result["continents"]
127-
assert continents[0]["code"] == "AF"
122+
# subscribe
123+
async for result in session.subscribe(request):
124+
assert result["continents"][0]["code"] == "AF"
128125

129126

130127
@pytest.mark.asyncio
@@ -619,46 +616,6 @@ def test_code():
619616
await run_sync_test(server, test_code)
620617

621618

622-
@pytest.mark.asyncio
623-
async def test_aiohttp_subscribe_with_extensions(aiohttp_server):
624-
from aiohttp import web
625-
626-
from gql.transport.aiohttp import AIOHTTPTransport
627-
628-
async def handler(request):
629-
body = await request.json()
630-
assert "extensions" in body
631-
assert body["extensions"] == {
632-
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
633-
}
634-
return web.Response(
635-
text=query1_server_answer,
636-
content_type="application/json",
637-
)
638-
639-
app = web.Application()
640-
app.router.add_route("POST", "/", handler)
641-
server = await aiohttp_server(app)
642-
643-
url = server.make_url("/")
644-
645-
transport = AIOHTTPTransport(url=url, timeout=10)
646-
647-
request = GraphQLRequest(
648-
query1_str,
649-
extensions={"persistedQuery": {"version": 1, "sha256Hash": "abc123"}},
650-
)
651-
652-
async with Client(transport=transport) as session:
653-
654-
results = []
655-
async for result in session.subscribe(request):
656-
results.append(result)
657-
658-
assert len(results) == 1
659-
assert results[0]["continents"][0]["code"] == "AF"
660-
661-
662619
file_upload_mutation_1 = """
663620
mutation($file: Upload!) {
664621
uploadFile(input:{other_var:$other_var, file:$file}) {

tests/test_aiohttp_batch.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,17 @@ async def handler(request):
9090

9191

9292
@pytest.mark.asyncio
93-
async def test_aiohttp_batch_query_with_extensions(aiohttp_server):
93+
async def test_aiohttp_batch_request_extensions(aiohttp_server):
9494
from aiohttp import web
9595

9696
from gql.transport.aiohttp import AIOHTTPTransport
9797

98+
extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}
99+
98100
async def handler(request):
99101
body = await request.json()
100102
assert isinstance(body, list)
101-
assert "extensions" in body[0]
102-
assert body[0]["extensions"] == {
103-
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
104-
}
103+
assert body[0]["extensions"] == extensions
105104
return web.Response(
106105
text=query1_server_answer_list,
107106
content_type="application/json",
@@ -117,19 +116,9 @@ async def handler(request):
117116

118117
async with Client(transport=transport) as session:
119118

120-
query = [
121-
GraphQLRequest(
122-
query1_str,
123-
extensions={
124-
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
125-
},
126-
)
127-
]
128-
119+
query = [GraphQLRequest(query1_str, extensions=extensions)]
129120
results = await session.execute_batch(query)
130-
131-
continents = results[0]["continents"]
132-
assert continents[0]["code"] == "AF"
121+
assert results[0]["continents"][0]["code"] == "AF"
133122

134123

135124
@pytest.mark.asyncio

tests/test_aiohttp_websocket_subscription.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -485,9 +485,10 @@ async def test_aiohttp_websocket_subscription_with_extensions(
485485

486486
assert count == -1
487487

488-
# Check that the query contains the extensions
489-
assert '"persistedQuery"' in logged_messages[0]
490-
assert '"sha256Hash": "abc123"' in logged_messages[0]
488+
message = json.loads(logged_messages[0])
489+
assert message["payload"]["extensions"] == {
490+
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
491+
}
491492

492493

493494
WITH_KEEPALIVE = True

tests/test_graphql_request.py

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -238,55 +238,26 @@ def test_graphql_request_init_with_graphql_request():
238238
assert request_3.variable_values["money"] == money_value_2
239239

240240

241-
def test_graphql_request_extensions_in_payload():
242-
extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}
243-
request = GraphQLRequest("{balance}", extensions=extensions)
244-
245-
payload = request.payload
246-
assert payload["extensions"] == extensions
247-
248-
249-
def test_graphql_request_extensions_not_in_payload_when_none():
250-
request = GraphQLRequest("{balance}")
251-
assert "extensions" not in request.payload
252-
253-
254-
def test_graphql_request_extensions_copied_from_graphql_request():
255-
extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}
256-
request_1 = GraphQLRequest("{balance}", extensions=extensions)
257-
request_2 = GraphQLRequest(request_1)
258-
259-
assert request_2.extensions == extensions
260-
assert request_2.payload["extensions"] == extensions
261-
262-
263-
def test_graphql_request_extensions_override_from_graphql_request():
241+
def test_graphql_request_extensions():
264242
extensions_1 = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}
265243
extensions_2 = {"custom": "value"}
266-
request_1 = GraphQLRequest("{balance}", extensions=extensions_1)
267-
request_2 = GraphQLRequest(request_1, extensions=extensions_2)
244+
money_value = Money(10, "DM")
268245

269-
assert request_2.extensions == extensions_2
246+
assert "extensions" not in GraphQLRequest("{balance}").payload
270247

248+
request_1 = GraphQLRequest("{balance}", extensions=extensions_1)
249+
assert request_1.payload["extensions"] == extensions_1
271250

272-
def test_graphql_request_extensions_preserved_by_serialize_variable_values():
273-
extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}
274-
money_value = Money(10, "DM")
251+
request_2 = GraphQLRequest(request_1)
252+
assert request_2.extensions == extensions_1
275253

276-
request = GraphQLRequest(
254+
request_3 = GraphQLRequest(request_1, extensions=extensions_2)
255+
assert request_3.extensions == extensions_2
256+
257+
request_4 = GraphQLRequest(
277258
"query myquery($money: Money) {toEuros(money: $money)}",
278259
variable_values={"money": money_value},
279-
extensions=extensions,
260+
extensions=extensions_1,
280261
)
281-
282-
serialized = request.serialize_variable_values(schema)
283-
assert serialized.extensions == extensions
284-
assert serialized.payload["extensions"] == extensions
285-
286-
287-
def test_graphql_request_str_includes_extensions():
288-
extensions = {"key": "value"}
289-
request = GraphQLRequest("{balance}", extensions=extensions)
290-
result = str(request)
291-
assert "extensions" in result
292-
assert "'key': 'value'" in result
262+
serialized = request_4.serialize_variable_values(schema)
263+
assert serialized.extensions == extensions_1

tests/test_httpx.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,16 @@ def test_code():
8686

8787
@pytest.mark.aiohttp
8888
@pytest.mark.asyncio
89-
async def test_httpx_query_with_extensions(aiohttp_server, run_sync_test):
89+
async def test_httpx_request_extensions(aiohttp_server, run_sync_test):
9090
from aiohttp import web
9191

9292
from gql.transport.httpx import HTTPXTransport
9393

94+
extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}
95+
9496
async def handler(request):
9597
body = await request.json()
96-
assert "extensions" in body
97-
assert body["extensions"] == {
98-
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
99-
}
98+
assert body["extensions"] == extensions
10099
return web.Response(
101100
text=query1_server_answer,
102101
content_type="application/json",
@@ -112,18 +111,9 @@ def test_code():
112111
transport = HTTPXTransport(url=url)
113112

114113
with Client(transport=transport) as session:
115-
116-
request = GraphQLRequest(
117-
query1_str,
118-
extensions={
119-
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
120-
},
121-
)
122-
114+
request = GraphQLRequest(query1_str, extensions=extensions)
123115
result = session.execute(request)
124-
125-
continents = result["continents"]
126-
assert continents[0]["code"] == "AF"
116+
assert result["continents"][0]["code"] == "AF"
127117

128118
await run_sync_test(server, test_code)
129119

tests/test_httpx_batch.py

Lines changed: 6 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -120,18 +120,17 @@ def test_code():
120120

121121
@pytest.mark.aiohttp
122122
@pytest.mark.asyncio
123-
async def test_httpx_async_batch_query_with_extensions(aiohttp_server):
123+
async def test_httpx_batch_request_extensions(aiohttp_server):
124124
from aiohttp import web
125125

126126
from gql.transport.httpx import HTTPXAsyncTransport
127127

128+
extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}
129+
128130
async def handler(request):
129131
body = await request.json()
130132
assert isinstance(body, list)
131-
assert "extensions" in body[0]
132-
assert body[0]["extensions"] == {
133-
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
134-
}
133+
assert body[0]["extensions"] == extensions
135134
return web.Response(
136135
text=query1_server_answer_list,
137136
content_type="application/json",
@@ -147,66 +146,9 @@ async def handler(request):
147146

148147
async with Client(transport=transport) as session:
149148

150-
query = [
151-
GraphQLRequest(
152-
query1_str,
153-
extensions={
154-
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
155-
},
156-
)
157-
]
158-
149+
query = [GraphQLRequest(query1_str, extensions=extensions)]
159150
results = await session.execute_batch(query)
160-
161-
continents = results[0]["continents"]
162-
assert continents[0]["code"] == "AF"
163-
164-
165-
@pytest.mark.aiohttp
166-
@pytest.mark.asyncio
167-
async def test_httpx_sync_batch_query_with_extensions(aiohttp_server, run_sync_test):
168-
from aiohttp import web
169-
170-
from gql.transport.httpx import HTTPXTransport
171-
172-
async def handler(request):
173-
body = await request.json()
174-
assert isinstance(body, list)
175-
assert "extensions" in body[0]
176-
assert body[0]["extensions"] == {
177-
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
178-
}
179-
return web.Response(
180-
text=query1_server_answer_list,
181-
content_type="application/json",
182-
)
183-
184-
app = web.Application()
185-
app.router.add_route("POST", "/", handler)
186-
server = await aiohttp_server(app)
187-
188-
url = str(server.make_url("/"))
189-
190-
transport = HTTPXTransport(url=url, timeout=10)
191-
192-
def test_code():
193-
with Client(transport=transport) as session:
194-
195-
query = [
196-
GraphQLRequest(
197-
query1_str,
198-
extensions={
199-
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
200-
},
201-
)
202-
]
203-
204-
results = session.execute_batch(query)
205-
206-
continents = results[0]["continents"]
207-
assert continents[0]["code"] == "AF"
208-
209-
await run_sync_test(server, test_code)
151+
assert results[0]["continents"][0]["code"] == "AF"
210152

211153

212154
@pytest.mark.aiohttp

0 commit comments

Comments
 (0)