import logging
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING, Union

import gym
import ray
from utils.annotations import Deprecated, override, PublicAPI
from utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, MultiEnvDict

if TYPE_CHECKING:
    from models.preprocessors import Preprocessor
    from env.external_env import ExternalEnv
    from env.multi_agent_env import MultiAgentEnv
    from env.vector_env import VectorEnv

ASYNC_RESET_RETURN = "async_reset_return"

logger = logging.getLogger(__name__)


@PublicAPI
class BaseEnv:
    """The lowest-level env interface used by RLlib for sampling.

    BaseEnv models multiple agents executing asynchronously in multiple
    vectorized sub-environments. A call to `poll()` returns observations from
    ready agents keyed by their sub-environment ID and agent IDs, and
    actions for those agents can be sent back via `send_actions()`.

    All other RLlib supported env types can be converted to BaseEnv.
    RLlib handles these conversions internally in RolloutWorker, for example:

    gym.Env => rllib.VectorEnv => rllib.BaseEnv
    rllib.MultiAgentEnv (is-a gym.Env) => rllib.VectorEnv => rllib.BaseEnv
    rllib.ExternalEnv => rllib.BaseEnv

    Examples:
        >>> env = MyBaseEnv()
        >>> obs, rewards, dones, infos, off_policy_actions = env.poll()
        >>> print(obs)
        {
            "env_0": {
                "car_0": [2.4, 1.6],
                "car_1": [3.4, -3.2],
            },
            "env_1": {
                "car_0": [8.0, 4.1],
            },
            "env_2": {
                "car_0": [2.3, 3.3],
                "car_1": [1.4, -0.2],
                "car_3": [1.2, 0.1],
            },
        }
        >>> env.send_actions({
        ...   "env_0": {
        ...     "car_0": 0,
        ...     "car_1": 1,
        ...   }, ...
        ... })
        >>> obs, rewards, dones, infos, off_policy_actions = env.poll()
        >>> print(obs)
        {
            "env_0": {
                "car_0": [4.1, 1.7],
                "car_1": [3.2, -4.2],
            }, ...
        }
        >>> print(dones)
        {
            "env_0": {
                "__all__": False,
                "car_0": False,
                "car_1": True,
            }, ...
        }
    """

    def to_base_env(
        self,
        make_env: Callable[[int], EnvType] = None,
        num_envs: int = 1,
        remote_envs: bool = False,
        remote_env_batch_wait_ms: int = 0,
    ) -> "BaseEnv":
        """Converts an RLlib-supported env into a BaseEnv object.

        Supported types for the `env` arg are gym.Env, BaseEnv,
        VectorEnv, MultiAgentEnv, ExternalEnv, or ExternalMultiAgentEnv.

        The resulting BaseEnv is always vectorized (contains n
        sub-environments) to support batched forward passes, where n may also
        be 1. BaseEnv also supports async execution via the `poll` and
        `send_actions` methods and thus supports external simulators.

        TODO: Support gym3 environments, which are already vectorized.

        Args:
            env: An already existing environment of any supported env type
                to convert/wrap into a BaseEnv. Supported types are gym.Env,
                BaseEnv, VectorEnv, MultiAgentEnv, ExternalEnv, and
                ExternalMultiAgentEnv.
            make_env: A callable taking an int as input (which indicates the
                number of individual sub-environments within the final
                vectorized BaseEnv) and returning one individual
                sub-environment.
            num_envs: The number of sub-environments to create in the
                resulting (vectorized) BaseEnv. The already existing `env`
                will be one of the `num_envs`.
            remote_envs: Whether each sub-env should be a @ray.remote actor.
                You can set this behavior in your config via the
                `remote_worker_envs=True` option.
            remote_env_batch_wait_ms: The wait time (in ms) to poll remote
                sub-environments for, if applicable. Only used if
                `remote_envs` is True.
            policy_config: Optional policy config dict.

        Returns:
            The resulting BaseEnv object.
        """
        del make_env, num_envs, remote_envs, remote_env_batch_wait_ms
        return self

    @PublicAPI
    def poll(
        self,
    ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
        """Returns observations from ready agents.

        All return values are two-level dicts mapping from EnvID to dicts
        mapping from AgentIDs to (observation/reward/etc..) values.
        The number of agents and sub-environments may vary over time.

        Returns:
            Tuple consisting of
            1) New observations for each ready agent.
            2) Reward values for each ready agent. If the episode is
            just started, the value will be None.
            3) Done values for each ready agent. The special key "__all__"
            is used to indicate env termination.
            4) Info values for each ready agent.
            5) Agents may take off-policy actions. When that
            happens, there will be an entry in this dict that contains the
            taken action. There is no need to send_actions() for agents that
            have already chosen off-policy actions.
        """
        raise NotImplementedError

    @PublicAPI
    def send_actions(self, action_dict: MultiEnvDict) -> None:
        """Called to send actions back to running agents in this env.

        Actions should be sent for each ready agent that returned observations
        in the previous poll() call.

        Args:
            action_dict: Actions values keyed by env_id and agent_id.
        """
        raise NotImplementedError

    @PublicAPI
    def try_reset(
        self, env_id: Optional[EnvID] = None
    ) -> Optional[Union[MultiAgentDict, MultiEnvDict]]:
        """Attempt to reset the sub-env with the given id or all sub-envs.

        If the environment does not support synchronous reset, None can be
        returned here.

        Args:
            env_id: The sub-environment's ID if applicable. If None, reset
                the entire Env (i.e. all sub-environments).

        Note: A MultiAgentDict is returned when using the deprecated wrapper
        classes such as `env.base_env._MultiAgentEnvToBaseEnv`,
        however for consistency with the poll() method, a `MultiEnvDict` is
        returned from the new wrapper classes, such as
        `env.multi_agent_env.MultiAgentEnvWrapper`.

        Returns:
            The reset (multi-agent) observation dict. None if reset is not
            supported.
        """
        return None

    @PublicAPI
    def get_sub_environments(self, as_dict: bool = False) -> Union[List[EnvType], dict]:
        """Return a reference to the underlying sub environments, if any.

        Args:
            as_dict: If True, return a dict mapping from env_id to env.

        Returns:
            List or dictionary of the underlying sub environments or [] / {}.
        """
        if as_dict:
            return {}
        return []

    @PublicAPI
    def get_agent_ids(self) -> Dict[EnvID, List[AgentID]]:
        """Return the agent ids for each sub-environment.

        Returns:
            A dict mapping from env_id to a list of agent_ids.
        """
        logger.warning("get_agent_ids() has not been implemented")
        return {}

    @PublicAPI
    def try_render(self, env_id: Optional[EnvID] = None) -> None:
        """Tries to render the sub-environment with the given id or all.

        Args:
            env_id: The sub-environment's ID, if applicable.
                If None, renders the entire Env (i.e. all sub-environments).
        """

        # By default, do nothing.
        pass

    @PublicAPI
    def stop(self) -> None:
        """Releases all resources used."""

        # Try calling `close` on all sub-environments.
        for env in self.get_sub_environments():
            if hasattr(env, "close"):
                env.close()

    @Deprecated(new="get_sub_environments", error=False)
    def get_unwrapped(self) -> List[EnvType]:
        return self.get_sub_environments()

    @PublicAPI
    @property
    def observation_space(self) -> gym.spaces.Dict:
        """Returns the observation space for each environment.

        Note: samples from the observation space need to be preprocessed into a
            `MultiEnvDict` before being used by a policy.

        Returns:
            The observation space for each environment.
        """
        raise NotImplementedError

    @PublicAPI
    @property
    def action_space(self) -> gym.Space:
        """Returns the action space for each environment.

        Note: samples from the action space need to be preprocessed into a
            `MultiEnvDict` before being passed to `send_actions`.

        Returns:
            The observation space for each environment.
        """
        raise NotImplementedError

    @PublicAPI
    def action_space_sample(self, agent_id: list = None) -> MultiEnvDict:
        """Returns a random action for each environment, and potentially each
            agent in that environment.

        Args:
            agent_id: List of agent ids to sample actions for. If None or empty
                list, sample actions for all agents in the environment.

        Returns:
            A random action for each environment.
        """
        del agent_id
        return {}

    @PublicAPI
    def observation_space_sample(self, agent_id: list = None) -> MultiEnvDict:
        """Returns a random observation for each environment, and potentially
            each agent in that environment.

        Args:
            agent_id: List of agent ids to sample actions for. If None or empty
                list, sample actions for all agents in the environment.

        Returns:
            A random action for each environment.
        """
        logger.warning("observation_space_sample() has not been implemented")
        return {}

    @PublicAPI
    def last(
        self,
    ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
        """Returns the last observations, rewards, and done flags that were
            returned by the environment.

        Returns:
            The last observations, rewards, and done flags for each environment
        """
        logger.warning("last has not been implemented for this environment.")
        return {}, {}, {}, {}, {}

    @PublicAPI
    def observation_space_contains(self, x: MultiEnvDict) -> bool:
        """Checks if the given observation is valid for each environment.

        Args:
            x: Observations to check.

        Returns:
            True if the observations are contained within their respective
                spaces. False otherwise.
        """
        self._space_contains(self.observation_space, x)

    @PublicAPI
    def action_space_contains(self, x: MultiEnvDict) -> bool:
        """Checks if the given actions is valid for each environment.

        Args:
            x: Actions to check.

        Returns:
            True if the actions are contained within their respective
                spaces. False otherwise.
        """
        return self._space_contains(self.action_space, x)

    @staticmethod
    def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
        """Check if the given space contains the observations of x.

        Args:
            space: The space to if x's observations are contained in.
            x: The observations to check.

        Returns:
            True if the observations of x are contained in space.
        """
        # this removes the agent_id key and inner dicts
        # in MultiEnvDicts
        flattened_obs = {env_id: list(obs.values()) for env_id, obs in x.items()}
        ret = True
        for env_id in flattened_obs:
            for obs in flattened_obs[env_id]:
                ret = ret and space[env_id].contains(obs)
        return ret


# Fixed agent identifier when there is only the single agent in the env
_DUMMY_AGENT_ID = "agent0"


@Deprecated(new="with_dummy_agent_id", error=False)
def _with_dummy_agent_id(
    env_id_to_values: Dict[EnvID, Any], dummy_id: "AgentID" = _DUMMY_AGENT_ID
) -> MultiEnvDict:
    return {k: {dummy_id: v} for (k, v) in env_id_to_values.items()}


def with_dummy_agent_id(
    env_id_to_values: Dict[EnvID, Any], dummy_id: "AgentID" = _DUMMY_AGENT_ID
) -> MultiEnvDict:
    return {k: {dummy_id: v} for (k, v) in env_id_to_values.items()}


@Deprecated(
    old="env.base_env._ExternalEnvToBaseEnv",
    new="env.external.ExternalEnvWrapper",
    error=False,
)
class _ExternalEnvToBaseEnv(BaseEnv):
    """Internal adapter of ExternalEnv to BaseEnv."""

    def __init__(
        self, external_env: "ExternalEnv", preprocessor: "Preprocessor" = None
    ):
        from env.external_multi_agent_env import ExternalMultiAgentEnv

        self.external_env = external_env
        self.prep = preprocessor
        self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
        self.action_space = external_env.action_space
        if preprocessor:
            self.observation_space = preprocessor.observation_space
        else:
            self.observation_space = external_env.observation_space
        external_env.start()

    @override(BaseEnv)
    def poll(
        self,
    ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
        with self.external_env._results_avail_condition:
            results = self._poll()
            while len(results[0]) == 0:
                self.external_env._results_avail_condition.wait()
                results = self._poll()
                if not self.external_env.is_alive():
                    raise Exception("Serving thread has stopped.")
        limit = self.external_env._max_concurrent_episodes
        assert len(results[0]) < limit, (
            "Too many concurrent episodes, were some leaked? This "
            "ExternalEnv was created with max_concurrent={}".format(limit)
        )
        return results

    @override(BaseEnv)
    def send_actions(self, action_dict: MultiEnvDict) -> None:
        if self.multiagent:
            for env_id, actions in action_dict.items():
                self.external_env._episodes[env_id].action_queue.put(actions)
        else:
            for env_id, action in action_dict.items():
                self.external_env._episodes[env_id].action_queue.put(
                    action[_DUMMY_AGENT_ID]
                )

    def _poll(
        self,
    ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
        all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {}
        off_policy_actions = {}
        for eid, episode in self.external_env._episodes.copy().items():
            data = episode.get_data()
            cur_done = (
                episode.cur_done_dict["__all__"]
                if self.multiagent
                else episode.cur_done
            )
            if cur_done:
                del self.external_env._episodes[eid]
            if data:
                if self.prep:
                    all_obs[eid] = self.prep.transform(data["obs"])
                else:
                    all_obs[eid] = data["obs"]
                all_rewards[eid] = data["reward"]
                all_dones[eid] = data["done"]
                all_infos[eid] = data["info"]
                if "off_policy_action" in data:
                    off_policy_actions[eid] = data["off_policy_action"]
        if self.multiagent:
            # Ensure a consistent set of keys
            # rely on all_obs having all possible keys for now.
            for eid, eid_dict in all_obs.items():
                for agent_id in eid_dict.keys():

                    def fix(d, zero_val):
                        if agent_id not in d[eid]:
                            d[eid][agent_id] = zero_val

                    fix(all_rewards, 0.0)
                    fix(all_dones, False)
                    fix(all_infos, {})
            return (all_obs, all_rewards, all_dones, all_infos, off_policy_actions)
        else:
            return (
                _with_dummy_agent_id(all_obs),
                _with_dummy_agent_id(all_rewards),
                _with_dummy_agent_id(all_dones, "__all__"),
                _with_dummy_agent_id(all_infos),
                _with_dummy_agent_id(off_policy_actions),
            )


@Deprecated(
    old="env.base_env._VectorEnvToBaseEnv",
    new="env.vector_env.VectorEnvWrapper",
    error=False,
)
class _VectorEnvToBaseEnv(BaseEnv):
    """Internal adapter of VectorEnv to BaseEnv.

    We assume the caller will always send the full vector of actions in each
    call to send_actions(), and that they call reset_at() on all completed
    environments before calling send_actions().
    """

    def __init__(self, vector_env: "VectorEnv"):
        self.vector_env = vector_env
        self.action_space = vector_env.action_space
        self.observation_space = vector_env.observation_space
        self.num_envs = vector_env.num_envs
        self.new_obs = None  # lazily initialized
        self.cur_rewards = [None for _ in range(self.num_envs)]
        self.cur_dones = [False for _ in range(self.num_envs)]
        self.cur_infos = [None for _ in range(self.num_envs)]

    @override(BaseEnv)
    def poll(
        self,
    ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
        if self.new_obs is None:
            self.new_obs = self.vector_env.vector_reset()
        new_obs = dict(enumerate(self.new_obs))
        rewards = dict(enumerate(self.cur_rewards))
        dones = dict(enumerate(self.cur_dones))
        infos = dict(enumerate(self.cur_infos))
        self.new_obs = []
        self.cur_rewards = []
        self.cur_dones = []
        self.cur_infos = []
        return (
            _with_dummy_agent_id(new_obs),
            _with_dummy_agent_id(rewards),
            _with_dummy_agent_id(dones, "__all__"),
            _with_dummy_agent_id(infos),
            {},
        )

    @override(BaseEnv)
    def send_actions(self, action_dict: MultiEnvDict) -> None:
        action_vector = [None] * self.num_envs
        for i in range(self.num_envs):
            action_vector[i] = action_dict[i][_DUMMY_AGENT_ID]
        (
            self.new_obs,
            self.cur_rewards,
            self.cur_dones,
            self.cur_infos,
        ) = self.vector_env.vector_step(action_vector)

    @override(BaseEnv)
    def try_reset(self, env_id: Optional[EnvID] = None) -> MultiAgentDict:
        assert env_id is None or isinstance(env_id, int)
        return {_DUMMY_AGENT_ID: self.vector_env.reset_at(env_id)}

    @override(BaseEnv)
    def get_sub_environments(self) -> List[EnvType]:
        return self.vector_env.get_sub_environments()

    @override(BaseEnv)
    def try_render(self, env_id: Optional[EnvID] = None) -> None:
        assert env_id is None or isinstance(env_id, int)
        return self.vector_env.try_render_at(env_id)


@Deprecated(
    old="env.base_env._MultiAgentEnvToBaseEnv",
    new="env.multi_agent_env.MultiAgentEnvWrapper",
    error=False,
)
class _MultiAgentEnvToBaseEnv(BaseEnv):
    """Internal adapter of MultiAgentEnv to BaseEnv.

    This also supports vectorization if num_envs > 1.
    """

    def __init__(
        self,
        make_env: Callable[[int], EnvType],
        existing_envs: "MultiAgentEnv",
        num_envs: int,
    ):
        """Wraps MultiAgentEnv(s) into the BaseEnv API.

        Args:
            make_env (Callable[[int], EnvType]): Factory that produces a new
                MultiAgentEnv intance. Must be defined, if the number of
                existing envs is less than num_envs.
            existing_envs (List[MultiAgentEnv]): List of already existing
                multi-agent envs.
            num_envs (int): Desired num multiagent envs to have at the end in
                total. This will include the given (already created)
                `existing_envs`.
        """
        from env.multi_agent_env import MultiAgentEnv

        self.make_env = make_env
        self.envs = existing_envs
        self.num_envs = num_envs
        self.dones = set()
        while len(self.envs) < self.num_envs:
            self.envs.append(self.make_env(len(self.envs)))
        for env in self.envs:
            assert isinstance(env, MultiAgentEnv)
        self.env_states = [_MultiAgentEnvState(env) for env in self.envs]

    @override(BaseEnv)
    def poll(
        self,
    ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
        obs, rewards, dones, infos = {}, {}, {}, {}
        for i, env_state in enumerate(self.env_states):
            obs[i], rewards[i], dones[i], infos[i] = env_state.poll()
        return obs, rewards, dones, infos, {}

    @override(BaseEnv)
    def send_actions(self, action_dict: MultiEnvDict) -> None:
        for env_id, agent_dict in action_dict.items():
            if env_id in self.dones:
                raise ValueError("Env {} is already done".format(env_id))
            env = self.envs[env_id]
            obs, rewards, dones, infos = env.step(agent_dict)
            assert isinstance(obs, dict), "Not a multi-agent obs"
            assert isinstance(rewards, dict), "Not a multi-agent reward"
            assert isinstance(dones, dict), "Not a multi-agent return"
            assert isinstance(infos, dict), "Not a multi-agent info"
            # Allow `__common__` entry in `infos` for data unrelated with any
            # agent, but rather with the environment itself.
            if set(infos).difference(set(obs) | {"__common__"}):
                raise ValueError(
                    "Key set for infos must be a subset of obs: "
                    "{} vs {}".format(infos.keys(), obs.keys())
                )
            if "__all__" not in dones:
                raise ValueError(
                    "In multi-agent environments, '__all__': True|False must "
                    "be included in the 'done' dict: got {}.".format(dones)
                )
            if dones["__all__"]:
                self.dones.add(env_id)
            self.env_states[env_id].observe(obs, rewards, dones, infos)

    @override(BaseEnv)
    def try_reset(self, env_id: Optional[EnvID] = None) -> Optional[MultiAgentDict]:
        obs = self.env_states[env_id].reset()
        assert isinstance(obs, dict), "Not a multi-agent obs"
        if obs is not None and env_id in self.dones:
            self.dones.remove(env_id)
        return obs

    @override(BaseEnv)
    def get_sub_environments(self) -> List[EnvType]:
        return [state.env for state in self.env_states]

    @override(BaseEnv)
    def try_render(self, env_id: Optional[EnvID] = None) -> None:
        if env_id is None:
            env_id = 0
        assert isinstance(env_id, int)
        return self.envs[env_id].render()


@Deprecated(
    old="env.base_env._MultiAgentEnvState",
    new="env.multi_agent_env._MultiAgentEnvState",
    error=False,
)
class _MultiAgentEnvState:
    def __init__(self, env: "MultiAgentEnv"):
        from env.multi_agent_env import MultiAgentEnv

        assert isinstance(env, MultiAgentEnv)
        self.env = env
        self.initialized = False

    def poll(
        self,
    ) -> Tuple[MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict]:
        if not self.initialized:
            self.reset()
            self.initialized = True

        observations = self.last_obs
        rewards = {}
        dones = {"__all__": self.last_dones["__all__"]}
        infos = {"__common__": self.last_infos.get("__common__")}

        # If episode is done, release everything we have.
        if dones["__all__"]:
            rewards = self.last_rewards
            self.last_rewards = {}
            dones = self.last_dones
            self.last_dones = {}
            self.last_obs = {}
            infos = self.last_infos
            self.last_infos = {}
        # Only release those agents' rewards/dones/infos, whose
        # observations we have.
        else:
            for ag in observations.keys():
                if ag in self.last_rewards:
                    rewards[ag] = self.last_rewards[ag]
                    del self.last_rewards[ag]
                if ag in self.last_dones:
                    dones[ag] = self.last_dones[ag]
                    del self.last_dones[ag]
                if ag in self.last_infos:
                    infos[ag] = self.last_infos[ag]
                    del self.last_infos[ag]

        self.last_dones["__all__"] = False
        return observations, rewards, dones, infos

    def observe(
        self,
        obs: MultiAgentDict,
        rewards: MultiAgentDict,
        dones: MultiAgentDict,
        infos: MultiAgentDict,
    ):
        self.last_obs = obs
        for ag, r in rewards.items():
            if ag in self.last_rewards:
                self.last_rewards[ag] += r
            else:
                self.last_rewards[ag] = r
        for ag, d in dones.items():
            if ag in self.last_dones:
                self.last_dones[ag] = self.last_dones[ag] or d
            else:
                self.last_dones[ag] = d
        self.last_infos = infos

    def reset(self) -> MultiAgentDict:
        self.last_obs = self.env.reset()
        self.last_rewards = {}
        self.last_dones = {"__all__": False}
        self.last_infos = {"__common__": {}}
        return self.last_obs


def convert_to_base_env(
    env: EnvType,
    make_env: Callable[[int], EnvType] = None,
    num_envs: int = 1,
    remote_envs: bool = False,
    remote_env_batch_wait_ms: int = 0,
) -> "BaseEnv":
    """Converts an RLlib-supported env into a BaseEnv object.

    Supported types for the `env` arg are gym.Env, BaseEnv,
    VectorEnv, MultiAgentEnv, ExternalEnv, or ExternalMultiAgentEnv.

    The resulting BaseEnv is always vectorized (contains n
    sub-environments) to support batched forward passes, where n may also
    be 1. BaseEnv also supports async execution via the `poll` and
    `send_actions` methods and thus supports external simulators.

    TODO: Support gym3 environments, which are already vectorized.

    Args:
        env: An already existing environment of any supported env type
            to convert/wrap into a BaseEnv. Supported types are gym.Env,
            BaseEnv, VectorEnv, MultiAgentEnv, ExternalEnv, and
            ExternalMultiAgentEnv.
        make_env: A callable taking an int as input (which indicates the
            number of individual sub-environments within the final
            vectorized BaseEnv) and returning one individual
            sub-environment.
        num_envs: The number of sub-environments to create in the
            resulting (vectorized) BaseEnv. The already existing `env`
            will be one of the `num_envs`.
        remote_envs: Whether each sub-env should be a @ray.remote actor.
            You can set this behavior in your config via the
            `remote_worker_envs=True` option.
        remote_env_batch_wait_ms: The wait time (in ms) to poll remote
            sub-environments for, if applicable. Only used if
            `remote_envs` is True.

    Returns:
        The resulting BaseEnv object.
    """

    from env.remote_base_env import RemoteBaseEnv
    from env.external_env import ExternalEnv
    from env.multi_agent_env import MultiAgentEnv
    from env.vector_env import VectorEnv, VectorEnvWrapper

    if remote_envs and num_envs == 1:
        raise ValueError(
            "Remote envs only make sense to use if num_envs > 1 "
            "(i.e. vectorization is enabled)."
        )

    # Given `env` is already a BaseEnv -> Return as is.
    if isinstance(env, (BaseEnv, MultiAgentEnv, VectorEnv, ExternalEnv)):
        return env.to_base_env()
    # `env` is not a BaseEnv yet -> Need to convert/vectorize.
    else:
        # Sub-environments are ray.remote actors:
        if remote_envs:
            # Determine, whether the already existing sub-env (could
            # be a ray.actor) is multi-agent or not.
            multiagent = (
                ray.get(env._is_multi_agent.remote())
                if hasattr(env, "_is_multi_agent")
                else False
            )
            env = RemoteBaseEnv(
                make_env,
                num_envs,
                multiagent=multiagent,
                remote_env_batch_wait_ms=remote_env_batch_wait_ms,
                existing_envs=[env],
            )
        # Sub-environments are not ray.remote actors.
        else:
            # Convert gym.Env to VectorEnv ...
            env = VectorEnv.vectorize_gym_envs(
                make_env=make_env,
                existing_envs=[env],
                num_envs=num_envs,
                action_space=env.action_space,
                observation_space=env.observation_space,
            )
            # ... then the resulting VectorEnv to a BaseEnv.
            env = VectorEnvWrapper(env)

    # Make sure conversion went well.
    assert isinstance(env, BaseEnv), env

    return env
