Skip to content
Merged
149 changes: 137 additions & 12 deletions durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import logging
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any, Optional, Sequence, TypeVar, Union
from typing import Any, List, Optional, Sequence, TypeVar, Union

import grpc
from google.protobuf import wrappers_pb2
from google.protobuf import wrappers_pb2 as pb2

from durabletask.entities import EntityInstanceId
from durabletask.entities.entity_metadata import EntityMetadata
Expand Down Expand Up @@ -57,6 +57,12 @@ def raise_if_failed(self):
self.failure_details)


class PurgeInstancesResult:
def __init__(self, deleted_instance_count: int, is_complete: bool):
self.deleted_instance_count = deleted_instance_count
self.is_complete = is_complete


class OrchestrationFailedError(Exception):
def __init__(self, message: str, failure_details: task.FailureDetails):
super().__init__(message)
Expand All @@ -73,6 +79,12 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op

state = res.orchestrationState

new_state = parse_orchestration_state(state)
new_state.instance_id = instance_id # Override instance_id with the one from the request, to match old behavior
return new_state


def parse_orchestration_state(state: pb.OrchestrationState) -> OrchestrationState:
failure_details = None
if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '':
failure_details = task.FailureDetails(
Expand All @@ -81,7 +93,7 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
state.failureDetails.stackTrace.value if not helpers.is_empty(state.failureDetails.stackTrace) else None)

return OrchestrationState(
instance_id,
state.instanceId,
state.name,
OrchestrationStatus(state.orchestrationStatus),
state.createdTimestamp.ToDatetime(),
Expand All @@ -93,7 +105,6 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op


class TaskHubGrpcClient:

def __init__(self, *,
host_address: Optional[str] = None,
metadata: Optional[list[tuple[str, str]]] = None,
Expand Down Expand Up @@ -136,7 +147,7 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
req = pb.CreateInstanceRequest(
name=name,
instanceId=instance_id if instance_id else uuid.uuid4().hex,
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
input=helpers.get_string_value(shared.to_json(input)),
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
version=helpers.get_string_value(version if version else self.default_version),
orchestrationIdReusePolicy=reuse_id_policy,
Expand All @@ -152,6 +163,54 @@ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = Tr
res: pb.GetInstanceResponse = self._stub.GetInstance(req)
return new_orchestration_state(req.instanceId, res)

def get_all_orchestration_states(self,
max_instance_count: Optional[int] = None,
fetch_inputs_and_outputs: bool = False) -> List[OrchestrationState]:
return self.get_orchestration_state_by(
created_time_from=None,
created_time_to=None,
runtime_status=None,
max_instance_count=max_instance_count,
fetch_inputs_and_outputs=fetch_inputs_and_outputs
)

def get_orchestration_state_by(self,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a strange method name - is this what we use in other SDKs?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This matches the Funtions Python SDK, but probably better to match the other portable SDKs and have the translation layer for Functions + durabletask python. Will update this in future commit

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to match .NET - takes a Filter instead of a list of params, and there is one method for each. Also removed _continuation_token from the public API

created_time_from: Optional[datetime] = None,
created_time_to: Optional[datetime] = None,
runtime_status: Optional[List[OrchestrationStatus]] = None,
max_instance_count: Optional[int] = None,
fetch_inputs_and_outputs: bool = False,
_continuation_token: Optional[pb2.StringValue] = None
) -> List[OrchestrationState]:
if max_instance_count is None:
# DTS backend does not behave well with max_instance_count = None, so we set to max 32-bit signed value
max_instance_count = (1 << 31) - 1
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Is there no int.max in python?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, int values are unbounded in python

req = pb.QueryInstancesRequest(
query=pb.InstanceQuery(
runtimeStatus=[status.value for status in runtime_status] if runtime_status else None,
createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
maxInstanceCount=max_instance_count,
fetchInputsAndOutputs=fetch_inputs_and_outputs,
continuationToken=_continuation_token
)
)
resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req)
states = [parse_orchestration_state(res) for res in resp.orchestrationState]
# Check the value for continuationToken - none or "0" indicates that there are no more results.
if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next list of instances...")
states += self.get_orchestration_state_by(
created_time_from,
created_time_to,
runtime_status,
max_instance_count,
fetch_inputs_and_outputs,
_continuation_token=resp.continuationToken
)
states = [state for state in states if state is not None] # Filter out any None values
return states

def wait_for_orchestration_start(self, instance_id: str, *,
fetch_payloads: bool = False,
timeout: int = 60) -> Optional[OrchestrationState]:
Expand Down Expand Up @@ -199,7 +258,7 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *,
req = pb.RaiseEventRequest(
instanceId=instance_id,
name=event_name,
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None)
input=helpers.get_string_value(shared.to_json(data)))

self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
self._stub.RaiseEvent(req)
Expand All @@ -209,7 +268,7 @@ def terminate_orchestration(self, instance_id: str, *,
recursive: bool = True):
req = pb.TerminateRequest(
instanceId=instance_id,
output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None,
output=helpers.get_string_value(shared.to_json(output)),
recursive=recursive)

self._logger.info(f"Terminating instance '{instance_id}'.")
Expand All @@ -225,10 +284,27 @@ def resume_orchestration(self, instance_id: str):
self._logger.info(f"Resuming instance '{instance_id}'.")
self._stub.ResumeInstance(req)

def purge_orchestration(self, instance_id: str, recursive: bool = True):
def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
self._logger.info(f"Purging instance '{instance_id}'.")
self._stub.PurgeInstances(req)
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req)
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)

def purge_orchestrations_by(self,
created_time_from: Optional[datetime] = None,
created_time_to: Optional[datetime] = None,
runtime_status: Optional[List[OrchestrationStatus]] = None,
recursive: bool = False) -> PurgeInstancesResult:
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(pb.PurgeInstancesRequest(
instanceId=None,
purgeInstanceFilter=pb.PurgeInstanceFilter(
createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
runtimeStatus=[status.value for status in runtime_status] if runtime_status else None
),
recursive=recursive
))
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)

def signal_entity(self,
entity_instance_id: EntityInstanceId,
Expand All @@ -237,7 +313,7 @@ def signal_entity(self,
req = pb.SignalEntityRequest(
instanceId=str(entity_instance_id),
name=operation_name,
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None,
input=helpers.get_string_value(shared.to_json(input)),
requestId=str(uuid.uuid4()),
scheduledTime=None,
parentTraceContext=None,
Expand All @@ -256,4 +332,53 @@ def get_entity(self,
if not res.exists:
return None

return EntityMetadata.from_entity_response(res, include_state)
return EntityMetadata.from_entity_metadata(res.entity, include_state)

def get_all_entities(self,
include_state: bool = True,
include_transient: bool = False,
page_size: Optional[int] = None) -> List[EntityMetadata]:
return self.get_entities_by(
instance_id_starts_with=None,
last_modified_from=None,
last_modified_to=None,
include_state=include_state,
include_transient=include_transient,
page_size=page_size
)

def get_entities_by(self,
instance_id_starts_with: Optional[str] = None,
last_modified_from: Optional[datetime] = None,
last_modified_to: Optional[datetime] = None,
include_state: bool = True,
include_transient: bool = False,
page_size: Optional[int] = None,
_continuation_token: Optional[pb2.StringValue] = None
) -> List[EntityMetadata]:
self._logger.info(f"Getting entities")
query_request = pb.QueryEntitiesRequest(
query=pb.EntityQuery(
instanceIdStartsWith=helpers.get_string_value(instance_id_starts_with),
lastModifiedFrom=helpers.new_timestamp(last_modified_from) if last_modified_from else None,
lastModifiedTo=helpers.new_timestamp(last_modified_to) if last_modified_to else None,
includeState=include_state,
includeTransient=include_transient,
pageSize=helpers.get_int_value(page_size),
continuationToken=_continuation_token
)
)
resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request)
entities = [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
if resp.continuationToken and resp.continuationToken.value != "0":
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next page of entities...")
entities += self.get_entities_by(
instance_id_starts_with=instance_id_starts_with,
last_modified_from=last_modified_from,
last_modified_to=last_modified_to,
include_state=include_state,
include_transient=include_transient,
page_size=page_size,
_continuation_token=resp.continuationToken
)
return entities
14 changes: 9 additions & 5 deletions durabletask/entities/entity_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,22 @@ def __init__(self,

@staticmethod
def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool):
return EntityMetadata.from_entity_metadata(entity_response.entity, includes_state)

@staticmethod
def from_entity_metadata(entity: pb.EntityMetadata, includes_state: bool):
try:
entity_id = EntityInstanceId.parse(entity_response.entity.instanceId)
entity_id = EntityInstanceId.parse(entity.instanceId)
except ValueError:
raise ValueError("Invalid entity instance ID in entity response.")
entity_state = None
if includes_state:
entity_state = entity_response.entity.serializedState.value
entity_state = entity.serializedState.value
return EntityMetadata(
id=entity_id,
last_modified=entity_response.entity.lastModifiedTime.ToDatetime(timezone.utc),
backlog_queue_size=entity_response.entity.backlogQueueSize,
locked_by=entity_response.entity.lockedBy.value,
last_modified=entity.lastModifiedTime.ToDatetime(timezone.utc),
backlog_queue_size=entity.backlogQueueSize,
locked_by=entity.lockedBy.value,
includes_state=includes_state,
state=entity_state
)
Expand Down
7 changes: 7 additions & 0 deletions durabletask/internal/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ def get_string_value(val: Optional[str]) -> Optional[wrappers_pb2.StringValue]:
return wrappers_pb2.StringValue(value=val)


def get_int_value(val: Optional[int]) -> Optional[wrappers_pb2.Int32Value]:
if val is None:
return None
else:
return wrappers_pb2.Int32Value(value=val)


def get_string_value_or_empty(val: Optional[str]) -> wrappers_pb2.StringValue:
if val is None:
return wrappers_pb2.StringValue(value="")
Expand Down
57 changes: 57 additions & 0 deletions tests/durabletask-azuremanaged/test_dts_batch_actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@

import os

import pytest
from durabletask import client, task
from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker
from durabletask.azuremanaged.client import DurableTaskSchedulerClient

# Read the environment variables
taskhub_name = os.getenv("TASKHUB", "default")
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")

pytestmark = pytest.mark.dts


def empty_orchestrator(ctx: task.OrchestrationContext, _):
return "Complete"


def test_get_all_orchestration_states():
# Start a worker, which will connect to the sidecar in a background thread
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None) as w:
w.add_orchestrator(empty_orchestrator)
w.start()

c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None)
id = c.schedule_new_orchestration(empty_orchestrator, input="Hello")
c.wait_for_orchestration_completion(id, timeout=30)

all_orchestrations = c.get_all_orchestration_states()
all_orchestrations_with_state = c.get_all_orchestration_states(fetch_inputs_and_outputs=True)
this_orch = c.get_orchestration_state(id)

assert this_orch is not None
assert this_orch.instance_id == id

assert all_orchestrations is not None
assert len(all_orchestrations) > 1
print(f"Received {len(all_orchestrations)} orchestrations")
assert len([o for o in all_orchestrations if o.instance_id == id]) == 1
orchestration_state = [o for o in all_orchestrations if o.instance_id == id][0]
assert orchestration_state.runtime_status == client.OrchestrationStatus.COMPLETED
assert orchestration_state.serialized_input is None
assert orchestration_state.serialized_output is None
assert orchestration_state.failure_details is None

assert all_orchestrations_with_state is not None
assert len(all_orchestrations_with_state) > 1
print(f"Received {len(all_orchestrations_with_state)} orchestrations")
assert len([o for o in all_orchestrations_with_state if o.instance_id == id]) == 1
orchestration_state = [o for o in all_orchestrations_with_state if o.instance_id == id][0]
assert orchestration_state.runtime_status == client.OrchestrationStatus.COMPLETED
assert orchestration_state.serialized_input == '"Hello"'
assert orchestration_state.serialized_output == '"Complete"'
assert orchestration_state.failure_details is None
33 changes: 33 additions & 0 deletions tests/durabletask/test_batch_actions.py
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewers - I'm not sure these tests are worth adding, tbh. If we ever implement these gRPC endpoints into the Go sidecar, our tests will immediately break. Any objection to just commenting them out for now, and relying on the DTS tests?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, it looks like they're just a reminder for us to write tests once they are implemented. However, I'm confused why they rely on the go sidecar since we're in the python SDK? If they can break from a PR outside of this repo, I think they become less useful.

If this is truly just a test for the go sidecar, I would rather we have those in the go repo. Can you expand on this a little bit?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The durabletask (non-DTS) tests use the Go sidecar as the emulator for the durability layer, just like how the DTS tests use the DTS emulator docker image, yes, it's testing two things but AFAIK, and this was before my time, it was the easiest way to set up E2E for this project.
I'll go ahead and remove these tests, with a note to add them when possible

Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

from durabletask import client, task, worker


def empty_orchestrator(ctx: task.OrchestrationContext, _):
return "Complete"


def test_get_all_orchestration_states():
# Start a worker, which will connect to the sidecar in a background thread
with worker.TaskHubGrpcWorker() as w:
w.add_orchestrator(empty_orchestrator)
w.start()

c = client.TaskHubGrpcClient()
id = c.schedule_new_orchestration(empty_orchestrator, input="Hello")
c.wait_for_orchestration_completion(id, timeout=30)

all_orchestrations = c.get_all_orchestration_states()
this_orch = c.get_orchestration_state(id)

assert this_orch is not None
assert this_orch.instance_id == id

assert all_orchestrations is not None
assert len(all_orchestrations) > 1
print(f"Received {len(all_orchestrations)} orchestrations")
assert len([o for o in all_orchestrations if o.instance_id == id]) == 1
orchestration_state = [o for o in all_orchestrations if o.instance_id == id][0]
assert orchestration_state.runtime_status == client.OrchestrationStatus.COMPLETED
assert orchestration_state.serialized_input == '"Hello"'
assert orchestration_state.serialized_output == '"Complete"'
assert orchestration_state.failure_details is None
Loading