from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple
import logging
import numpy as np
import queue
import threading
import time
import tree  # pip install dm_tree
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Iterator,
    Optional,
    Set,
    Tuple,
    Type,
    TYPE_CHECKING,
    Union,
)

from ray.util.debug import log_once
from worker.collectors.sample_collector import SampleCollector
from worker.collectors.simple_list_collector import SimpleListCollector
from worker.episode import Episode
from worker.metrics import RolloutMetrics
from worker.sample_batch_builder import MultiAgentSampleBatchBuilder
from env.base_env import BaseEnv, convert_to_base_env, ASYNC_RESET_RETURN
from env.wrappers.atari_wrappers import get_wrapper_by_cls, MonitorEnv
from models.preprocessors import Preprocessor
from worker.offline import InputReader
from policy.policy import Policy
from policy.policy_map import PolicyMap
from policy.sample_batch import SampleBatch
from utils.annotations import override, DeveloperAPI
from utils.debug import summarize
from utils.deprecation import deprecation_warning
from utils.filter import Filter
from utils.numpy import convert_to_numpy
from utils.spaces.space_utils import clip_action, unsquash_action, unbatch
from utils.typing import (
    SampleBatchType,
    AgentID,
    PolicyID,
    EnvObsType,
    EnvInfoDict,
    EnvID,
    MultiEnvDict,
    EnvActionType,
    TensorStructType,
)

if TYPE_CHECKING:
    from agents.callbacks import DefaultCallbacks
    from worker.observation_function import ObservationFunction
    from worker.rollout_worker import RolloutWorker

    from gym.envs.classic_control.rendering import SimpleImageViewer

logger = logging.getLogger(__name__)

PolicyEvalData = namedtuple(
    "PolicyEvalData",
    ["env_id", "agent_id", "obs", "info", "rnn_state", "prev_action", "prev_reward"],
)

# A batch of RNN states with dimensions [state_index, batch, state_object].
StateBatch = List[List[Any]]


class NewEpisodeDefaultDict(defaultdict):
    def __missing__(self, env_id):
        if self.default_factory is None:
            raise KeyError(env_id)
        else:
            ret = self[env_id] = self.default_factory(env_id)
            return ret


class _PerfStats:
    """Sampler perf stats that will be included in rollout metrics."""

    def __init__(self):
        self.iters = 0
        self.raw_obs_processing_time = 0.0
        self.inference_time = 0.0
        self.action_processing_time = 0.0
        self.env_wait_time = 0.0
        self.env_render_time = 0.0

    def get(self):
        # Mean multiplicator (1000 = ms -> sec).
        factor = 1000 / self.iters
        return {
            # Raw observation preprocessing.
            "mean_raw_obs_processing_ms": self.raw_obs_processing_time * factor,
            # Computing actions through policy.
            "mean_inference_ms": self.inference_time * factor,
            # Processing actions (to be sent to env, e.g. clipping).
            "mean_action_processing_ms": self.action_processing_time * factor,
            # Waiting for environment (during poll).
            "mean_env_wait_ms": self.env_wait_time * factor,
            # Environment rendering (False by default).
            "mean_env_render_ms": self.env_render_time * factor,
        }


@DeveloperAPI
class SamplerInput(InputReader, metaclass=ABCMeta):
    """Reads input experiences from an existing sampler."""

    @override(InputReader)
    def next(self) -> SampleBatchType:
        batches = [self.get_data()]
        batches.extend(self.get_extra_batches())
        if len(batches) > 1:
            return batches[0].concat_samples(batches)
        else:
            return batches[0]

    @abstractmethod
    @DeveloperAPI
    def get_data(self) -> SampleBatchType:
        """Called by `self.next()` to return the next batch of data.

        Override this in child classes.

        Returns:
            The next batch of data.
        """
        raise NotImplementedError

    @abstractmethod
    @DeveloperAPI
    def get_metrics(self) -> List[RolloutMetrics]:
        """Returns list of episode metrics since the last call to this method.

        The list will contain one RolloutMetrics object per completed episode.

        Returns:
            List of RolloutMetrics objects, one per completed episode since
            the last call to this method.
        """
        raise NotImplementedError

    @abstractmethod
    @DeveloperAPI
    def get_extra_batches(self) -> List[SampleBatchType]:
        """Returns list of extra batches since the last call to this method.

        The list will contain all SampleBatches or
        MultiAgentBatches that the user has provided thus-far. Users can
        add these "extra batches" to an episode by calling the episode's
        `add_extra_batch([SampleBatchType])` method. This can be done from
        inside an overridden `Policy.compute_actions_from_input_dict(...,
        episodes)` or from a custom callback's `on_episode_[start|step|end]()`
        methods.

        Returns:
            List of SamplesBatches or MultiAgentBatches provided thus-far by
            the user since the last call to this method.
        """
        raise NotImplementedError


@DeveloperAPI
class SyncSampler(SamplerInput):
    """Sync SamplerInput that collects experiences when `get_data()` is called."""

    def __init__(
        self,
        *,
        worker: "RolloutWorker",
        env: BaseEnv,
        clip_rewards: Union[bool, float],
        rollout_fragment_length: int,
        count_steps_by: str = "env_steps",
        callbacks: "DefaultCallbacks",
        horizon: int = None,
        multiple_episodes_in_batch: bool = False,
        normalize_actions: bool = True,
        clip_actions: bool = False,
        soft_horizon: bool = False,
        no_done_at_end: bool = False,
        observation_fn: Optional["ObservationFunction"] = None,
        sample_collector_class: Optional[Type[SampleCollector]] = None,
        render: bool = False,
    ):
        """Initializes a SyncSampler instance.

        Args:
            worker: The RolloutWorker that will use this Sampler for sampling.
            env: Any Env object. Will be converted into an RLlib BaseEnv.
            clip_rewards: True for +/-1.0 clipping,
                actual float value for +/- value clipping. False for no
                clipping.
            rollout_fragment_length: The length of a fragment to collect
                before building a SampleBatch from the data and resetting
                the SampleBatchBuilder object.
            count_steps_by: One of "env_steps" (default) or "agent_steps".
                Use "agent_steps", if you want rollout lengths to be counted
                by individual agent steps. In a multi-agent env,
                a single env_step contains one or more agent_steps, depending
                on how many agents are present at any given time in the
                ongoing episode.
            callbacks: The Callbacks object to use when episode
                events happen during rollout.
            horizon: Hard-reset the Env after this many timesteps.
            multiple_episodes_in_batch: Whether to pack multiple
                episodes into each batch. This guarantees batches will be
                exactly `rollout_fragment_length` in size.
            normalize_actions: Whether to normalize actions to the
                action space's bounds.
            clip_actions: Whether to clip actions according to the
                given action_space's bounds.
            soft_horizon: If True, calculate bootstrapped values as if
                episode had ended, but don't physically reset the environment
                when the horizon is hit.
            no_done_at_end: Ignore the done=True at the end of the
                episode and instead record done=False.
            observation_fn: Optional multi-agent observation func to use for
                preprocessing observations.
            sample_collector_class: An optional Samplecollector sub-class to
                use to collect, store, and retrieve environment-, model-,
                and sampler data.
            render: Whether to try to render the environment after each step.
        """
        self.base_env = convert_to_base_env(env)
        self.rollout_fragment_length = rollout_fragment_length
        self.horizon = horizon
        self.extra_batches = queue.Queue()
        self.perf_stats = _PerfStats()
        if not sample_collector_class:
            sample_collector_class = SimpleListCollector
        self.sample_collector = sample_collector_class(
            worker.policy_map,
            clip_rewards,
            callbacks,
            multiple_episodes_in_batch,
            rollout_fragment_length,
            count_steps_by=count_steps_by,
        )
        self.render = render

        # Create the rollout generator to use for calls to `get_data()`.
        self._env_runner = _env_runner(
            worker,
            self.base_env,
            self.extra_batches.put,
            self.horizon,
            normalize_actions,
            clip_actions,
            multiple_episodes_in_batch,
            callbacks,
            self.perf_stats,
            soft_horizon,
            no_done_at_end,
            observation_fn,
            self.sample_collector,
            self.render,
        )
        self.metrics_queue = queue.Queue()

    @override(SamplerInput)
    def get_data(self) -> SampleBatchType:
        while True:
            item = next(self._env_runner)
            if isinstance(item, RolloutMetrics):
                self.metrics_queue.put(item)
            else:
                return item

    @override(SamplerInput)
    def get_metrics(self) -> List[RolloutMetrics]:
        completed = []
        while True:
            try:
                completed.append(
                    self.metrics_queue.get_nowait()._replace(
                        perf_stats=self.perf_stats.get()
                    )
                )
            except queue.Empty:
                break
        return completed

    @override(SamplerInput)
    def get_extra_batches(self) -> List[SampleBatchType]:
        extra = []
        while True:
            try:
                extra.append(self.extra_batches.get_nowait())
            except queue.Empty:
                break
        return extra


@DeveloperAPI
class AsyncSampler(threading.Thread, SamplerInput):
    """Async SamplerInput that collects experiences in thread and queues them.

    Once started, experiences are continuously collected in the background
    and put into a Queue, from where they can be unqueued by the caller
    of `get_data()`.
    """

    def __init__(
        self,
        *,
        worker: "RolloutWorker",
        env: BaseEnv,
        clip_rewards: Union[bool, float],
        rollout_fragment_length: int,
        count_steps_by: str = "env_steps",
        callbacks: "DefaultCallbacks",
        horizon: Optional[int] = None,
        multiple_episodes_in_batch: bool = False,
        normalize_actions: bool = True,
        clip_actions: bool = False,
        soft_horizon: bool = False,
        no_done_at_end: bool = False,
        observation_fn: Optional["ObservationFunction"] = None,
        sample_collector_class: Optional[Type[SampleCollector]] = None,
        render: bool = False,
        blackhole_outputs: bool = False,
    ):
        """Initializes an AsyncSampler instance.

        Args:
            worker: The RolloutWorker that will use this Sampler for sampling.
            env: Any Env object. Will be converted into an RLlib BaseEnv.
            clip_rewards: True for +/-1.0 clipping,
                actual float value for +/- value clipping. False for no
                clipping.
            rollout_fragment_length: The length of a fragment to collect
                before building a SampleBatch from the data and resetting
                the SampleBatchBuilder object.
            count_steps_by: One of "env_steps" (default) or "agent_steps".
                Use "agent_steps", if you want rollout lengths to be counted
                by individual agent steps. In a multi-agent env,
                a single env_step contains one or more agent_steps, depending
                on how many agents are present at any given time in the
                ongoing episode.
            horizon: Hard-reset the Env after this many timesteps.
            multiple_episodes_in_batch: Whether to pack multiple
                episodes into each batch. This guarantees batches will be
                exactly `rollout_fragment_length` in size.
            normalize_actions: Whether to normalize actions to the
                action space's bounds.
            clip_actions: Whether to clip actions according to the
                given action_space's bounds.
            blackhole_outputs: Whether to collect samples, but then
                not further process or store them (throw away all samples).
            soft_horizon: If True, calculate bootstrapped values as if
                episode had ended, but don't physically reset the environment
                when the horizon is hit.
            no_done_at_end: Ignore the done=True at the end of the
                episode and instead record done=False.
            observation_fn: Optional multi-agent observation func to use for
                preprocessing observations.
            sample_collector_class: An optional SampleCollector sub-class to
                use to collect, store, and retrieve environment-, model-,
                and sampler data.
            render: Whether to try to render the environment after each step.
        """
        self.worker = worker

        for _, f in worker.filters.items():
            assert getattr(
                f, "is_concurrent", False
            ), "Observation Filter must support concurrent updates."

        self.base_env = convert_to_base_env(env)
        threading.Thread.__init__(self)
        self.queue = queue.Queue(5)
        self.extra_batches = queue.Queue()
        self.metrics_queue = queue.Queue()
        self.rollout_fragment_length = rollout_fragment_length
        self.horizon = horizon
        self.clip_rewards = clip_rewards
        self.daemon = True
        self.multiple_episodes_in_batch = multiple_episodes_in_batch
        self.callbacks = callbacks
        self.normalize_actions = normalize_actions
        self.clip_actions = clip_actions
        self.blackhole_outputs = blackhole_outputs
        self.soft_horizon = soft_horizon
        self.no_done_at_end = no_done_at_end
        self.perf_stats = _PerfStats()
        self.shutdown = False
        self.observation_fn = observation_fn
        self.render = render
        if not sample_collector_class:
            sample_collector_class = SimpleListCollector
        self.sample_collector = sample_collector_class(
            self.worker.policy_map,
            self.clip_rewards,
            self.callbacks,
            self.multiple_episodes_in_batch,
            self.rollout_fragment_length,
            count_steps_by=count_steps_by,
        )

    @override(threading.Thread)
    def run(self):
        try:
            self._run()
        except BaseException as e:
            self.queue.put(e)
            raise e

    def _run(self):
        if self.blackhole_outputs:
            queue_putter = lambda x: None
            extra_batches_putter = lambda x: None
        else:
            queue_putter = self.queue.put
            extra_batches_putter = lambda x: self.extra_batches.put(x, timeout=600.0)
        env_runner = _env_runner(
            self.worker,
            self.base_env,
            extra_batches_putter,
            self.horizon,
            self.normalize_actions,
            self.clip_actions,
            self.multiple_episodes_in_batch,
            self.callbacks,
            self.perf_stats,
            self.soft_horizon,
            self.no_done_at_end,
            self.observation_fn,
            self.sample_collector,
            self.render,
        )
        while not self.shutdown:
            # The timeout variable exists because apparently, if one worker
            # dies, the other workers won't die with it, unless the timeout is
            # set to some large number. This is an empirical observation.
            item = next(env_runner)
            if isinstance(item, RolloutMetrics):
                self.metrics_queue.put(item)
            else:
                queue_putter(item)

    @override(SamplerInput)
    def get_data(self) -> SampleBatchType:
        if not self.is_alive():
            raise RuntimeError("Sampling thread has died")
        rollout = self.queue.get(timeout=600.0)

        # Propagate errors.
        if isinstance(rollout, BaseException):
            raise rollout

        return rollout

    @override(SamplerInput)
    def get_metrics(self) -> List[RolloutMetrics]:
        completed = []
        while True:
            try:
                completed.append(
                    self.metrics_queue.get_nowait()._replace(
                        perf_stats=self.perf_stats.get()
                    )
                )
            except queue.Empty:
                break
        return completed

    @override(SamplerInput)
    def get_extra_batches(self) -> List[SampleBatchType]:
        extra = []
        while True:
            try:
                extra.append(self.extra_batches.get_nowait())
            except queue.Empty:
                break
        return extra


def _env_runner(
    worker: "RolloutWorker",
    base_env: BaseEnv,
    extra_batch_callback: Callable[[SampleBatchType], None],
    horizon: Optional[int],
    normalize_actions: bool,
    clip_actions: bool,
    multiple_episodes_in_batch: bool,
    callbacks: "DefaultCallbacks",
    perf_stats: _PerfStats,
    soft_horizon: bool,
    no_done_at_end: bool,
    observation_fn: "ObservationFunction",
    sample_collector: Optional[SampleCollector] = None,
    render: bool = None,
) -> Iterator[SampleBatchType]:
    """This implements the common experience collection logic.

    Args:
        worker: Reference to the current rollout worker.
        base_env: Env implementing BaseEnv.
        extra_batch_callback: function to send extra batch data to.
        horizon: Horizon of the episode.
        multiple_episodes_in_batch: Whether to pack multiple
            episodes into each batch. This guarantees batches will be exactly
            `rollout_fragment_length` in size.
        normalize_actions: Whether to normalize actions to the action
            space's bounds.
        clip_actions: Whether to clip actions to the space range.
        callbacks: User callbacks to run on episode events.
        perf_stats: Record perf stats into this object.
        soft_horizon: Calculate rewards but don't reset the
            environment when the horizon is hit.
        no_done_at_end: Ignore the done=True at the end of the episode
            and instead record done=False.
        observation_fn: Optional multi-agent
            observation func to use for preprocessing observations.
        sample_collector: An optional
            SampleCollector object to use.
        render: Whether to try to render the environment after each
            step.

    Yields:
        Object containing state, action, reward, terminal condition,
        and other fields as dictated by `policy`.
    """

    # May be populated with used for image rendering
    simple_image_viewer: Optional["SimpleImageViewer"] = None

    # Try to get Env's `max_episode_steps` prop. If it doesn't exist, ignore
    # error and continue with max_episode_steps=None.
    max_episode_steps = None
    try:
        max_episode_steps = base_env.get_sub_environments()[0].spec.max_episode_steps
    except Exception:
        pass

    # Trainer has a given `horizon` setting.
    if horizon:
        # `horizon` is larger than env's limit.
        if max_episode_steps and horizon > max_episode_steps:
            # Try to override the env's own max-step setting with our horizon.
            # If this won't work, throw an error.
            try:
                base_env.get_sub_environments()[0].spec.max_episode_steps = horizon
                base_env.get_sub_environments()[0]._max_episode_steps = horizon
            except Exception:
                raise ValueError(
                    "Your `horizon` setting ({}) is larger than the Env's own "
                    "timestep limit ({}), which seems to be unsettable! Try "
                    "to increase the Env's built-in limit to be at least as "
                    "large as your wanted `horizon`.".format(horizon, max_episode_steps)
                )
    # Otherwise, set Trainer's horizon to env's max-steps.
    elif max_episode_steps:
        horizon = max_episode_steps
        logger.debug(
            "No episode horizon specified, setting it to Env's limit ({}).".format(
                max_episode_steps
            )
        )
    # No horizon/max_episode_steps -> Episodes may be infinitely long.
    else:
        horizon = float("inf")
        logger.debug("No episode horizon specified, assuming inf.")

    # Pool of batch builders, which can be shared across episodes to pack
    # trajectory data.
    batch_builder_pool: List[MultiAgentSampleBatchBuilder] = []

    def get_batch_builder():
        if batch_builder_pool:
            return batch_builder_pool.pop()
        else:
            return None

    def new_episode(env_id):
        episode = Episode(
            worker.policy_map,
            worker.policy_mapping_fn,
            get_batch_builder,
            extra_batch_callback,
            env_id=env_id,
            worker=worker,
        )
        # Call each policy's Exploration.on_episode_start method.
        # Note: This may break the exploration (e.g. ParameterNoise) of
        # policies in the `policy_map` that have not been recently used
        # (and are therefore stashed to disk). However, we certainly do not
        # want to loop through all (even stashed) policies here as that
        # would counter the purpose of the LRU policy caching.
        for p in worker.policy_map.cache.values():
            if getattr(p, "exploration", None) is not None:
                p.exploration.on_episode_start(
                    policy=p,
                    environment=base_env,
                    episode=episode,
                )
        callbacks.on_episode_start(
            worker=worker,
            base_env=base_env,
            policies=worker.policy_map,
            episode=episode,
            env_index=env_id,
        )
        return episode

    active_episodes: Dict[EnvID, Episode] = NewEpisodeDefaultDict(new_episode)

    while True:
        perf_stats.iters += 1
        t0 = time.time()
        # Get observations from all ready agents.
        # types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ...
        unfiltered_obs, rewards, dones, infos, off_policy_actions = base_env.poll()
        perf_stats.env_wait_time += time.time() - t0

        if log_once("env_returns"):
            logger.info("Raw obs from env: {}".format(summarize(unfiltered_obs)))
            logger.info("Info return from env: {}".format(summarize(infos)))

        # Process observations and prepare for policy evaluation.
        t1 = time.time()
        # types: Set[EnvID], Dict[PolicyID, List[PolicyEvalData]],
        #       List[Union[RolloutMetrics, SampleBatchType]]
        active_envs, to_eval, outputs = _process_observations(
            worker=worker,
            base_env=base_env,
            active_episodes=active_episodes,
            unfiltered_obs=unfiltered_obs,
            rewards=rewards,
            dones=dones,
            infos=infos,
            horizon=horizon,
            multiple_episodes_in_batch=multiple_episodes_in_batch,
            callbacks=callbacks,
            soft_horizon=soft_horizon,
            no_done_at_end=no_done_at_end,
            observation_fn=observation_fn,
            sample_collector=sample_collector,
        )
        perf_stats.raw_obs_processing_time += time.time() - t1
        for o in outputs:
            yield o

        # Do batched policy eval (accross vectorized envs).
        t2 = time.time()
        # types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
        eval_results = _do_policy_eval(
            to_eval=to_eval,
            policies=worker.policy_map,
            sample_collector=sample_collector,
            active_episodes=active_episodes,
        )
        perf_stats.inference_time += time.time() - t2

        # Process results and update episode state.
        t3 = time.time()
        actions_to_send: Dict[
            EnvID, Dict[AgentID, EnvActionType]
        ] = _process_policy_eval_results(
            to_eval=to_eval,
            eval_results=eval_results,
            active_episodes=active_episodes,
            active_envs=active_envs,
            off_policy_actions=off_policy_actions,
            policies=worker.policy_map,
            normalize_actions=normalize_actions,
            clip_actions=clip_actions,
        )
        perf_stats.action_processing_time += time.time() - t3

        # Return computed actions to ready envs. We also send to envs that have
        # taken off-policy actions; those envs are free to ignore the action.
        t4 = time.time()
        base_env.send_actions(actions_to_send)
        perf_stats.env_wait_time += time.time() - t4

        # Try to render the env, if required.
        if render:
            t5 = time.time()
            # Render can either return an RGB image (uint8 [w x h x 3] numpy
            # array) or take care of rendering itself (returning True).
            rendered = base_env.try_render()
            # Rendering returned an image -> Display it in a SimpleImageViewer.
            if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3:
                # ImageViewer not defined yet, try to create one.
                if simple_image_viewer is None:
                    try:
                        from gym.envs.classic_control.rendering import SimpleImageViewer

                        simple_image_viewer = SimpleImageViewer()
                    except (ImportError, ModuleNotFoundError):
                        render = False  # disable rendering
                        logger.warning(
                            "Could not import gym.envs.classic_control."
                            "rendering! Try `pip install gym[all]`."
                        )
                if simple_image_viewer:
                    simple_image_viewer.imshow(rendered)
            elif rendered not in [True, False, None]:
                raise ValueError(
                    "The env's ({base_env}) `try_render()` method returned an"
                    " unsupported value! Make sure you either return a "
                    "uint8/w x h x 3 (RGB) image or handle rendering in a "
                    "window and then return `True`."
                )
            perf_stats.env_render_time += time.time() - t5


def _process_observations(
    *,
    worker: "RolloutWorker",
    base_env: BaseEnv,
    active_episodes: Dict[EnvID, Episode],
    unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
    rewards: Dict[EnvID, Dict[AgentID, float]],
    dones: Dict[EnvID, Dict[AgentID, bool]],
    infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
    horizon: int,
    multiple_episodes_in_batch: bool,
    callbacks: "DefaultCallbacks",
    soft_horizon: bool,
    no_done_at_end: bool,
    observation_fn: "ObservationFunction",
    sample_collector: SampleCollector,
) -> Tuple[
    Set[EnvID],
    Dict[PolicyID, List[PolicyEvalData]],
    List[Union[RolloutMetrics, SampleBatchType]],
]:
    """Record new data from the environment and prepare for policy evaluation.

    Args:
        worker: Reference to the current rollout worker.
        base_env: Env implementing BaseEnv.
        active_episodes: Mapping from
            episode ID to currently ongoing Episode object.
        unfiltered_obs: Doubly keyed dict of env-ids -> agent ids
            -> unfiltered observation tensor, returned by a `BaseEnv.poll()`
            call.
        rewards: Doubly keyed dict of env-ids -> agent ids ->
            rewards tensor, returned by a `BaseEnv.poll()` call.
        dones: Doubly keyed dict of env-ids -> agent ids ->
            boolean done flags, returned by a `BaseEnv.poll()` call.
        infos: Doubly keyed dict of env-ids -> agent ids ->
            info dicts, returned by a `BaseEnv.poll()` call.
        horizon: Horizon of the episode.
        multiple_episodes_in_batch: Whether to pack multiple
            episodes into each batch. This guarantees batches will be exactly
            `rollout_fragment_length` in size.
        callbacks: User callbacks to run on episode events.
        soft_horizon: Calculate rewards but don't reset the
            environment when the horizon is hit.
        no_done_at_end: Ignore the done=True at the end of the episode
            and instead record done=False.
        observation_fn: Optional multi-agent
            observation func to use for preprocessing observations.
        sample_collector: The SampleCollector object
            used to store and retrieve environment samples.

    Returns:
        Tuple consisting of 1) active_envs: Set of non-terminated env ids.
        2) to_eval: Map of policy_id to list of agent PolicyEvalData.
        3) outputs: List of metrics and samples to return from the sampler.
    """

    # Output objects.
    active_envs: Set[EnvID] = set()
    to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list)
    outputs: List[Union[RolloutMetrics, SampleBatchType]] = []

    # For each (vectorized) sub-environment.
    # types: EnvID, Dict[AgentID, EnvObsType]
    for env_id, all_agents_obs in unfiltered_obs.items():
        is_new_episode: bool = env_id not in active_episodes
        episode: Episode = active_episodes[env_id]

        if not is_new_episode:
            sample_collector.episode_step(episode)
            episode._add_agent_rewards(rewards[env_id])

        # Check episode termination conditions.
        if dones[env_id]["__all__"] or episode.length >= horizon:
            hit_horizon = episode.length >= horizon and not dones[env_id]["__all__"]
            all_agents_done = True
            atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(base_env)
            if atari_metrics is not None:
                for m in atari_metrics:
                    outputs.append(m._replace(custom_metrics=episode.custom_metrics))
            else:
                outputs.append(
                    RolloutMetrics(
                        episode.length,
                        episode.total_reward,
                        dict(episode.agent_rewards),
                        episode.custom_metrics,
                        {},
                        episode.hist_data,
                        episode.media,
                    )
                )
            # Check whether we have to create a fake-last observation
            # for some agents (the environment is not required to do so if
            # dones[__all__]=True).
            for ag_id in episode.get_agents():
                if not episode.last_done_for(ag_id) and ag_id not in all_agents_obs:
                    # Create a fake (all-0s) observation.
                    obs_sp = worker.policy_map[
                        episode.policy_for(ag_id)
                    ].observation_space
                    obs_sp = getattr(obs_sp, "original_space", obs_sp)
                    all_agents_obs[ag_id] = tree.map_structure(
                        np.zeros_like, obs_sp.sample()
                    )
        else:
            hit_horizon = False
            all_agents_done = False
            active_envs.add(env_id)

        # Custom observation function is applied before preprocessing.
        if observation_fn:
            all_agents_obs: Dict[AgentID, EnvObsType] = observation_fn(
                agent_obs=all_agents_obs,
                worker=worker,
                base_env=base_env,
                policies=worker.policy_map,
                episode=episode,
            )
            if not isinstance(all_agents_obs, dict):
                raise ValueError("observe() must return a dict of agent observations")

        common_infos = infos[env_id].get("__common__", {})
        episode._set_last_info("__common__", common_infos)

        # For each agent in the environment.
        # types: AgentID, EnvObsType
        for agent_id, raw_obs in all_agents_obs.items():
            assert agent_id != "__all__"

            last_observation: EnvObsType = episode.last_observation_for(agent_id)
            agent_done = bool(all_agents_done or dones[env_id].get(agent_id))

            # A new agent (initial obs) is already done -> Skip entirely.
            if last_observation is None and agent_done:
                continue

            policy_id: PolicyID = episode.policy_for(agent_id)

            preprocessor = _get_or_raise(worker.preprocessors, policy_id)
            prep_obs: EnvObsType = raw_obs
            if preprocessor is not None:
                prep_obs = preprocessor.transform(raw_obs)
                if log_once("prep_obs"):
                    logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
            filtered_obs: EnvObsType = _get_or_raise(worker.filters, policy_id)(
                prep_obs
            )
            if log_once("filtered_obs"):
                logger.info("Filtered obs: {}".format(summarize(filtered_obs)))

            episode._set_last_observation(agent_id, filtered_obs)
            episode._set_last_raw_obs(agent_id, raw_obs)
            episode._set_last_done(agent_id, agent_done)
            # Infos from the environment.
            agent_infos = infos[env_id].get(agent_id, {})
            episode._set_last_info(agent_id, agent_infos)

            # Record transition info if applicable.
            if last_observation is None:
                sample_collector.add_init_obs(
                    episode,
                    agent_id,
                    env_id,
                    policy_id,
                    episode.length - 1,
                    filtered_obs,
                )
            elif agent_infos is None or agent_infos.get("training_enabled", True):
                # Add actions, rewards, next-obs to collectors.
                values_dict = {
                    SampleBatch.T: episode.length - 1,
                    SampleBatch.ENV_ID: env_id,
                    SampleBatch.AGENT_INDEX: episode._agent_index(agent_id),
                    # Action (slot 0) taken at timestep t.
                    SampleBatch.ACTIONS: episode.last_action_for(agent_id),
                    # Reward received after taking a at timestep t.
                    SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0),
                    # After taking action=a, did we reach terminal?
                    SampleBatch.DONES: (
                        False
                        if (no_done_at_end or (hit_horizon and soft_horizon))
                        else agent_done
                    ),
                    # Next observation.
                    SampleBatch.NEXT_OBS: filtered_obs,
                }
                # Add extra-action-fetches (policy-inference infos) to
                # collectors.
                pol = worker.policy_map[policy_id]
                for key, value in episode.last_extra_action_outs_for(agent_id).items():
                    if key in pol.view_requirements:
                        values_dict[key] = value
                # Env infos for this agent.
                if "infos" in pol.view_requirements:
                    values_dict["infos"] = agent_infos
                sample_collector.add_action_reward_next_obs(
                    episode.episode_id,
                    agent_id,
                    env_id,
                    policy_id,
                    agent_done,
                    values_dict,
                )

            if not agent_done:
                item = PolicyEvalData(
                    env_id,
                    agent_id,
                    filtered_obs,
                    agent_infos,
                    None
                    if last_observation is None
                    else episode.rnn_state_for(agent_id),
                    None
                    if last_observation is None
                    else episode.last_action_for(agent_id),
                    rewards[env_id].get(agent_id, 0.0),
                )
                to_eval[policy_id].append(item)

        # Invoke the `on_episode_step` callback after the step is logged
        # to the episode.
        # Exception: The very first env.poll() call causes the env to get reset
        # (no step taken yet, just a single starting observation logged).
        # We need to skip this callback in this case.
        if episode.length > 0:
            callbacks.on_episode_step(
                worker=worker,
                base_env=base_env,
                policies=worker.policy_map,
                episode=episode,
                env_index=env_id,
            )

        # Episode is done for all agents (dones[__all__] == True)
        # or we hit the horizon.
        if all_agents_done:
            is_done = dones[env_id]["__all__"]
            check_dones = is_done and not no_done_at_end

            # If, we are not allowed to pack the next episode into the same
            # SampleBatch (batch_mode=complete_episodes) -> Build the
            # MultiAgentBatch from a single episode and add it to "outputs".
            # Otherwise, just postprocess and continue collecting across
            # episodes.
            ma_sample_batch = sample_collector.postprocess_episode(
                episode,
                is_done=is_done or (hit_horizon and not soft_horizon),
                check_dones=check_dones,
                build=not multiple_episodes_in_batch,
            )
            if ma_sample_batch:
                outputs.append(ma_sample_batch)

            # Call each (in-memory) policy's Exploration.on_episode_end
            # method.
            # Note: This may break the exploration (e.g. ParameterNoise) of
            # policies in the `policy_map` that have not been recently used
            # (and are therefore stashed to disk). However, we certainly do not
            # want to loop through all (even stashed) policies here as that
            # would counter the purpose of the LRU policy caching.
            for p in worker.policy_map.cache.values():
                if getattr(p, "exploration", None) is not None:
                    p.exploration.on_episode_end(
                        policy=p,
                        environment=base_env,
                        episode=episode,
                    )
            # Call custom on_episode_end callback.
            callbacks.on_episode_end(
                worker=worker,
                base_env=base_env,
                policies=worker.policy_map,
                episode=episode,
                env_index=env_id,
            )
            # Horizon hit and we have a soft horizon (no hard env reset).
            if hit_horizon and soft_horizon:
                episode.soft_reset()
                resetted_obs: Dict[EnvID, Dict[AgentID, EnvObsType]] = {
                    env_id: all_agents_obs
                }
            else:
                del active_episodes[env_id]
                resetted_obs: Dict[
                    EnvID, Dict[AgentID, EnvObsType]
                ] = base_env.try_reset(env_id)
            # Reset not supported, drop this env from the ready list.
            if resetted_obs is None:
                if horizon != float("inf"):
                    raise ValueError(
                        "Setting episode horizon requires reset() support "
                        "from the environment."
                    )
            # Creates a new episode if this is not async return.
            # If reset is async, we will get its result in some future poll.
            elif resetted_obs != ASYNC_RESET_RETURN:
                new_episode: Episode = active_episodes[env_id]
                resetted_obs = resetted_obs[env_id]
                if observation_fn:
                    resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
                        agent_obs=resetted_obs,
                        worker=worker,
                        base_env=base_env,
                        policies=worker.policy_map,
                        episode=new_episode,
                    )
                # types: AgentID, EnvObsType
                for agent_id, raw_obs in resetted_obs.items():
                    policy_id: PolicyID = new_episode.policy_for(agent_id)
                    preproccessor = _get_or_raise(worker.preprocessors, policy_id)

                    prep_obs: EnvObsType = raw_obs
                    if preproccessor is not None:
                        prep_obs = preproccessor.transform(raw_obs)
                    filtered_obs: EnvObsType = _get_or_raise(worker.filters, policy_id)(
                        prep_obs
                    )
                    new_episode._set_last_raw_obs(agent_id, raw_obs)
                    new_episode._set_last_observation(agent_id, filtered_obs)

                    # Add initial obs to buffer.
                    sample_collector.add_init_obs(
                        new_episode,
                        agent_id,
                        env_id,
                        policy_id,
                        new_episode.length - 1,
                        filtered_obs,
                    )

                    item = PolicyEvalData(
                        env_id,
                        agent_id,
                        filtered_obs,
                        episode.last_info_for(agent_id) or {},
                        episode.rnn_state_for(agent_id),
                        None,
                        0.0,
                    )
                    to_eval[policy_id].append(item)

    # Try to build something.
    if multiple_episodes_in_batch:
        sample_batches = (
            sample_collector.try_build_truncated_episode_multi_agent_batch()
        )
        if sample_batches:
            outputs.extend(sample_batches)

    return active_envs, to_eval, outputs


def _do_policy_eval(
    *,
    to_eval: Dict[PolicyID, List[PolicyEvalData]],
    policies: PolicyMap,
    sample_collector: SampleCollector,
    active_episodes: Dict[EnvID, Episode],
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
    """Call compute_actions on collected episode/model data to get next action.

    Args:
        to_eval: Mapping of policy IDs to lists of PolicyEvalData objects
            (items in these lists will be the batch's items for the model
            forward pass).
        policies: Mapping from policy ID to Policy obj.
        sample_collector: The SampleCollector object to use.
        active_episodes: Mapping of EnvID to its currently active episode.

    Returns:
        Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs.
    """

    eval_results: Dict[PolicyID, TensorStructType] = {}

    if log_once("compute_actions_input"):
        logger.info("Inputs to compute_actions():\n\n{}\n".format(summarize(to_eval)))

    for policy_id, eval_data in to_eval.items():
        # In case the policyID has been removed from this worker, we need to
        # re-assign policy_id and re-lookup the Policy object to use.
        try:
            policy: Policy = _get_or_raise(policies, policy_id)
        except ValueError:
            # Important: Get the policy_mapping_fn from the active
            # Episode as the policy_mapping_fn from the worker may
            # have already been changed (mapping fn stay constant
            # within one episode).
            episode = active_episodes[eval_data[0].env_id]
            policy_id = episode.policy_mapping_fn(
                eval_data[0].agent_id, episode, worker=episode.worker
            )
            policy: Policy = _get_or_raise(policies, policy_id)

        input_dict = sample_collector.get_inference_input_dict(policy_id)
        eval_results[policy_id] = policy.compute_actions_from_input_dict(
            input_dict,
            timestep=policy.global_timestep,
            episodes=[active_episodes[t.env_id] for t in eval_data],
        )

    if log_once("compute_actions_result"):
        logger.info(
            "Outputs of compute_actions():\n\n{}\n".format(summarize(eval_results))
        )

    return eval_results


def _process_policy_eval_results(
    *,
    to_eval: Dict[PolicyID, List[PolicyEvalData]],
    eval_results: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]],
    active_episodes: Dict[EnvID, Episode],
    active_envs: Set[int],
    off_policy_actions: MultiEnvDict,
    policies: Dict[PolicyID, Policy],
    normalize_actions: bool,
    clip_actions: bool,
) -> Dict[EnvID, Dict[AgentID, EnvActionType]]:
    """Process the output of policy neural network evaluation.

    Records policy evaluation results into the given episode objects and
    returns replies to send back to agents in the env.

    Args:
        to_eval: Mapping of policy IDs to lists of PolicyEvalData objects.
        eval_results: Mapping of policy IDs to list of
            actions, rnn-out states, extra-action-fetches dicts.
        active_episodes: Mapping from episode ID to currently ongoing
            Episode object.
        active_envs: Set of non-terminated env ids.
        off_policy_actions: Doubly keyed dict of env-ids -> agent ids ->
            off-policy-action, returned by a `BaseEnv.poll()` call.
        policies: Mapping from policy ID to Policy.
        normalize_actions: Whether to normalize actions to the action
            space's bounds.
        clip_actions: Whether to clip actions to the action space's bounds.

    Returns:
        Nested dict of env id -> agent id -> actions to be sent to
        Env (np.ndarrays).
    """

    actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = defaultdict(dict)

    # types: int
    for env_id in active_envs:
        actions_to_send[env_id] = {}  # at minimum send empty dict

    # types: PolicyID, List[PolicyEvalData]
    for policy_id, eval_data in to_eval.items():
        actions: TensorStructType = eval_results[policy_id][0]
        actions = convert_to_numpy(actions)

        rnn_out_cols: StateBatch = eval_results[policy_id][1]
        extra_action_out_cols: dict = eval_results[policy_id][2]

        # In case actions is a list (representing the 0th dim of a batch of
        # primitive actions), try converting it first.
        if isinstance(actions, list):
            actions = np.array(actions)

        # Store RNN state ins/outs and extra-action fetches to episode.
        for f_i, column in enumerate(rnn_out_cols):
            extra_action_out_cols["state_out_{}".format(f_i)] = column

        policy: Policy = _get_or_raise(policies, policy_id)
        # Split action-component batches into single action rows.
        actions: List[EnvActionType] = unbatch(actions)
        # types: int, EnvActionType
        for i, action in enumerate(actions):
            # Normalize, if necessary.
            if normalize_actions:
                action_to_send = unsquash_action(action, policy.action_space_struct)
            # Clip, if necessary.
            elif clip_actions:
                action_to_send = clip_action(action, policy.action_space_struct)
            else:
                action_to_send = action

            env_id: int = eval_data[i].env_id
            agent_id: AgentID = eval_data[i].agent_id
            episode: Episode = active_episodes[env_id]
            episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
            episode._set_last_extra_action_outs(
                agent_id, {k: v[i] for k, v in extra_action_out_cols.items()}
            )
            if env_id in off_policy_actions and agent_id in off_policy_actions[env_id]:
                episode._set_last_action(agent_id, off_policy_actions[env_id][agent_id])
            else:
                episode._set_last_action(agent_id, action)

            assert agent_id not in actions_to_send[env_id]
            actions_to_send[env_id][agent_id] = action_to_send

    return actions_to_send


def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]:
    """Atari games have multiple logical episodes, one per life.

    However, for metrics reporting we count full episodes, all lives included.
    """
    sub_environments = base_env.get_sub_environments()
    if not sub_environments:
        return None
    atari_out = []
    for sub_env in sub_environments:
        monitor = get_wrapper_by_cls(sub_env, MonitorEnv)
        if not monitor:
            return None
        for eps_rew, eps_len in monitor.next_episode_results():
            atari_out.append(RolloutMetrics(eps_len, eps_rew))
    return atari_out


def _to_column_format(rnn_state_rows: List[List[Any]]) -> StateBatch:
    num_cols = len(rnn_state_rows[0])
    return [[row[i] for row in rnn_state_rows] for i in range(num_cols)]


def _get_or_raise(
    mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]],
    policy_id: PolicyID,
) -> Union[Policy, Preprocessor, Filter]:
    """Returns an object under key `policy_id` in `mapping`.

    Args:
        mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The
            mapping dict from policy id (str) to actual object (Policy,
            Preprocessor, etc.).
        policy_id (str): The policy ID to lookup.

    Returns:
        Union[Policy, Preprocessor, Filter]: The found object.

    Raises:
        ValueError: If `policy_id` cannot be found in `mapping`.
    """
    if policy_id not in mapping:
        raise ValueError(
            "Could not find policy for agent: PolicyID `{}` not found "
            "in policy map, whose keys are `{}`.".format(policy_id, mapping.keys())
        )
    return mapping[policy_id]
