import gymnasium
import numpy as np
from collections import deque
import cv2
from einops import rearrange
import copy


class LifeLossInfo(gymnasium.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.lives_info = None

    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)
        current_lives_info = info["lives"]
        if current_lives_info < self.lives_info:
            info["life_loss"] = True
            self.lives_info = info["lives"]
        else:
            info["life_loss"] = False

        return observation, reward, terminated, truncated, info

    def reset(self, **kwargs):
        observation, info = self.env.reset(**kwargs)
        self.lives_info = info["lives"]
        info["life_loss"] = False
        return observation, info


class SeedEnvWrapper(gymnasium.Wrapper):
    def __init__(self, env, seed):
        super().__init__(env)
        self.seed = seed
        self.env.action_space.seed(seed)

    def reset(self, **kwargs):
        kwargs["seed"] = self.seed
        obs, _ = self.env.reset(**kwargs)
        return obs, _

    def step(self, action):
        return self.env.step(action)


class MaxLast2FrameSkipWrapper(gymnasium.Wrapper):
    def __init__(self, env, skip=4):
        super().__init__(env)
        self.skip = skip

    def reset(self, **kwargs):
        obs, _ = self.env.reset(**kwargs)
        return obs, _

    def step(self, action):
        total_reward = 0
        self.obs_buffer = deque(maxlen=2)
        for _ in range(self.skip):
            obs, reward, done, truncated, info = self.env.step(action)
            self.obs_buffer.append(obs)
            total_reward += reward
            if done or truncated:
                break
        if len(self.obs_buffer) == 1:
            obs = self.obs_buffer[0]
        else:
            obs = np.max(np.stack(self.obs_buffer), axis=0)
        return obs, total_reward, done, truncated, info


class NoopResetEnv(gymnasium.Wrapper):
    def __init__(self, env, noop_max=30):
        """Sample initial states by taking random number of no-ops on reset.
        No-op is assumed to be action 0.
        """
        super().__init__(env)
        self.noop_max = noop_max
        self.noop_action = 0
        assert env.unwrapped.get_action_meanings()[0] == 'NOOP'

    def reset(self):
        """ Do no-op action for a number of steps in [1, noop_max]."""
        self.env.reset()
        noops = np.random.randint(1, self.noop_max + 1)
        obs = None
        for _ in range(noops):
            obs, reward, done, truncated, info = self.env.step(self.noop_action)
            if done or truncated:
                obs = self.env.reset()
        return obs, info

    def step(self, action):
        return self.env.step(action)


def build_single_env(env_name, image_size):
    env = gymnasium.make(env_name, full_action_space=True, frameskip=1)
    from gymnasium.wrappers import AtariPreprocessing
    env = AtariPreprocessing(env, screen_size=image_size, grayscale_obs=False)
    return env


def build_vec_env(env_list, image_size, num_envs):
    # lambda pitfall refs to: https://python.plainenglish.io/python-pitfalls-with-variable-capture-dcfc113f39b7
    assert num_envs % len(env_list) == 0
    env_fns = []
    vec_env_names = []
    for env_name in env_list:
        def lambda_generator(env_name, image_size):
            return lambda: build_single_env(env_name, image_size)
        env_fns += [lambda_generator(env_name, image_size) for i in range(num_envs//len(env_list))]
        vec_env_names += [env_name for i in range(num_envs//len(env_list))]
    vec_env = gymnasium.vector.AsyncVectorEnv(env_fns=env_fns)
    return vec_env, vec_env_names


if __name__ == "__main__":
    # env = gymnasium.make("ALE/MsPacman-v5", full_action_space=True)
    # env = ResizeWrapper(env, size=(128, 128))
    # env = NormalizeAsFP16Wrapper(env)
    # env = ExpandChannelWithDynamicVisionWrapper(env, tau=0.1, size=(128, 128))
    # env = FrameSkipWrapper(env, skip=4)
    # env = RearrangeWrapper(env, "h w c -> c h w")
    # env_fns = [lambda:  gymnasium.make("ALE/Pong-v5", full_action_space=False) for i in range(8)]
    # env_fns = [lambda:  debug_build_env("ALE/Pong-v5", 64) for i in range(2)]\
    #     + [lambda:  debug_build_env("ALE/IceHockey-v5", 64) for i in range(2)]\
    #     + [lambda:  debug_build_env("ALE/Breakout-v5", 64) for i in range(2)]\
    #     + [lambda:  debug_build_env("ALE/Tennis-v5", 64) for i in range(2)]
    # vec_env = gymnasium.vector.AsyncVectorEnv(env_fns=env_fns)
    vec_env, vec_env_names = build_vec_env(['ALE/Pong-v5', 'ALE/IceHockey-v5', 'ALE/Breakout-v5', 'ALE/Tennis-v5'], 64, num_envs=8)
    current_obs, _ = vec_env.reset()
    while True:
        action = vec_env.action_space.sample()
        obs, reward, done, truncated, info = vec_env.step(action)
        # done = done or truncated
        if done.any():
            print("---------")
            print(reward)
            print(info["episode_frame_number"])
        cv2.imshow("Pong", current_obs[0])
        cv2.imshow("IceHockey", current_obs[2])
        cv2.imshow("Breakout", current_obs[4])
        cv2.imshow("Tennis", current_obs[6])
        cv2.waitKey(40)
        current_obs = obs
        # if done:
        #     break
