from gym_minigrid.window import Window
from gym_minigrid import wrappers as mg_wrappers
from gym import wrappers

import gym
import minihack
import numpy as np

import xxhash
import cv2


class VectorizedEnvironment:
    def __init__(self, num_levels):
        self.num_levels = num_levels

        # Create environments
        self.envs = []

        # Create windows for environments
        self.windows = []

    def show_windows(self, state):
        for i in range(self.num_levels):
            self.windows[i].show_img(state[i])

    def get_action_space(self):
        return self.envs[0].action_space.n

    def reset(self):
        state_array = []
        for env in self.envs:
            state_array.append(env.reset())

        return state_array

    def step(self, action):
        state_array = []
        reward_array = []
        done_array = []
        info_array = []

        for env in envs:
            next_state, reward, done, infos = env.step(action)
            state_array.append(next_state)
            reward_array.append(reward)
            done_array.append(done)
            info_array.append(infos)

        return state_array, reward_array, done_array, info_array


class HashMHMVectorizedEnvironment(VectorizedEnvironment):
    def __init__(self, num_levels, env_name):
        super().__init__(num_levels)

        observation_keys = ("pixel_crop", "glyphs_crop", "message", "inv_glyphs")
        env = gym.make(env_name, observation_keys=observation_keys)

        self.envs.append(env)

    def reset(self):
        state_array = []
        obs_array = []

        obs1 = self.envs[0].reset()

        for obs, obs_g, obs_m, obs_i in zip(obs1["pixel_crop"], obs1["glyphs_crop"], obs1["message"], obs1["inv_glyphs"]):
            obs_array.append(obs)
            state_array.append( xxhash.xxh64(obs).hexdigest()+ xxhash.xxh64(obs_m).hexdigest() + xxhash.xxh64(np.ascontiguousarray(obs_g)).hexdigest())

        return state_array, obs_array

    def step(self, action):
        state_array = []
        reward_array = []
        done_array = []
        info_array = []
        obs_array = []

        next_state1, reward, done, infos = self.envs[0].step(action)

        for next_state, next_state_g, next_state_m, next_state_i in zip(next_state1["pixel_crop"], next_state1["glyphs_crop"], next_state1["message"], next_state1["inv_glyphs"]):
            obs_array.append(next_state)
            state_array.append( xxhash.xxh64(next_state).hexdigest() + xxhash.xxh64(next_state_m).hexdigest() + xxhash.xxh64(np.ascontiguousarray(next_state_g)).hexdigest())

        for i in range(self.num_levels):
            reward_array.append(reward)
            done_array.append(done)
            info_array.append(infos)

        return state_array, reward_array, done_array, info_array, obs_array


class HashVectorizedPartialEnvironment(VectorizedEnvironment):
    def __init__(self, num_levels, env_name):
        super().__init__(num_levels)

        env = gym.make(env_name)

        env1 = mg_wrappers.RGBImgPartialObsWrapper(env)
        #env1 = mg_wrappers.ImgObsWrapper(env1)

        env2 = mg_wrappers.RGBImgObsWrapper(env)

        env3 = mg_wrappers.RGBImgNxNObsWrapper(env)

        self.envs.append(env2)
        self.envs.append(env1)
        self.envs.append(env3)

    def reset(self):
        state_array = []
        obs_array = []

        obs1 = self.envs[1].reset()['image']
        obs2 = self.envs[2].observation(self.envs[1].gen_obs())['image']

        for int_obs in obs2:
            obs_array.append(int_obs)
            state_array.append(xxhash.xxh64(int_obs).hexdigest())

        obs_array.append(obs1)
        state_array.append(xxhash.xxh64(obs1).hexdigest())

        return state_array, obs_array

    def step(self, action):
        state_array = []
        reward_array = []
        done_array = []
        info_array = []
        obs_array = []

        next_state1, reward, done, infos = self.envs[1].step(action)
        next_state2 = self.envs[2].observation(self.envs[1].gen_obs())['image']

        for int_next_state in next_state2:
            obs_array.append(int_next_state)
            state_array.append(xxhash.xxh64(int_next_state).hexdigest())

        obs_array.append(next_state1['image'])
        state_array.append(xxhash.xxh64(next_state1['image']).hexdigest())

        for i in range(self.num_levels):
            reward_array.append(reward)
            done_array.append(done)
            info_array.append(infos)

        return state_array, reward_array, done_array, info_array, obs_array


class HashVectorizedEnvironment(VectorizedEnvironment):
    def __init__(self, num_levels, env_name):
        super().__init__(num_levels)

        env = gym.make(env_name)

        env1 = mg_wrappers.RGBImgPartialObsWrapper(env)
        env1 = mg_wrappers.ImgObsWrapper(env1)

        env2 = mg_wrappers.RGBImgObsWrapper(env)

        env3 = mg_wrappers.RGBImgNxNObsWrapper(env)

        self.envs.append(env2)
        self.envs.append(env1)
        self.envs.append(env3)

    def reset(self):
        state_array = []
        obs_array = []

        obs1 = self.envs[1].reset()
        obs0 = self.envs[0].observation(self.envs[1].gen_obs())['image']
        obs2 = self.envs[2].observation(self.envs[1].gen_obs())['image']

        obs_array = [obs0, obs1]
        state_array = [xxhash.xxh64(obs0).hexdigest(), xxhash.xxh64(obs1).hexdigest()]

        for int_obs in obs2:
            obs_array.append(int_obs)
            state_array.append(xxhash.xxh64(int_obs).hexdigest())

        return state_array, obs_array

    def step(self, action):
        state_array = []
        reward_array = []
        done_array = []
        info_array = []
        obs_array = []

        next_state1, reward, done, infos = self.envs[1].step(action)

        next_state0 = self.envs[0].observation(self.envs[1].gen_obs())['image']
        next_state2 = self.envs[2].observation(self.envs[1].gen_obs())['image']

        obs_array = [next_state0, next_state1]
        state_array = [xxhash.xxh64(next_state0).hexdigest(), xxhash.xxh64(next_state1).hexdigest()]

        for int_next_state in next_state2:
            obs_array.append(int_next_state)
            state_array.append(xxhash.xxh64(int_next_state).hexdigest())

        for i in range(self.num_levels):
            reward_array.append(reward)
            done_array.append(done)
            info_array.append(infos)

        return state_array, reward_array, done_array, info_array, obs_array
