diff --git a/python/restate/__init__.py b/python/restate/__init__.py index e6b4f30..3b32d65 100644 --- a/python/restate/__init__.py +++ b/python/restate/__init__.py @@ -44,7 +44,6 @@ from .logging import getLogger, RestateLoggingFilter - try: from .harness import create_test_harness, test_harness # type: ignore except ImportError: diff --git a/python/restate/extensions.py b/python/restate/extensions.py new file mode 100644 index 0000000..3a65679 --- /dev/null +++ b/python/restate/extensions.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +"""This module contains internal extensions apis""" + +from .server_context import current_context + +__all__ = ["current_context"] diff --git a/python/restate/server_context.py b/python/restate/server_context.py index affbd80..26bc0a5 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -30,6 +30,7 @@ import time from restate.context import ( + Context, DurablePromise, AttemptFinishedEvent, HandlerType, @@ -302,6 +303,14 @@ def update_restate_context_is_replaying(vm: VMWrapper): restate_context_is_replaying.set(vm.is_replaying()) +_restate_context_var = contextvars.ContextVar[Context]("restate_context") + + +def current_context() -> Context | None: + """Get the current context.""" + return _restate_context_var.get() + + # pylint: disable=R0902 class ServerInvocationContext(ObjectContext): """This class implements the context for the restate framework based on the server.""" @@ -330,6 +339,7 @@ def __init__( async def enter(self): """Invoke the user code.""" update_restate_context_is_replaying(self.vm) + token = _restate_context_var.set(self) try: in_buffer = self.invocation.input_buffer out_buffer = await invoke_handler(handler=self.handler, ctx=self, in_buffer=in_buffer) @@ -356,6 +366,8 @@ async def enter(self): stacktrace = "\n".join(traceback.format_exception(e)) self.vm.notify_error(repr(e), stacktrace) raise e + finally: + _restate_context_var.reset(token) async def leave(self): """Leave the context.""" diff --git a/tests/ext.py b/tests/ext.py new file mode 100644 index 0000000..f7b4f0d --- /dev/null +++ b/tests/ext.py @@ -0,0 +1,64 @@ +# +# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# + +import restate +from restate import ( + Context, + Service, + HarnessEnvironment, +) +import pytest + +# ----- Asyncio fixtures + + +@pytest.fixture(scope="session") +def anyio_backend(): + return "asyncio" + + +pytestmark = [ + pytest.mark.anyio, +] + +# -------- Restate services and restate fixture + +greeter = Service("greeter") + + +def magic_function(): + from restate.extensions import current_context + + ctx = current_context() + assert ctx is not None + return ctx.request().id + + +@greeter.handler() +async def greet(ctx: Context, name: str) -> str: + id = magic_function() + return f"Hello {id}!" + + +@pytest.fixture(scope="session") +async def restate_test_harness(): + async with restate.create_test_harness( + restate.app([greeter]), restate_image="ghcr.io/restatedev/restate:latest" + ) as harness: + yield harness + + +# ----- Tests + + +async def test_greeter(restate_test_harness: HarnessEnvironment): + greeting = await restate_test_harness.client.service_call(greet, arg="bob") + assert greeting.startswith("Hello ")