-
Notifications
You must be signed in to change notification settings - Fork 25
Add batch actions (purge, query orchestrations/entities) #111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
4494708
93da9bf
3e26a68
b274398
e610386
8213512
1791b65
9fcdc6d
867bf3b
ae20682
d6775f7
2b52e98
dd2548a
540421c
5059b82
5406856
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,12 +4,12 @@ | |
| import logging | ||
| import uuid | ||
| from dataclasses import dataclass | ||
| from datetime import datetime, timezone | ||
| from datetime import datetime, timedelta, timezone | ||
| from enum import Enum | ||
| from typing import Any, Optional, Sequence, TypeVar, Union | ||
| from typing import Any, List, Optional, Sequence, TypeVar, Union | ||
|
|
||
| import grpc | ||
| from google.protobuf import wrappers_pb2 | ||
| from google.protobuf import wrappers_pb2 as pb2 | ||
|
|
||
| from durabletask.entities import EntityInstanceId | ||
| from durabletask.entities.entity_metadata import EntityMetadata | ||
|
|
@@ -57,6 +57,12 @@ def raise_if_failed(self): | |
| self.failure_details) | ||
|
|
||
|
|
||
| class PurgeInstancesResult: | ||
| def __init__(self, deleted_instance_count: int, is_complete: bool): | ||
| self.deleted_instance_count = deleted_instance_count | ||
| self.is_complete = is_complete | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| class OrchestrationFailedError(Exception): | ||
| def __init__(self, message: str, failure_details: task.FailureDetails): | ||
| super().__init__(message) | ||
|
|
@@ -73,6 +79,12 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op | |
|
|
||
| state = res.orchestrationState | ||
|
|
||
| new_state = parse_orchestration_state(state) | ||
| new_state.instance_id = instance_id # Override instance_id with the one from the request, to match old behavior | ||
| return new_state | ||
|
|
||
|
|
||
| def parse_orchestration_state(state: pb.OrchestrationState) -> OrchestrationState: | ||
| failure_details = None | ||
| if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '': | ||
| failure_details = task.FailureDetails( | ||
|
|
@@ -81,7 +93,7 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op | |
| state.failureDetails.stackTrace.value if not helpers.is_empty(state.failureDetails.stackTrace) else None) | ||
|
|
||
| return OrchestrationState( | ||
| instance_id, | ||
| state.instanceId, | ||
| state.name, | ||
| OrchestrationStatus(state.orchestrationStatus), | ||
| state.createdTimestamp.ToDatetime(), | ||
|
|
@@ -93,7 +105,6 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op | |
|
|
||
|
|
||
| class TaskHubGrpcClient: | ||
|
|
||
| def __init__(self, *, | ||
| host_address: Optional[str] = None, | ||
| metadata: Optional[list[tuple[str, str]]] = None, | ||
|
|
@@ -136,7 +147,7 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu | |
| req = pb.CreateInstanceRequest( | ||
| name=name, | ||
| instanceId=instance_id if instance_id else uuid.uuid4().hex, | ||
| input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None, | ||
| input=helpers.get_string_value(shared.to_json(input)), | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, | ||
| version=helpers.get_string_value(version if version else self.default_version), | ||
| orchestrationIdReusePolicy=reuse_id_policy, | ||
|
|
@@ -152,6 +163,54 @@ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = Tr | |
| res: pb.GetInstanceResponse = self._stub.GetInstance(req) | ||
| return new_orchestration_state(req.instanceId, res) | ||
|
|
||
| def get_all_orchestration_states(self, | ||
| max_instance_count: Optional[int] = None, | ||
| fetch_inputs_and_outputs: bool = False) -> List[OrchestrationState]: | ||
| return self.get_orchestration_state_by( | ||
| created_time_from=None, | ||
| created_time_to=None, | ||
| runtime_status=None, | ||
| max_instance_count=max_instance_count, | ||
| fetch_inputs_and_outputs=fetch_inputs_and_outputs | ||
| ) | ||
|
|
||
| def get_orchestration_state_by(self, | ||
|
||
| created_time_from: Optional[datetime] = None, | ||
| created_time_to: Optional[datetime] = None, | ||
| runtime_status: Optional[List[OrchestrationStatus]] = None, | ||
| max_instance_count: Optional[int] = None, | ||
| fetch_inputs_and_outputs: bool = False, | ||
| _continuation_token: Optional[pb2.StringValue] = None | ||
| ) -> List[OrchestrationState]: | ||
| if max_instance_count is None: | ||
| # DTS backend does not behave well with max_instance_count = None, so we set to max 32-bit signed value | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| max_instance_count = (1 << 31) - 1 | ||
|
||
| req = pb.QueryInstancesRequest( | ||
| query=pb.InstanceQuery( | ||
| runtimeStatus=[status.value for status in runtime_status] if runtime_status else None, | ||
| createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None, | ||
| createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None, | ||
| maxInstanceCount=max_instance_count, | ||
| fetchInputsAndOutputs=fetch_inputs_and_outputs, | ||
| continuationToken=_continuation_token | ||
| ) | ||
| ) | ||
| resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req) | ||
| states = [parse_orchestration_state(res) for res in resp.orchestrationState] | ||
| # Check the value for continuationToken - none or "0" indicates that there are no more results. | ||
| if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0": | ||
| self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next list of instances...") | ||
| states += self.get_orchestration_state_by( | ||
| created_time_from, | ||
| created_time_to, | ||
| runtime_status, | ||
| max_instance_count, | ||
| fetch_inputs_and_outputs, | ||
| _continuation_token=resp.continuationToken | ||
| ) | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| states = [state for state in states if state is not None] # Filter out any None values | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return states | ||
|
|
||
| def wait_for_orchestration_start(self, instance_id: str, *, | ||
| fetch_payloads: bool = False, | ||
| timeout: int = 60) -> Optional[OrchestrationState]: | ||
|
|
@@ -199,7 +258,7 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *, | |
| req = pb.RaiseEventRequest( | ||
| instanceId=instance_id, | ||
| name=event_name, | ||
| input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None) | ||
| input=helpers.get_string_value(shared.to_json(data))) | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") | ||
| self._stub.RaiseEvent(req) | ||
|
|
@@ -209,7 +268,7 @@ def terminate_orchestration(self, instance_id: str, *, | |
| recursive: bool = True): | ||
| req = pb.TerminateRequest( | ||
| instanceId=instance_id, | ||
| output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None, | ||
| output=helpers.get_string_value(shared.to_json(output)), | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| recursive=recursive) | ||
|
|
||
| self._logger.info(f"Terminating instance '{instance_id}'.") | ||
|
|
@@ -225,10 +284,27 @@ def resume_orchestration(self, instance_id: str): | |
| self._logger.info(f"Resuming instance '{instance_id}'.") | ||
| self._stub.ResumeInstance(req) | ||
|
|
||
| def purge_orchestration(self, instance_id: str, recursive: bool = True): | ||
| def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult: | ||
| req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive) | ||
| self._logger.info(f"Purging instance '{instance_id}'.") | ||
| self._stub.PurgeInstances(req) | ||
| resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req) | ||
| return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value) | ||
|
|
||
| def purge_orchestrations_by(self, | ||
| created_time_from: Optional[datetime] = None, | ||
| created_time_to: Optional[datetime] = None, | ||
| runtime_status: Optional[List[OrchestrationStatus]] = None, | ||
| recursive: bool = False) -> PurgeInstancesResult: | ||
| resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(pb.PurgeInstancesRequest( | ||
| instanceId=None, | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| purgeInstanceFilter=pb.PurgeInstanceFilter( | ||
| createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None, | ||
| createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None, | ||
| runtimeStatus=[status.value for status in runtime_status] if runtime_status else None | ||
| ), | ||
andystaples marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| recursive=recursive | ||
| )) | ||
| return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value) | ||
|
|
||
| def signal_entity(self, | ||
| entity_instance_id: EntityInstanceId, | ||
|
|
@@ -237,7 +313,7 @@ def signal_entity(self, | |
| req = pb.SignalEntityRequest( | ||
| instanceId=str(entity_instance_id), | ||
| name=operation_name, | ||
| input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None, | ||
| input=helpers.get_string_value(shared.to_json(input)), | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| requestId=str(uuid.uuid4()), | ||
| scheduledTime=None, | ||
| parentTraceContext=None, | ||
|
|
@@ -256,4 +332,53 @@ def get_entity(self, | |
| if not res.exists: | ||
| return None | ||
|
|
||
| return EntityMetadata.from_entity_response(res, include_state) | ||
| return EntityMetadata.from_entity_metadata(res.entity, include_state) | ||
|
|
||
| def get_all_entities(self, | ||
| include_state: bool = True, | ||
| include_transient: bool = False, | ||
| page_size: Optional[int] = None) -> List[EntityMetadata]: | ||
| return self.get_entities_by( | ||
| instance_id_starts_with=None, | ||
| last_modified_from=None, | ||
| last_modified_to=None, | ||
| include_state=include_state, | ||
| include_transient=include_transient, | ||
| page_size=page_size | ||
| ) | ||
|
|
||
| def get_entities_by(self, | ||
| instance_id_starts_with: Optional[str] = None, | ||
| last_modified_from: Optional[datetime] = None, | ||
| last_modified_to: Optional[datetime] = None, | ||
| include_state: bool = True, | ||
| include_transient: bool = False, | ||
| page_size: Optional[int] = None, | ||
| _continuation_token: Optional[pb2.StringValue] = None | ||
| ) -> List[EntityMetadata]: | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self._logger.info(f"Getting entities") | ||
| query_request = pb.QueryEntitiesRequest( | ||
| query=pb.EntityQuery( | ||
| instanceIdStartsWith=helpers.get_string_value(instance_id_starts_with), | ||
| lastModifiedFrom=helpers.new_timestamp(last_modified_from) if last_modified_from else None, | ||
| lastModifiedTo=helpers.new_timestamp(last_modified_to) if last_modified_to else None, | ||
| includeState=include_state, | ||
| includeTransient=include_transient, | ||
| pageSize=helpers.get_int_value(page_size), | ||
| continuationToken=_continuation_token | ||
| ) | ||
| ) | ||
| resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request) | ||
| entities = [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities] | ||
| if resp.continuationToken and resp.continuationToken.value != "0": | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next page of entities...") | ||
| entities += self.get_entities_by( | ||
| instance_id_starts_with=instance_id_starts_with, | ||
| last_modified_from=last_modified_from, | ||
| last_modified_to=last_modified_to, | ||
| include_state=include_state, | ||
| include_transient=include_transient, | ||
| page_size=page_size, | ||
| _continuation_token=resp.continuationToken | ||
| ) | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return entities | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
|
|
||
| import os | ||
|
|
||
| import pytest | ||
| from durabletask import client, task | ||
| from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker | ||
| from durabletask.azuremanaged.client import DurableTaskSchedulerClient | ||
|
|
||
| # Read the environment variables | ||
| taskhub_name = os.getenv("TASKHUB", "default") | ||
| endpoint = os.getenv("ENDPOINT", "http://localhost:8080") | ||
|
|
||
| pytestmark = pytest.mark.dts | ||
|
|
||
|
|
||
| def empty_orchestrator(ctx: task.OrchestrationContext, _): | ||
| return "Complete" | ||
|
|
||
|
|
||
| def test_get_all_orchestration_states(): | ||
| # Start a worker, which will connect to the sidecar in a background thread | ||
| with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, | ||
| taskhub=taskhub_name, token_credential=None) as w: | ||
| w.add_orchestrator(empty_orchestrator) | ||
| w.start() | ||
|
|
||
| c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, | ||
| taskhub=taskhub_name, token_credential=None) | ||
| id = c.schedule_new_orchestration(empty_orchestrator, input="Hello") | ||
| c.wait_for_orchestration_completion(id, timeout=30) | ||
|
|
||
| all_orchestrations = c.get_all_orchestration_states() | ||
| all_orchestrations_with_state = c.get_all_orchestration_states(fetch_inputs_and_outputs=True) | ||
| this_orch = c.get_orchestration_state(id) | ||
|
|
||
| assert this_orch is not None | ||
| assert this_orch.instance_id == id | ||
|
|
||
| assert all_orchestrations is not None | ||
| assert len(all_orchestrations) > 1 | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| print(f"Received {len(all_orchestrations)} orchestrations") | ||
| assert len([o for o in all_orchestrations if o.instance_id == id]) == 1 | ||
| orchestration_state = [o for o in all_orchestrations if o.instance_id == id][0] | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert orchestration_state.runtime_status == client.OrchestrationStatus.COMPLETED | ||
| assert orchestration_state.serialized_input is None | ||
| assert orchestration_state.serialized_output is None | ||
| assert orchestration_state.failure_details is None | ||
|
|
||
| assert all_orchestrations_with_state is not None | ||
| assert len(all_orchestrations_with_state) > 1 | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| print(f"Received {len(all_orchestrations_with_state)} orchestrations") | ||
| assert len([o for o in all_orchestrations_with_state if o.instance_id == id]) == 1 | ||
| orchestration_state = [o for o in all_orchestrations_with_state if o.instance_id == id][0] | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert orchestration_state.runtime_status == client.OrchestrationStatus.COMPLETED | ||
| assert orchestration_state.serialized_input == '"Hello"' | ||
| assert orchestration_state.serialized_output == '"Complete"' | ||
| assert orchestration_state.failure_details is None | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reviewers - I'm not sure these tests are worth adding, tbh. If we ever implement these gRPC endpoints into the Go sidecar, our tests will immediately break. Any objection to just commenting them out for now, and relying on the DTS tests?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now, it looks like they're just a reminder for us to write tests once they are implemented. However, I'm confused why they rely on the go sidecar since we're in the python SDK? If they can break from a PR outside of this repo, I think they become less useful. If this is truly just a test for the go sidecar, I would rather we have those in the go repo. Can you expand on this a little bit?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
|
|
||
| from durabletask import client, task, worker | ||
|
|
||
|
|
||
| def empty_orchestrator(ctx: task.OrchestrationContext, _): | ||
| return "Complete" | ||
|
|
||
|
|
||
| def test_get_all_orchestration_states(): | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Start a worker, which will connect to the sidecar in a background thread | ||
| with worker.TaskHubGrpcWorker() as w: | ||
| w.add_orchestrator(empty_orchestrator) | ||
| w.start() | ||
|
|
||
| c = client.TaskHubGrpcClient() | ||
| id = c.schedule_new_orchestration(empty_orchestrator, input="Hello") | ||
| c.wait_for_orchestration_completion(id, timeout=30) | ||
|
|
||
| all_orchestrations = c.get_all_orchestration_states() | ||
| this_orch = c.get_orchestration_state(id) | ||
|
|
||
| assert this_orch is not None | ||
| assert this_orch.instance_id == id | ||
|
|
||
| assert all_orchestrations is not None | ||
| assert len(all_orchestrations) > 1 | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| print(f"Received {len(all_orchestrations)} orchestrations") | ||
| assert len([o for o in all_orchestrations if o.instance_id == id]) == 1 | ||
| orchestration_state = [o for o in all_orchestrations if o.instance_id == id][0] | ||
andystaples marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert orchestration_state.runtime_status == client.OrchestrationStatus.COMPLETED | ||
| assert orchestration_state.serialized_input == '"Hello"' | ||
| assert orchestration_state.serialized_output == '"Complete"' | ||
| assert orchestration_state.failure_details is None | ||
Uh oh!
There was an error while loading. Please reload this page.