from typing import Dict, Tuple, List
from minatar import Environment
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs import register
from gymnasium.vector import AsyncVectorEnv
from gymnasium.wrappers import RecordEpisodeStatistics, TimeLimit


class TransposeObservation(gym.ObservationWrapper):
    """
    Change observation dimension from HxWxC to CxHxW

    :param env: the environment
    """

    def __init__(self, env: gym.Env):
        gym.ObservationWrapper.__init__(self, env)
        obs_space = env.observation_space
        self.observation_space = spaces.Box(
            low=np.transpose(env.observation_space.low, [2, 0, 1]),
            high=np.transpose(env.observation_space.high, [2, 0, 1]),
            shape=(obs_space.shape[2], obs_space.shape[0], obs_space.shape[1]),
            dtype=env.observation_space.dtype
        )

    def observation(self, frame: np.ndarray) -> np.ndarray:
        return np.transpose(frame, [2, 0, 1])


def make_env(env: str):
    e = gym.make(env)
    e = RecordEpisodeStatistics(TimeLimit(e, 108000))
    return TransposeObservation(e)


def make_vec_env(env: str, n_envs: int):

    return AsyncVectorEnv([lambda: make_env(env) for _ in range(n_envs)])


class BaseEnv(gym.Env):
    """
    MinAtar with gymnasium API
    """

    metadata = {"render_modes": ["human", "array"]}

    def __init__(self, game, display_time=50, use_minimal_action_set=False, **kwargs):
        self.game_name = game
        self.display_time = display_time

        self.game_kwargs = kwargs
        self.seed()

        if use_minimal_action_set:
            self.action_set = self.game.minimal_action_set()
        else:
            self.action_set = list(range(self.game.num_actions()))

        self.action_space = spaces.Discrete(len(self.action_set))
        self.observation_space = spaces.Box(
            0.0, 1.0, shape=self.game.state_shape(), dtype=bool
        )

    def step(self, action):
        action = self.action_set[action]
        reward, done = self.game.act(action)
        return self.game.state(), reward, done, False, {}

    def reset(self, seed=None, options=None):
        if(seed is not None):
            self.game = Environment(
                env_name=self.game_name,
                random_seed=seed,
                **self.game_kwargs
            )
        self.game.reset()
        return self.game.state(), {}

    def seed(self, seed=None):
        self.game = Environment(
            env_name=self.game_name,
            random_seed=seed,
            **self.game_kwargs
        )
        return seed

    def render(self, mode="human"):
        if mode == "array":
            return self.game.state()
        elif mode == "human":
            self.game.display_state(self.display_time)

    def close(self):
        if self.game.visualized:
            self.game.close_display()
        return 0


for game in ["asterix", "breakout", "freeway", "seaquest", "space_invaders"]:
    name = game.title().replace('_', '')
    register(
        id="{}-MinAtar-v0".format(name),
        entry_point=BaseEnv,
        kwargs=dict(
            game=game,
            display_time=50,
            use_minimal_action_set=True,
            sticky_action_prob=0,
            difficulty_ramping=False
        ),
    )
