import os
import gym
from gym import wrappers
from gym.spaces import Box, Discrete, MultiDiscrete, Space
import logging
import numpy as np
import ray

from env.base_env import BaseEnv
from env.external_env import ExternalEnv
from env.vector_env import VectorEnv
from env.env_context import EnvContext
from env.multi_agent_env import MultiAgentEnv
from utils import add_mixins
from utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
from utils.typing import EnvType

logger = logging.getLogger(__name__)


def gym_env_creator(env_context: EnvContext, env_descriptor: str):
    """Tries to create a gym env given an EnvContext object and descriptor.

    Note: This function tries to construct the env from a string descriptor
    only using possibly installed RL env packages (such as gym, pybullet_envs,
    vizdoomgym, etc..). These packages are no installation requirements for
    RLlib. In case you would like to support more such env packages, add the
    necessary imports and construction logic below.

    Args:
        env_context (EnvContext): The env context object to configure the env.
            Note that this is a config dict, plus the properties:
            `worker_index`, `vector_index`, and `remote`.
        env_descriptor (str): The env descriptor, e.g. CartPole-v0,
            MsPacmanNoFrameskip-v4, VizdoomBasic-v0, or
            CartPoleContinuousBulletEnv-v0.

    Returns:
        gym.Env: The actual gym environment object.

    Raises:
        gym.error.Error: If the env cannot be constructed.
    """
    import gym

    # Allow for PyBullet or VizdoomGym envs to be used as well
    # (via string). This allows for doing things like
    # `env=CartPoleContinuousBulletEnv-v0` or
    # `env=VizdoomBasic-v0`.
    try:
        import pybullet_envs

        pybullet_envs.getList()
    except (ModuleNotFoundError, ImportError):
        pass
    try:
        import vizdoomgym

        vizdoomgym.__name__  # trick LINTer.
    except (ModuleNotFoundError, ImportError):
        pass

    # Try creating a gym env. If this fails we can output a
    # decent error message.
    try:
        return gym.make(env_descriptor, **env_context)
    except gym.error.Error:
        raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_descriptor))


class VideoMonitor(wrappers.Monitor):
    # Same as original method, but doesn't use the StatsRecorder as it will
    # try to add up multi-agent rewards dicts, which throws errors.
    def _after_step(self, observation, reward, done, info):
        if not self.enabled:
            return done

        # Use done["__all__"] b/c this is a multi-agent dict.
        if done["__all__"] and self.env_semantics_autoreset:
            # For envs with BlockingReset wrapping VNCEnv, this observation
            # will be the first one of the new episode
            self.reset_video_recorder()
            self.episode_id += 1
            self._flush()

        # Record video
        self.video_recorder.capture_frame()

        return done


def record_env_wrapper(env, record_env, log_dir, policy_config):
    if record_env:
        path_ = record_env if isinstance(record_env, str) else log_dir
        # Relative path: Add logdir here, otherwise, this would
        # not work for non-local workers.
        if not os.path.isabs(path_):
            path_ = os.path.join(log_dir, path_)
        print(f"Setting the path for recording to {path_}")
        wrapper_cls = (
            VideoMonitor if isinstance(env, MultiAgentEnv) else wrappers.Monitor
        )
        wrapper_cls = add_mixins(wrapper_cls, [MultiAgentEnv], reversed=True)
        env = wrapper_cls(
            env,
            path_,
            resume=True,
            force=True,
            video_callable=lambda _: True,
            mode="evaluation" if policy_config["in_evaluation"] else "training",
        )
    return env


def default_validate_env(env: EnvType, env_context: EnvContext = None):
    # Base message for checking the env for vector-index=0
    msg = f"Validating sub-env at vector index={env_context.vector_index} ..."

    allowed_types = [gym.Env, ExternalEnv, VectorEnv, BaseEnv, ray.actor.ActorHandle]
    if not any(isinstance(env, tpe) for tpe in allowed_types):
        # Allow this as a special case (assumed gym.Env).
        # TODO: Disallow this early-out. Everything should conform to a few
        #  supported classes, i.e. gym.Env/MultiAgentEnv/etc...
        if hasattr(env, "observation_space") and hasattr(env, "action_space"):
            logger.warning(msg + f" (warning; invalid env-type={type(env)})")
            return
        else:
            logger.warning(msg + " (NOT OK)")
            raise EnvError(
                "Returned env should be an instance of gym.Env (incl. "
                "MultiAgentEnv), ExternalEnv, VectorEnv, or BaseEnv. "
                f"The provided env creator function returned {env} "
                f"(type={type(env)})."
            )

    # Do some test runs with the provided env.
    if isinstance(env, gym.Env) and not isinstance(env, MultiAgentEnv):
        # Make sure the gym.Env has the two space attributes properly set.
        assert hasattr(env, "observation_space") and hasattr(env, "action_space")
        # Get a dummy observation by resetting the env.
        dummy_obs = env.reset()
        # Convert lists to np.ndarrays.
        if type(dummy_obs) is list and isinstance(env.observation_space, Box):
            dummy_obs = np.array(dummy_obs)
        # Ignore float32/float64 diffs.
        if (
            isinstance(env.observation_space, Box)
            and env.observation_space.dtype != dummy_obs.dtype
        ):
            dummy_obs = dummy_obs.astype(env.observation_space.dtype)
        # Check, if observation is ok (part of the observation space). If not,
        # error.
        if not env.observation_space.contains(dummy_obs):
            logger.warning(msg + " (NOT OK)")
            raise EnvError(
                f"Env's `observation_space` {env.observation_space} does not "
                f"contain returned observation after a reset ({dummy_obs})!"
            )

    # Log that everything is ok.
    logger.info(msg + " (ok)")
