Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions temporalio/nexus/_operation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
workflow_event_to_nexus_link,
workflow_execution_started_event_link_from_workflow_handle,
)
from ._token import WorkflowHandle
from ._token import OperationToken, OperationTokenType, WorkflowHandle

if TYPE_CHECKING:
import temporalio.client
Expand Down Expand Up @@ -225,15 +225,14 @@ def get(cls) -> _TemporalStartOperationContext:
def set(self) -> None:
_temporal_start_operation_context.set(self)

def _get_callbacks(
self,
) -> list[temporalio.client.Callback]:
def _get_callbacks(self, token: str) -> list[temporalio.client.Callback]:
ctx = self.nexus_context
callback_headers = {**ctx.callback_headers, "nexus-operation-token": token}
return (
[
NexusCallback(
url=ctx.callback_url,
headers=ctx.callback_headers,
headers=callback_headers,
)
]
if ctx.callback_url
Expand Down Expand Up @@ -643,6 +642,11 @@ async def _start_nexus_backing_workflow(
# terminal state) and inbound links to the caller workflow (attached to history events of
# the workflow started in the handler namespace, and displayed in the UI).
with _nexus_backing_workflow_start_context():
token = OperationToken(
type=OperationTokenType.WORKFLOW,
namespace=temporal_context.client.namespace,
workflow_id=id,
).encode()
wf_handle = await temporal_context.client.start_workflow( # type: ignore
workflow=workflow,
arg=arg,
Expand All @@ -669,7 +673,7 @@ async def _start_nexus_backing_workflow(
request_eager_start=request_eager_start,
priority=priority,
versioning_override=versioning_override,
callbacks=temporal_context._get_callbacks(),
callbacks=temporal_context._get_callbacks(token),
links=temporal_context._get_links(),
request_id=temporal_context.nexus_context.request_id,
)
Expand Down
39 changes: 39 additions & 0 deletions tests/nexus/test_temporal_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from temporalio import nexus, workflow
from temporalio.client import Client, WorkflowExecutionStatus, WorkflowFailureError
from temporalio.common import NexusOperationExecutionStatus, WorkflowIDConflictPolicy
from temporalio.nexus._token import OperationToken, OperationTokenType
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker
from tests.helpers import EventType, assert_event_subsequence, assert_eventually
Expand Down Expand Up @@ -685,3 +686,41 @@ async def test_temporal_operation_overloads(
if op == "no_param"
else TemporalOperationOverloadTestValue(value=4)
)


async def test_temporal_operation_includes_token_in_callback(
client: Client, env: WorkflowEnvironment
):
task_queue = str(uuid.uuid4())
endpoint_name = make_nexus_endpoint_name(task_queue)
await env.create_nexus_endpoint(endpoint_name, task_queue)
async with Worker(
env.client,
task_queue=task_queue,
nexus_service_handlers=[TestServiceHandler()],
workflows=[EchoWorkflow, EchoWorkflowCaller],
):
input_value = f"test-{uuid.uuid4()}"
wf_handle = await client.start_workflow(
EchoWorkflowCaller.run,
Input(value=input_value, task_queue=task_queue),
task_queue=task_queue,
id=str(uuid.uuid4()),
)
result = await wf_handle.result()
assert result == input_value

target_handle = client.get_workflow_handle(f"echo-{input_value}")

desc = await target_handle.describe()
token = desc.raw_description.callbacks[0].callback.nexus.header[
"nexus-operation-token"
]

expected_token = OperationToken(
type=OperationTokenType.WORKFLOW,
namespace=client.namespace,
workflow_id=target_handle.id,
).encode()

assert token == expected_token
53 changes: 48 additions & 5 deletions tests/nexus/test_workflow_run_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from temporalio.client import Client
from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation
from temporalio.nexus._operation_handlers import WorkflowRunOperationHandler
from temporalio.nexus._token import OperationToken, OperationTokenType
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker
from tests.helpers.nexus import make_nexus_endpoint_name
Expand Down Expand Up @@ -48,7 +49,7 @@ async def start(
handle = await tctx.start_workflow(
EchoWorkflow.run,
input.value,
id=str(uuid.uuid4()),
id=input.value,
)
return StartOperationResultAsync(handle.to_token())

Expand Down Expand Up @@ -78,7 +79,7 @@ async def op(
return await ctx.start_workflow(
EchoWorkflow.run,
input.value,
id=str(uuid.uuid4()),
id=input.value,
)


Expand Down Expand Up @@ -146,13 +147,14 @@ async def test_workflow_run_operation(
nexus_service_handlers=[service_handler_cls()],
workflows=[CallerWorkflow, EchoWorkflow],
):
input_value = str(uuid.uuid4())
result = await client.execute_workflow(
CallerWorkflow.run,
args=[Input(value="test"), service_defn.name, task_queue],
args=[Input(value=input_value), service_defn.name, task_queue],
id=str(uuid.uuid4()),
task_queue=task_queue,
)
assert result == "test"
assert result == input_value


async def test_request_deadline_is_accessible_in_workflow_run_operation(
Expand All @@ -173,9 +175,10 @@ async def test_request_deadline_is_accessible_in_workflow_run_operation(
nexus_service_handlers=[service_handler],
workflows=[RequestDeadlineWorkflow, EchoWorkflow],
):
input_value = str(uuid.uuid4())
await client.execute_workflow(
RequestDeadlineWorkflow.run,
args=[Input(value="test"), task_queue],
args=[Input(value=input_value), task_queue],
task_queue=task_queue,
id=str(uuid.uuid4()),
)
Expand All @@ -186,3 +189,43 @@ async def test_request_deadline_is_accessible_in_workflow_run_operation(
"request_deadline should be set in WorkflowRunOperationContext"
)
assert deadline.tzinfo is timezone.utc, "request_deadline should be in utc"


async def test_workflow_run_operation_includes_token_in_callback(
client: Client,
env: WorkflowEnvironment,
):
if env.supports_time_skipping:
pytest.skip("Nexus tests don't work with time-skipping server")

task_queue = str(uuid.uuid4())
await env.create_nexus_endpoint(make_nexus_endpoint_name(task_queue), task_queue)
async with Worker(
client,
task_queue=task_queue,
nexus_service_handlers=[SubclassingHappyPath()],
workflows=[CallerWorkflow, EchoWorkflow],
):
input_value = str(uuid.uuid4())
result = await client.execute_workflow(
CallerWorkflow.run,
args=[Input(value=input_value), "SubclassingHappyPath", task_queue],
id=str(uuid.uuid4()),
task_queue=task_queue,
)
assert result == input_value

target_handle = client.get_workflow_handle(input_value)

desc = await target_handle.describe()
token = desc.raw_description.callbacks[0].callback.nexus.header[
"nexus-operation-token"
]

expected_token = OperationToken(
type=OperationTokenType.WORKFLOW,
namespace=client.namespace,
workflow_id=target_handle.id,
).encode()

assert token == expected_token
Loading