import ray

"""Only works with ray 2.6.3"""


def patch():
    from ray.data._internal.execution.operators.actor_pool_map_operator import (
        ActorPoolMapOperator,
        _ActorPool,
        logger,
    )

    import collections
    from typing import Union
    from ray._raylet import ObjectRefGenerator
    from ray.data._internal.execution.interfaces import TaskContext
    from ray.data._internal.execution.operators.map_operator import _TaskState
    from ray.types import ObjectRef

    def _kill_died_actor(self, actor: ray.actor.ActorHandle):
        if actor in self._num_tasks_in_flight:
            logger.get_logger().warning(f"An actor died: {actor._actor_id.hex()}")
            self._kill_running_actor(actor)

    setattr(_ActorPool, "_kill_died_actor", _kill_died_actor)

    setattr(
        _ActorPool,
        "_path_fault_tolerant_old_pending_to_running",
        _ActorPool.pending_to_running,
    )

    def pending_to_running(self, ready_ref: ray.ObjectRef) -> bool:
        """Mark the actor corresponding to the provided ready future as running, making
        the actor pickable.

        Args:
            ready_ref: The ready future for the actor that we wish to mark as running.

        Returns:
            Whether the actor was still pending. This can return False if the actor had
            already been killed.
        """
        if ready_ref not in self._pending_actors:
            # We assume that there was a race between killing the actor and the actor
            # ready future resolving. Since we can rely on ray.kill() eventually killing
            # the actor, we can safely drop this reference.
            return False
        actor = self._pending_actors.pop(ready_ref)
        self._actor_locations[actor] = ray.get(ready_ref)
        #  try:
        #      self._actor_locations[actor] = ray.get(ready_ref)
        #  except ray.exceptions.RayActorError as e:
        #      # Actor died.
        #      # this might lead to no actors, if that is the last actor
        #      return False
        self._num_tasks_in_flight[actor] = 0
        return True

    setattr(_ActorPool, "pending_to_running", pending_to_running)

    setattr(
        _ActorPool,
        "_path_fault_tolerant_old_add_pending_actor",
        _ActorPool.add_pending_actor,
    )

    def add_pending_actor(self, actor: ray.actor.ActorHandle, ready_ref: ray.ObjectRef):
        """Adds a pending actor to the pool.

        This actor won't be pickable until it is marked as running via a
        pending_to_running() call.

        Args:
            actor: The not-yet-ready actor to add as pending to the pool.
            ready_ref: The ready future for the actor.
        """
        # The caller shouldn't add new actors to the pool after invoking
        # kill_inactive_actors().
        # disabled
        #  assert not self._should_kill_idle_actors
        self._pending_actors[ready_ref] = actor

    setattr(_ActorPool, "add_pending_actor", add_pending_actor)

    setattr(
        ActorPoolMapOperator,
        "_path_fault_tolerant_old___init__",
        ActorPoolMapOperator.__init__,
    )

    def __init__(self, *args, **kwargs):
        self._path_fault_tolerant_old___init__(*args, **kwargs)
        # a queue of tasks that was dispatched to an actor but the actor died before
        # the task was completed.
        # element is a tuple of (task, ctx)
        self._failed_tasks = collections.deque()

    setattr(ActorPoolMapOperator, "__init__", __init__)

    setattr(
        ActorPoolMapOperator,
        "_path_fault_tolerant_old__dispatch_tasks",
        ActorPoolMapOperator._dispatch_tasks,
    )

    def _dispatch_tasks(self):
        """Try to dispatch tasks from the bundle buffer to the actor pool.

        This is called when:
            * a new input bundle is added,
            * a task finishes,
            * a new worker has been created.
        """
        while self._bundle_queue or self._failed_tasks:
            # Pick an actor from the pool.
            if self._actor_locality_enabled:
                if self._failed_tasks:
                    bundle = self._failed_tasks[0][0].inputs
                else:
                    bundle = self._bundle_queue[0]
                actor = self._actor_pool.pick_actor(bundle)
            else:
                actor = self._actor_pool.pick_actor()
            if actor is None:
                # No actors available for executing the next task.
                break
            # Submit the map task.
            if self._failed_tasks:
                task, ctx = self._failed_tasks.popleft()
                bundle = task.inputs
            else:
                bundle = self._bundle_queue.popleft()
                task = _TaskState(bundle)
                self._handle_task_submitted(task)
                ctx = TaskContext(task_idx=self._next_task_idx)
                self._next_task_idx += 1
            input_blocks = [block for block, _ in bundle.blocks]
            ref = actor.submit.options(num_returns="dynamic", name=self.name).remote(
                self._transform_fn_ref, ctx, *input_blocks
            )
            self._tasks[ref] = (task, actor, ctx)

        # Needed in the bulk execution path for triggering autoscaling. This is a
        # no-op in the streaming execution case.
        if self._bundle_queue or self._failed_tasks:
            # Try to scale up if work remains in the work queue.
            self._scale_up_if_needed()
        else:
            # Only try to scale down if the work queue has been fully consumed.
            self._scale_down_if_needed()

    setattr(ActorPoolMapOperator, "_dispatch_tasks", _dispatch_tasks)

    setattr(
        ActorPoolMapOperator,
        "_path_fault_tolerant_old_notify_work_completed",
        ActorPoolMapOperator.notify_work_completed,
    )

    def notify_work_completed(
        self, ref: Union[ObjectRef[ObjectRefGenerator], ray.ObjectRef]
    ):
        def ensure_min_worker():
            for _ in range(
                self._actor_pool.num_total_actors(),
                self._autoscaling_policy.min_workers,
            ):
                self._start_actor()

        # This actor pool MapOperator implementation has both task output futures AND
        # worker started futures to handle here.
        if ref in self._tasks:
            # Get task state and set output.
            task, actor, ctx = self._tasks.pop(ref)
            try:
                task.output = self._map_ref_to_ref_bundle(ref)
            except ray.exceptions.RayActorError as e:
                # Actor died.
                # 1. kill actor
                self._actor_pool._kill_died_actor(actor)
                # 1.1 ensure min worker
                ensure_min_worker()
                # reset _should_kill_idle_actors
                self._actor_pool._should_kill_idle_actors = False
                # 2. requeue task
                self._failed_tasks.append((task, ctx))
                # finally, still need to call _dispatch_tasks to dispatch the failed
                # task
            else:
                self._handle_task_done(task)
                # Return the actor that was running the task to the pool.
                self._actor_pool.return_actor(actor)
        else:
            # ref is a future for a now-ready actor; move actor from pending to the
            # active actor pool.
            try:
                has_actor = self._actor_pool.pending_to_running(ref)
                if not has_actor:
                    # Actor has already been killed.
                    return
            except ray.exceptions.RayActorError as e:
                ensure_min_worker()
        # For either a completed task or ready worker, we try to dispatch queued tasks.
        self._dispatch_tasks()

    setattr(ActorPoolMapOperator, "notify_work_completed", notify_work_completed)

    setattr(
        ActorPoolMapOperator,
        "_path_fault_tolerant_old_internal_queue_size",
        ActorPoolMapOperator.internal_queue_size,
    )

    def internal_queue_size(self) -> int:
        return len(self._bundle_queue) + len(self._failed_tasks)

    setattr(ActorPoolMapOperator, "internal_queue_size", internal_queue_size)

    setattr(
        ActorPoolMapOperator,
        "_path_fault_tolerant_old__kill_inactive_workers_if_done",
        ActorPoolMapOperator._kill_inactive_workers_if_done,
    )

    def _kill_inactive_workers_if_done(self):
        if self._inputs_done and not self._bundle_queue and not self._failed_tasks:
            # No more tasks will be submitted, so we kill all current and future
            # inactive workers.
            self._actor_pool.kill_all_inactive_actors()

    setattr(
        ActorPoolMapOperator,
        "_kill_inactive_workers_if_done",
        _kill_inactive_workers_if_done,
    )
