import time
import warnings
from enum import IntEnum
from typing import Union
from pprint import pformat

import gym
import numpy as np
from gym.spaces import Box

import gym_montezuma.envs.wrappers as wrap
from gym_montezuma.envs.conditional_discrete import ConditionalDiscrete
from gym_montezuma.envs.errors import *
from pix2sym.ataritools.envs import MontezumaEnv
from pix2sym.montezuma.skills import *
from pix2sym.montezuma.newplans import plan, Room01

FPS = 180


class Option(IntEnum):
    WAIT_RESPAWN = 0
    RUN_LEFT = 1
    RUN_RIGHT = 2
    DOWN_ROPE = 3
    UP_ROPE = 4
    DROP_ROPE = 5
    DOWN_LADDER = 6
    UP_LADDER = 7
    JUMP = 8
    JUMP_LEFT = 9
    JUMP_RIGHT = 10
    CLIMB_PLATFORMS = 11
    WAIT_LASER_DISAPPEAR = 12
    WAIT_LASER_APPEAR = 13
    WAIT_BRIDGE_DISAPPEAR = 14
    WAIT_BRIDGE_APPEAR = 15
    WAIT_SPIDER_TOWARDS = 16
    WAIT_SPIDER_AWAY = 17
    WAIT_SKULL_TOWARDS = 18
    WAIT_SKULL_AWAY = 19
    WAIT_JUMP_SKULL_TOWARDS = 20
    WAIT_JUMP_SKULL_AWAY = 21
    PASS_LEFT_JUMP_SKULL = 22
    PASS_RIGHT_JUMP_SKULL = 23
    CHARGE_LEFT = 24
    CHARGE_RIGHT = 25
    WAIT_LEVEL = 26


class MontezumasRevengeEnv(gym.Env):
    metadata = {'render.modes': ['human', 'rgb_array']}

    def __init__(self,
                 single_life=False,
                 single_screen=False,
                 seed=None,
                 noop_wrapper=False,
                 render_mode="rgb_array",
                 observation_mode="rgb_array",
                 max_timesteps=-1,
                 eps=0.0,
                 clip_eps=0.5):
        self.observation_mode = observation_mode
        self.max_timesteps = max_timesteps
        self.executed_opt_count = 0
        self._eps = eps
        self._clip_eps = clip_eps
        self._current_plan = []
        self._env = MontezumaEnv(seed=seed, single_life=single_life, single_screen=single_screen, render_mode=render_mode)
        if noop_wrapper:
            self._env = wrap.NoopResetEnv(self._env, noop_max=30)
        self._seed = seed
        self.options = [
            skill() for skill in [
                WaitForRespawn, RunLeftSkill, RunRightSkill, ClimbDownRopeSkill, ClimbUpRopeSkill,
                DropFromRopeSkill, ClimbDownLadderSkill, ClimbUpLadderSkill, JumpInPlaceSkill,
                JumpLeftSkill, JumpRightSkill, ClimbPlatforms, WaitForLaserToDisappear,
                WaitForLaserToAppear, WaitForBridgeToDisappear, WaitForBridgeToAppear,
                WaitForSpiderMovingTowards, WaitForSpiderMovingAway, WaitForSkullMovingTowards,
                WaitForSkullMovingAway, WaitForJumpSkullMovingTowards, WaitForJumpSkullMovingAway,
                PassToLeftOfJumpSkull, PassToRightOfJumpSkull, ChargeEnemyLeft, ChargeEnemyRight,
                WaitForLevelChange
            ]
        ]
        self._option_names = [type(x).__name__ for x in self.options]
        self._state = None
        self._frame = None
        if observation_mode == "rgb_array":
            self.observation_space = Box(low=0, high=255, shape=(210, 160, 3), dtype=np.uint8)
        elif observation_mode == "ram":
            self.observation_space = Box(low=0, high=1, shape=(128,), dtype=float)
        else:
            raise ValueError(f"Invalid observation mode: {self.observation_mode}. Pick 'rgb_array' or 'ram'.")
        self.action_space = ConditionalDiscrete(len(self.options), self.available_options)

    def __getattr__(self, name):
        # NOTE: Forwarding to underlying env's attributes as this is not a proper wrapper
        # not using self.env because self.* internally calls __getattr__, results in recursion loop.
        internal_env = self.__class__.__getattribute__(self, '_env')
        return internal_env.__getattribute__(name)

    @property
    def unwrapped(self):
        return self.env.unwrapped

    @property
    def action_names(self):
        return self.get_action_meanings()

    @property
    def available_mask(self):
        return tuple(self.available_options().tolist())

    @property
    def observation(self):
        if self.observation_mode == "rgb_array":
            obs = self._env.getRGBFrame()
        else:
            obs = self._env.getRAM() / 255.0
        return obs

    def sample_action(self) -> np.int64:
        if self.executed_opt_count < len(self._current_plan):
            return self._current_plan[self.executed_opt_count]
        else:
            return self.action_space.sample()

    def reset(self):
        self._env.reset(seed=self._seed)
        self._state = self._env.getState()
        self._frame = self._env.getFrame()
        self.executed_opt_count = 0
        if np.random.rand() < self._eps:
            self._current_plan = []
        else:
            if self._env.single_screen:
                self._current_plan = [x.value for x in Room01.go("left", 0)]
            else:
                self._current_plan = [x.value for x in plan]

            if np.random.rand() < self._clip_eps:
                clip_idx = np.random.randint(len(self._current_plan))
                self._current_plan = self._current_plan[:clip_idx]

        return self.observation, {}

    def step(self, option):
        if self.max_timesteps != -1 and (self.executed_opt_count >= self.max_timesteps):
            return self.observation, 0, True, True, {}

        skill = self.options[option]
        if not skill.can_run(self._state):
            raise RuntimeError('Cannot execute skill {} in state:\n{}'.format(
                skill, pformat(self._state)))
        n_steps = 0
        cum_reward = 0
        while True:
            try:
                action = skill.policy(self._state)
            except Exception as e:
                raise SkillPolicyFailed(e)
            _, reward, done, info = self._env.step(action)
            cum_reward += reward
            self._state = self._env.getState()
            self._frame = self._env.getFrame()

            try:
                terminated = skill.is_done(self._state)
            except Exception as e:
                raise SkillCheckDoneFailed(e)

            n_steps += 1
            if done or terminated:
                break

            # If the option executed for a very long time, we might have a bug. So issue a warning
            n_step_limit = 1000
            if n_steps > n_step_limit:
                raise SkillRanTooLong('Option execution exceeded {} steps'.format(n_step_limit))

        self.executed_opt_count += 1
        info["steps"] = n_steps

        return self.observation, cum_reward, done, False, info

    def close(self):
        self._env.close()

    def available_options(self) -> np.ndarray:
        """
        Return a binary array specifying which options can be run at the current state
        """
        if self._state is None:
            return np.zeros(shape=(len(self.options), ))
        return np.array([int(option.can_run(self._state)) for option in self.options])

    def can_execute(self, option: Union[int, Option]) -> bool:
        """
        Returns whether the given option is runnable at the current state
        """
        return self.available_options()[option] == 1

    def get_action_meanings(self):
        return self._option_names

def make_monte_env_as_atari_deepmind(max_episode_steps=None,
                                     episode_life=True,
                                     clip_rewards=False,
                                     frame_skip=4,
                                     frame_stack=4,
                                     frame_warp=(84, 84),
                                     **kwargs):
    """
    Wraps a skills Montezuma Env in the wrappers to emulate the deepmind style wrappers seen in:
    https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py#L275
    """
    # Structure looks like:
    # Wrappers --> SteveEnv -> MonteSkillsEnv => AtariSkillsEnv -> NoopWrapper --> BaseEnv (Gym)
    # Where the double arrow (=>) is subclassing, the longer arrows (-->) are true wrappers, and
    # the while the others (->) are attributes (self.env)

    # NOTE: NOOP is done in the internal env in gym_montezuma/montezuma_env.py
    env = MontezumasRevengeEnv(**kwargs)
    env = wrap.WarpFrame(env, frame_warp)
    env = wrap.PyTorchFrame(env)

    if max_episode_steps is not None:
        env = wrap.TimeLimit(env, max_episode_steps=max_episode_steps)
    if episode_life:
        env = wrap.EpisodicLifeEnv(env)

    # if 'FIRE' in env.unwrapped.get_action_meanings():
    #     env = FireResetEnv(env)

    # NOTE: there are no primitive actions so this wrapper does not make sense
    # https://github.com/openai/baselines/issues/240 claims that this is not necessary anymore

    if clip_rewards:
        env = wrap.ClipRewardEnv(env)

    env = wrap.FrameSkipStackPool(env, skip=frame_skip, stack=frame_stack)
    return env
