diff --git a/src/integrations/prefect-ray/prefect_ray/task_runners.py b/src/integrations/prefect-ray/prefect_ray/task_runners.py index e6308482a7c3..512abd6d2c5a 100644 --- a/src/integrations/prefect-ray/prefect_ray/task_runners.py +++ b/src/integrations/prefect-ray/prefect_ray/task_runners.py @@ -73,7 +73,6 @@ def count_to(highest_number): from __future__ import annotations import asyncio # noqa: I001 -import sys from typing import ( TYPE_CHECKING, Any, @@ -131,12 +130,23 @@ def result( raise_on_failure: bool = True, ) -> R: if not self._final_state: - self.wait(timeout=timeout) - if not self._final_state: - raise RuntimeError("No final state could be retrieved.") + try: + object_ref_result = ray.get(self.wrapped_future, timeout=timeout) + except ray.exceptions.GetTimeoutError as exc: + raise TimeoutError( + f"Task run {self.task_run_id} did not complete within {timeout} seconds" + ) from exc + + if isinstance(object_ref_result, State): + self._final_state = object_ref_result + else: + return object_ref_result + _result = self._final_state.result( raise_on_failure=raise_on_failure, fetch=True ) + # state.result is a `sync_compatible` function that may or may not return an awaitable + # depending on whether the parent frame is sync or not if asyncio.iscoroutine(_result): _result = run_coro_as_sync(_result) return _result @@ -156,14 +166,6 @@ def __del__(self): if self._final_state: return - # If the Python interpreter is shutting down, skip - if sys.is_finalizing(): - return - - # If Ray is not initialized, skip - if not ray.is_initialized(): - return - try: ray.get(self.wrapped_future, timeout=0) except ray.exceptions.GetTimeoutError: @@ -355,7 +357,7 @@ def _run_prefect_task( parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, dependencies: dict[str, set[TaskRunInput]] | None = None, - ) -> State: + ) -> Any: """Resolves Ray futures before calling the actual Prefect task function. Passing upstream_ray_obj_refs directly as args enables Ray to wait for @@ -385,15 +387,12 @@ def resolve_ray_future(expr: Any): "return_type": "state", } - try: - # Ray does not support the submission of async functions and we must create a - # sync entrypoint - if task.isasync: - return asyncio.run(run_task_async(**run_task_kwargs)) - else: - return run_task_sync(**run_task_kwargs) - except Exception as exc: - return run_coro_as_sync(exception_to_crashed_state(exc)) + # Ray does not support the submission of async functions and we must create a + # sync entrypoint + if task.isasync: + return asyncio.run(run_task_async(**run_task_kwargs)) + else: + return run_task_sync(**run_task_kwargs) def __enter__(self) -> Self: super().__enter__()