import logging
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set

import ray
from ray.actor import ActorHandle
from utils.annotations import ExperimentalAPI

logger = logging.getLogger(__name__)


@ExperimentalAPI
def asynchronous_parallel_requests(
    remote_requests_in_flight: DefaultDict[ActorHandle, Set[ray.ObjectRef]],
    actors: List[ActorHandle],
    ray_wait_timeout_s: Optional[float] = None,
    max_remote_requests_in_flight_per_actor: int = 2,
    remote_fn: Optional[Callable[[Any, Optional[Any], Optional[Any]], Any]] = None,
    remote_args: Optional[List[List[Any]]] = None,
    remote_kwargs: Optional[List[Dict[str, Any]]] = None,
) -> Dict[ActorHandle, Any]:
    """Runs parallel and asynchronous rollouts on all remote workers.
    May use a timeout (if provided) on `ray.wait()` and returns only those
    samples that could be gathered in the timeout window. Allows a maximum
    of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight
    per remote actor.
    Alternatively to calling `actor.sample.remote()`, the user can provide a
    `remote_fn()`, which will be applied to the actor(s) instead.
    Args:
        remote_requests_in_flight: Dict mapping actor handles to a set of
            their currently-in-flight pending requests (those we expect to
            ray.get results for next). If you have an RLlib Trainer that calls
            this function, you can use its `self.remote_requests_in_flight`
            property here.
        actors: The List of ActorHandles to perform the remote requests on.
        ray_wait_timeout_s: Timeout (in sec) to be used for the underlying
            `ray.wait()` calls. If None (default), never time out (block
            until at least one actor returns something).
        max_remote_requests_in_flight_per_actor: Maximum number of remote
            requests sent to each actor. 2 (default) is probably
            sufficient to avoid idle times between two requests.
        remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of
            `actor.sample.remote()` to generate the requests.
        remote_args: If provided, use this list (per-actor) of lists (call
            args) as *args to be passed to the `remote_fn`.
            E.g.: actors=[A, B],
            remote_args=[[...] <- *args for A, [...] <- *args for B].
        remote_kwargs: If provided, use this list (per-actor) of dicts
            (kwargs) as **kwargs to be passed to the `remote_fn`.
            E.g.: actors=[A, B],
            remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B].
    Returns:
        A dict mapping actor handles to the results received by sending requests
        to these actors.
        None, if no samples are ready.
    Examples:
        >>> # 2 remote rollout workers (num_workers=2):
        >>> batches = asynchronous_parallel_sample(
        ...     trainer.remote_requests_in_flight,
        ...     actors=trainer.workers.remote_workers(),
        ...     ray_wait_timeout_s=0.1,
        ...     remote_fn=lambda w: time.sleep(1)  # sleep 1sec
        ... )
        >>> print(len(batches))
        ... 2
        >>> # Expect a timeout to have happened.
        >>> batches[0] is None and batches[1] is None
        ... True
    """

    if remote_args is not None:
        assert len(remote_args) == len(actors)
    if remote_kwargs is not None:
        assert len(remote_kwargs) == len(actors)

    # For faster hash lookup.
    actor_set = set(actors)

    # Collect all currently pending remote requests into a single set of
    # object refs.
    pending_remotes = set()
    # Also build a map to get the associated actor for each remote request.
    remote_to_actor = {}
    for actor, set_ in remote_requests_in_flight.items():
        # Only consider those actors' pending requests that are in
        # the given `actors` list.
        if actor in actor_set:
            pending_remotes |= set_
            for r in set_:
                remote_to_actor[r] = actor

    # Add new requests, if possible (if
    # `max_remote_requests_in_flight_per_actor` setting allows it).
    for actor_idx, actor in enumerate(actors):
        # Still room for another request to this actor.
        if (
            len(remote_requests_in_flight[actor])
            < max_remote_requests_in_flight_per_actor
        ):
            if remote_fn is None:
                req = actor.sample.remote()
            else:
                args = remote_args[actor_idx] if remote_args else []
                kwargs = remote_kwargs[actor_idx] if remote_kwargs else {}
                req = actor.apply.remote(remote_fn, *args, **kwargs)
            # Add to our set to send to ray.wait().
            pending_remotes.add(req)
            # Keep our mappings properly updated.
            remote_requests_in_flight[actor].add(req)
            remote_to_actor[req] = actor

    # There must always be pending remote requests.
    assert len(pending_remotes) > 0
    pending_remote_list = list(pending_remotes)

    # No timeout: Block until at least one result is returned.
    if ray_wait_timeout_s is None:
        # First try to do a `ray.wait` w/o timeout for efficiency.
        ready, _ = ray.wait(
            pending_remote_list, num_returns=len(pending_remotes), timeout=0
        )
        # Nothing returned and `timeout` is None -> Fall back to a
        # blocking wait to make sure we can return something.
        if not ready:
            ready, _ = ray.wait(pending_remote_list, num_returns=1)
    # Timeout: Do a `ray.wait() call` w/ timeout.
    else:
        ready, _ = ray.wait(
            pending_remote_list,
            num_returns=len(pending_remotes),
            timeout=ray_wait_timeout_s,
        )

        # Return empty results if nothing ready after the timeout.
        if not ready:
            return {}

    # Remove in-flight records for ready refs.
    for obj_ref in ready:
        remote_requests_in_flight[remote_to_actor[obj_ref]].remove(obj_ref)

    # Do one ray.get().
    results = ray.get(ready)
    assert len(ready) == len(results)

    # Return mapping from (ready) actors to their results.
    ret = {}
    for obj_ref, result in zip(ready, results):
        ret[remote_to_actor[obj_ref]] = result

    return ret
