Skip to content

Commit 76c02b6

Browse files
authored
Add extensions support to GraphQLRequest (#591)
1 parent 6d3ffad commit 76c02b6

10 files changed

Lines changed: 311 additions & 4 deletions

docs/usage/extensions.rst

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,38 @@
33
Extensions
44
----------
55

6+
Request extensions
7+
^^^^^^^^^^^^^^^^^^
8+
9+
The `GraphQL over HTTP spec <https://github.com/graphql/graphql-over-http>`_
10+
defines an optional :code:`extensions` field on requests. This is sent as a
11+
top-level key in the request payload alongside :code:`query`, :code:`variables`,
12+
and :code:`operationName`.
13+
14+
You can use this to pass protocol extensions such as
15+
`trusted documents <https://graphql.org/learn/security/#trusted-documents>`_:
16+
17+
.. code-block:: python
18+
19+
from gql import Client, GraphQLRequest
20+
from gql.transport.aiohttp import AIOHTTPTransport
21+
22+
transport = AIOHTTPTransport(url="https://example.com/graphql")
23+
24+
async with Client(transport=transport) as session:
25+
26+
request = GraphQLRequest(
27+
"query { viewer { name } }",
28+
extensions={
29+
"document-id": "155d6e8f5545...",
30+
},
31+
)
32+
33+
result = await session.execute(request)
34+
35+
Response extensions
36+
^^^^^^^^^^^^^^^^^^^
37+
638
When you execute (or subscribe) GraphQL requests, the server will send
739
responses which may have 3 fields:
840

gql/graphql_request.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def __init__(
1313
*,
1414
variable_values: Optional[Dict[str, Any]] = None,
1515
operation_name: Optional[str] = None,
16+
extensions: Optional[Dict[str, Any]] = None,
1617
):
1718
"""Initialize a GraphQL request.
1819
@@ -21,6 +22,9 @@ def __init__(
2122
:param variable_values: Dictionary of input parameters (Default: None).
2223
:param operation_name: Name of the operation that shall be executed.
2324
Only required in multi-operation documents (Default: None).
25+
:param extensions: Dictionary of protocol extensions (Default: None).
26+
This is passed as the top-level "extensions" key in the request
27+
payload, as defined in the GraphQL over HTTP spec.
2428
:return: a :class:`GraphQLRequest <gql.GraphQLRequest>`
2529
which can be later executed or subscribed by a
2630
:class:`Client <gql.client.Client>`, by an
@@ -42,9 +46,12 @@ def __init__(
4246
variable_values = request.variable_values
4347
if operation_name is None:
4448
operation_name = request.operation_name
49+
if extensions is None:
50+
extensions = request.extensions
4551

4652
self.variable_values: Optional[Dict[str, Any]] = variable_values
4753
self.operation_name: Optional[str] = operation_name
54+
self.extensions: Optional[Dict[str, Any]] = extensions
4855

4956
def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest":
5057

@@ -61,6 +68,7 @@ def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest":
6168
operation_name=self.operation_name,
6269
),
6370
operation_name=self.operation_name,
71+
extensions=self.extensions,
6472
)
6573

6674
@property
@@ -74,6 +82,9 @@ def payload(self) -> Dict[str, Any]:
7482
if self.variable_values:
7583
payload["variables"] = self.variable_values
7684

85+
if self.extensions:
86+
payload["extensions"] = self.extensions
87+
7788
return payload
7889

7990
def __str__(self):

tests/test_aiohttp.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88

9-
from gql import Client, FileVar, gql
9+
from gql import Client, FileVar, GraphQLRequest, gql
1010
from gql.cli import get_parser, main
1111
from gql.transport.exceptions import (
1212
TransportAlreadyConnected,
@@ -87,6 +87,43 @@ async def handler(request):
8787
assert transport.response_headers["dummy"] == "test1234"
8888

8989

90+
@pytest.mark.asyncio
91+
async def test_aiohttp_request_extensions(aiohttp_server):
92+
from aiohttp import web
93+
94+
from gql.transport.aiohttp import AIOHTTPTransport
95+
96+
extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}
97+
98+
async def handler(request):
99+
body = await request.json()
100+
assert body["extensions"] == extensions
101+
return web.Response(
102+
text=query1_server_answer,
103+
content_type="application/json",
104+
)
105+
106+
app = web.Application()
107+
app.router.add_route("POST", "/", handler)
108+
server = await aiohttp_server(app)
109+
110+
url = server.make_url("/")
111+
112+
transport = AIOHTTPTransport(url=url, timeout=10)
113+
114+
request = GraphQLRequest(query1_str, extensions=extensions)
115+
116+
async with Client(transport=transport) as session:
117+
118+
# execute
119+
result = await session.execute(request)
120+
assert result["continents"][0]["code"] == "AF"
121+
122+
# subscribe
123+
async for result in session.subscribe(request):
124+
assert result["continents"][0]["code"] == "AF"
125+
126+
90127
@pytest.mark.asyncio
91128
async def test_aiohttp_ignore_backend_content_type(aiohttp_server):
92129
from aiohttp import web

tests/test_aiohttp_batch.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,38 @@ async def handler(request):
8989
assert transport.response_headers["dummy"] == "test1234"
9090

9191

92+
@pytest.mark.asyncio
93+
async def test_aiohttp_batch_request_extensions(aiohttp_server):
94+
from aiohttp import web
95+
96+
from gql.transport.aiohttp import AIOHTTPTransport
97+
98+
extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}
99+
100+
async def handler(request):
101+
body = await request.json()
102+
assert isinstance(body, list)
103+
assert body[0]["extensions"] == extensions
104+
return web.Response(
105+
text=query1_server_answer_list,
106+
content_type="application/json",
107+
)
108+
109+
app = web.Application()
110+
app.router.add_route("POST", "/", handler)
111+
server = await aiohttp_server(app)
112+
113+
url = server.make_url("/")
114+
115+
transport = AIOHTTPTransport(url=url, timeout=10)
116+
117+
async with Client(transport=transport) as session:
118+
119+
query = [GraphQLRequest(query1_str, extensions=extensions)]
120+
results = await session.execute_batch(query)
121+
assert results[0]["continents"][0]["code"] == "AF"
122+
123+
92124
@pytest.mark.asyncio
93125
async def test_aiohttp_batch_query_auto_batch_enabled(aiohttp_server, run_sync_test):
94126
from aiohttp import web

tests/test_aiohttp_websocket_subscription.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from graphql import ExecutionResult
99
from parse import search
1010

11-
from gql import Client, gql
11+
from gql import Client, GraphQLRequest, gql
1212
from gql.client import AsyncClientSession
1313
from gql.transport.exceptions import TransportConnectionFailed, TransportServerError
1414

@@ -460,6 +460,37 @@ async def test_aiohttp_websocket_subscription_with_operation_name(
460460
assert '"operationName": "CountdownSubscription"' in logged_messages[0]
461461

462462

463+
@pytest.mark.asyncio
464+
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
465+
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
466+
async def test_aiohttp_websocket_subscription_with_extensions(
467+
aiohttp_client_and_server, subscription_str
468+
):
469+
470+
session, server = aiohttp_client_and_server
471+
472+
count = 10
473+
request = GraphQLRequest(
474+
subscription_str.format(count=count),
475+
extensions={"persistedQuery": {"version": 1, "sha256Hash": "abc123"}},
476+
)
477+
478+
async for result in session.subscribe(request):
479+
480+
number = result["number"]
481+
print(f"Number received: {number}")
482+
483+
assert number == count
484+
count -= 1
485+
486+
assert count == -1
487+
488+
message = json.loads(logged_messages[0])
489+
assert message["payload"]["extensions"] == {
490+
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
491+
}
492+
493+
463494
WITH_KEEPALIVE = True
464495

465496

tests/test_graphql_request.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,31 @@ def test_graphql_request_init_with_graphql_request():
236236
assert request_1.variable_values["money"] == money_value_1
237237
assert request_2.variable_values["money"] == money_value_1
238238
assert request_3.variable_values["money"] == money_value_2
239+
240+
241+
def test_graphql_request_extensions():
242+
extensions_1 = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}
243+
extensions_2 = {"custom": "value"}
244+
money_value = Money(10, "DM")
245+
246+
assert "extensions" not in GraphQLRequest("{balance}").payload
247+
248+
request_1 = GraphQLRequest("{balance}", extensions=extensions_1)
249+
assert request_1.payload["extensions"] == extensions_1
250+
251+
# Copied from another GraphQLRequest
252+
request_2 = GraphQLRequest(request_1)
253+
assert request_2.extensions == extensions_1
254+
255+
# Explicit extensions override the copied value
256+
request_3 = GraphQLRequest(request_1, extensions=extensions_2)
257+
assert request_3.extensions == extensions_2
258+
259+
# Preserved through serialize_variable_values
260+
request_4 = GraphQLRequest(
261+
"query myquery($money: Money) {toEuros(money: $money)}",
262+
variable_values={"money": money_value},
263+
extensions=extensions_1,
264+
)
265+
serialized = request_4.serialize_variable_values(schema)
266+
assert serialized.extensions == extensions_1

tests/test_httpx.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from gql import Client, FileVar, gql
6+
from gql import Client, FileVar, GraphQLRequest, gql
77
from gql.transport.exceptions import (
88
TransportAlreadyConnected,
99
TransportClosed,
@@ -84,6 +84,40 @@ def test_code():
8484
await run_sync_test(server, test_code)
8585

8686

87+
@pytest.mark.aiohttp
88+
@pytest.mark.asyncio
89+
async def test_httpx_request_extensions(aiohttp_server, run_sync_test):
90+
from aiohttp import web
91+
92+
from gql.transport.httpx import HTTPXTransport
93+
94+
extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}
95+
96+
async def handler(request):
97+
body = await request.json()
98+
assert body["extensions"] == extensions
99+
return web.Response(
100+
text=query1_server_answer,
101+
content_type="application/json",
102+
)
103+
104+
app = web.Application()
105+
app.router.add_route("POST", "/", handler)
106+
server = await aiohttp_server(app)
107+
108+
url = str(server.make_url("/"))
109+
110+
def test_code():
111+
transport = HTTPXTransport(url=url)
112+
113+
with Client(transport=transport) as session:
114+
request = GraphQLRequest(query1_str, extensions=extensions)
115+
result = session.execute(request)
116+
assert result["continents"][0]["code"] == "AF"
117+
118+
await run_sync_test(server, test_code)
119+
120+
87121
@pytest.mark.aiohttp
88122
@pytest.mark.asyncio
89123
@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"])

tests/test_httpx_batch.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,39 @@ def test_code():
118118
await run_sync_test(server, test_code)
119119

120120

121+
@pytest.mark.aiohttp
122+
@pytest.mark.asyncio
123+
async def test_httpx_batch_request_extensions(aiohttp_server):
124+
from aiohttp import web
125+
126+
from gql.transport.httpx import HTTPXAsyncTransport
127+
128+
extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}
129+
130+
async def handler(request):
131+
body = await request.json()
132+
assert isinstance(body, list)
133+
assert body[0]["extensions"] == extensions
134+
return web.Response(
135+
text=query1_server_answer_list,
136+
content_type="application/json",
137+
)
138+
139+
app = web.Application()
140+
app.router.add_route("POST", "/", handler)
141+
server = await aiohttp_server(app)
142+
143+
url = str(server.make_url("/"))
144+
145+
transport = HTTPXAsyncTransport(url=url, timeout=10)
146+
147+
async with Client(transport=transport) as session:
148+
149+
query = [GraphQLRequest(query1_str, extensions=extensions)]
150+
results = await session.execute_batch(query)
151+
assert results[0]["continents"][0]["code"] == "AF"
152+
153+
121154
@pytest.mark.aiohttp
122155
@pytest.mark.asyncio
123156
async def test_httpx_async_batch_query_without_session(aiohttp_server, run_sync_test):

tests/test_requests.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66

7-
from gql import Client, FileVar, gql
7+
from gql import Client, FileVar, GraphQLRequest, gql
88
from gql.transport.exceptions import (
99
TransportAlreadyConnected,
1010
TransportClosed,
@@ -85,6 +85,40 @@ def test_code():
8585
await run_sync_test(server, test_code)
8686

8787

88+
@pytest.mark.aiohttp
89+
@pytest.mark.asyncio
90+
async def test_requests_request_extensions(aiohttp_server, run_sync_test):
91+
from aiohttp import web
92+
93+
from gql.transport.requests import RequestsHTTPTransport
94+
95+
extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}
96+
97+
async def handler(request):
98+
body = await request.json()
99+
assert body["extensions"] == extensions
100+
return web.Response(
101+
text=query1_server_answer,
102+
content_type="application/json",
103+
)
104+
105+
app = web.Application()
106+
app.router.add_route("POST", "/", handler)
107+
server = await aiohttp_server(app)
108+
109+
url = server.make_url("/")
110+
111+
def test_code():
112+
transport = RequestsHTTPTransport(url=url)
113+
114+
with Client(transport=transport) as session:
115+
request = GraphQLRequest(query1_str, extensions=extensions)
116+
result = session.execute(request)
117+
assert result["continents"][0]["code"] == "AF"
118+
119+
await run_sync_test(server, test_code)
120+
121+
88122
@pytest.mark.aiohttp
89123
@pytest.mark.asyncio
90124
@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"])

0 commit comments

Comments
 (0)