From 64b7b233cb6ce97e984199b7194161f67d4c77f2 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Wed, 3 Jun 2026 14:51:39 -0700 Subject: [PATCH] Add nexus-operation-token header to nexus callback headers for TemporalOperationHandler and WorkflowRunOperationHandler --- temporalio/nexus/_operation_context.py | 16 ++++--- tests/nexus/test_temporal_operation.py | 39 ++++++++++++++++ tests/nexus/test_workflow_run_operation.py | 53 ++++++++++++++++++++-- 3 files changed, 97 insertions(+), 11 deletions(-) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 01d209a9f..0d9d11449 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -47,7 +47,7 @@ workflow_event_to_nexus_link, workflow_execution_started_event_link_from_workflow_handle, ) -from ._token import WorkflowHandle +from ._token import OperationToken, OperationTokenType, WorkflowHandle if TYPE_CHECKING: import temporalio.client @@ -225,15 +225,14 @@ def get(cls) -> _TemporalStartOperationContext: def set(self) -> None: _temporal_start_operation_context.set(self) - def _get_callbacks( - self, - ) -> list[temporalio.client.Callback]: + def _get_callbacks(self, token: str) -> list[temporalio.client.Callback]: ctx = self.nexus_context + callback_headers = {**ctx.callback_headers, "nexus-operation-token": token} return ( [ NexusCallback( url=ctx.callback_url, - headers=ctx.callback_headers, + headers=callback_headers, ) ] if ctx.callback_url @@ -643,6 +642,11 @@ async def _start_nexus_backing_workflow( # terminal state) and inbound links to the caller workflow (attached to history events of # the workflow started in the handler namespace, and displayed in the UI). with _nexus_backing_workflow_start_context(): + token = OperationToken( + type=OperationTokenType.WORKFLOW, + namespace=temporal_context.client.namespace, + workflow_id=id, + ).encode() wf_handle = await temporal_context.client.start_workflow( # type: ignore workflow=workflow, arg=arg, @@ -669,7 +673,7 @@ async def _start_nexus_backing_workflow( request_eager_start=request_eager_start, priority=priority, versioning_override=versioning_override, - callbacks=temporal_context._get_callbacks(), + callbacks=temporal_context._get_callbacks(token), links=temporal_context._get_links(), request_id=temporal_context.nexus_context.request_id, ) diff --git a/tests/nexus/test_temporal_operation.py b/tests/nexus/test_temporal_operation.py index c101ede3b..c97792c8d 100644 --- a/tests/nexus/test_temporal_operation.py +++ b/tests/nexus/test_temporal_operation.py @@ -12,6 +12,7 @@ from temporalio import nexus, workflow from temporalio.client import Client, WorkflowExecutionStatus, WorkflowFailureError from temporalio.common import NexusOperationExecutionStatus, WorkflowIDConflictPolicy +from temporalio.nexus._token import OperationToken, OperationTokenType from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers import EventType, assert_event_subsequence, assert_eventually @@ -685,3 +686,41 @@ async def test_temporal_operation_overloads( if op == "no_param" else TemporalOperationOverloadTestValue(value=4) ) + + +async def test_temporal_operation_includes_token_in_callback( + client: Client, env: WorkflowEnvironment +): + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[TestServiceHandler()], + workflows=[EchoWorkflow, EchoWorkflowCaller], + ): + input_value = f"test-{uuid.uuid4()}" + wf_handle = await client.start_workflow( + EchoWorkflowCaller.run, + Input(value=input_value, task_queue=task_queue), + task_queue=task_queue, + id=str(uuid.uuid4()), + ) + result = await wf_handle.result() + assert result == input_value + + target_handle = client.get_workflow_handle(f"echo-{input_value}") + + desc = await target_handle.describe() + token = desc.raw_description.callbacks[0].callback.nexus.header[ + "nexus-operation-token" + ] + + expected_token = OperationToken( + type=OperationTokenType.WORKFLOW, + namespace=client.namespace, + workflow_id=target_handle.id, + ).encode() + + assert token == expected_token diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index 3ba9545fc..851f408ec 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -18,6 +18,7 @@ from temporalio.client import Client from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation from temporalio.nexus._operation_handlers import WorkflowRunOperationHandler +from temporalio.nexus._token import OperationToken, OperationTokenType from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import make_nexus_endpoint_name @@ -48,7 +49,7 @@ async def start( handle = await tctx.start_workflow( EchoWorkflow.run, input.value, - id=str(uuid.uuid4()), + id=input.value, ) return StartOperationResultAsync(handle.to_token()) @@ -78,7 +79,7 @@ async def op( return await ctx.start_workflow( EchoWorkflow.run, input.value, - id=str(uuid.uuid4()), + id=input.value, ) @@ -146,13 +147,14 @@ async def test_workflow_run_operation( nexus_service_handlers=[service_handler_cls()], workflows=[CallerWorkflow, EchoWorkflow], ): + input_value = str(uuid.uuid4()) result = await client.execute_workflow( CallerWorkflow.run, - args=[Input(value="test"), service_defn.name, task_queue], + args=[Input(value=input_value), service_defn.name, task_queue], id=str(uuid.uuid4()), task_queue=task_queue, ) - assert result == "test" + assert result == input_value async def test_request_deadline_is_accessible_in_workflow_run_operation( @@ -173,9 +175,10 @@ async def test_request_deadline_is_accessible_in_workflow_run_operation( nexus_service_handlers=[service_handler], workflows=[RequestDeadlineWorkflow, EchoWorkflow], ): + input_value = str(uuid.uuid4()) await client.execute_workflow( RequestDeadlineWorkflow.run, - args=[Input(value="test"), task_queue], + args=[Input(value=input_value), task_queue], task_queue=task_queue, id=str(uuid.uuid4()), ) @@ -186,3 +189,43 @@ async def test_request_deadline_is_accessible_in_workflow_run_operation( "request_deadline should be set in WorkflowRunOperationContext" ) assert deadline.tzinfo is timezone.utc, "request_deadline should be in utc" + + +async def test_workflow_run_operation_includes_token_in_callback( + client: Client, + env: WorkflowEnvironment, +): + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with time-skipping server") + + task_queue = str(uuid.uuid4()) + await env.create_nexus_endpoint(make_nexus_endpoint_name(task_queue), task_queue) + async with Worker( + client, + task_queue=task_queue, + nexus_service_handlers=[SubclassingHappyPath()], + workflows=[CallerWorkflow, EchoWorkflow], + ): + input_value = str(uuid.uuid4()) + result = await client.execute_workflow( + CallerWorkflow.run, + args=[Input(value=input_value), "SubclassingHappyPath", task_queue], + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + assert result == input_value + + target_handle = client.get_workflow_handle(input_value) + + desc = await target_handle.describe() + token = desc.raw_description.callbacks[0].callback.nexus.header[ + "nexus-operation-token" + ] + + expected_token = OperationToken( + type=OperationTokenType.WORKFLOW, + namespace=client.namespace, + workflow_id=target_handle.id, + ).encode() + + assert token == expected_token