import numpy as np

from gymnasium import spaces, Wrapper
from gymnasium.core import ObservationWrapper

from minigrid.core import constants
constants.COLORS = COLORS = {
    "red": np.array([255, 0, 0]),
    "green": np.array([0, 255, 0]),
    "blue": np.array([0, 0, 255]),
    "purple": np.array([112, 39, 195]),
    "yellow": np.array([255, 255, 0]),
    "grey": np.array([100, 100, 100]),
    "orange": np.array([255, 165, 0]),   # Bright orange, added
    "cyan": np.array([0, 255, 255]),    # Cyan, added
    "pink": np.array([255, 105, 180]),  # Hot pink, added
}

constants.COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5,
                "orange": 6, "cyan":7, "pink":8}



from minigrid.core.constants import OBJECT_TO_IDX, COLOR_TO_IDX, COLORS


# A backwards compatibility wrapper so that RLlib can continue using the old deprecated Gym API
class GymCompatWrapper(Wrapper):
    def reset(self, **kwargs):
        obs, _ = self.env.reset(**kwargs)

        return obs

    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)

        # Since RLlib doesn't support the truncated variable (yet), incorporate it into terminated
        terminated = terminated or truncated

        return observation, reward, terminated, info

class PartialObsWrapper(ObservationWrapper):
    """
    Wrapper to use partially observable observation.
    """

    def __init__(self, env, tile_size=8):
        super().__init__(env)

        og_image = self.observation_space['image']
        og_direction = self.observation_space['direction']

        self.observation_space = spaces.Dict({
            "image": og_image,
            "direction": og_direction
        })

    def observation(self, obs):
        env = self.unwrapped
        partial_obs = env.gen_obs()

        return {'image': partial_obs['image'], 'direction': partial_obs['direction']}

class GoalOffsetWrapper(ObservationWrapper):
    """
    Wrapper to add a tuple containing the goal offset
    """

    def __init__(self, env, tile_size=8):
        super().__init__(env)

        self.observation_space = spaces.Dict({
            **self.observation_space.spaces,
            "goal_offset": spaces.Box(
                low=np.array([-self.env.width, -self.env.height]), 
                high=np.array([self.env.width, self.env.height]), 
                shape=(2,))
        })

    def observation(self, obs):
        env = self.unwrapped
        
        assert len(env.instrs.desc.obj_poss) == 1 # currently only considering environments with one goal
        agent_pos = env.agent_pos
        goal_pos = env.instrs.desc.obj_poss[0]

        return {**obs, "goal_offset": np.array([agent_pos[0] - goal_pos[0], agent_pos[1] - goal_pos[1]])}

class FullyObsWrapper(ObservationWrapper):
    """
    Fully observable gridworld using a compact grid encoding
    """

    def __init__(self, env):
        super().__init__(env)

        self.observation_space = spaces.Dict({
            "image": spaces.Box(
                low=0,
                high=255,
                shape=(self.env.width, self.env.height, 3),
                dtype="uint8"),
        })

    def observation(self, obs):
        env = self.unwrapped
        full_grid = env.grid.encode()
        full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array(
            [OBJECT_TO_IDX["agent"], COLOR_TO_IDX["red"], env.agent_dir]
        )
        return {"image": full_grid}


class RGBImgObsWrapper(ObservationWrapper):
    """
    almostly the same as RGBImgObsWrapper in minigrid, but only keep the "image" in obs
    
    Wrapper to use fully observable RGB image as observation,
    This can be used to have the agent to solve the gridworld in pixel space.

    Example:
        >>> import gymnasium as gym
        >>> import matplotlib.pyplot as plt
        >>> from minigrid.wrappers import RGBImgObsWrapper
        >>> env = gym.make("MiniGrid-Empty-5x5-v0")
        >>> obs, _ = env.reset()
        >>> plt.imshow(obs['image'])  # doctest: +SKIP
        ![NoWrapper](../figures/lavacrossing_NoWrapper.png)
        >>> env = RGBImgObsWrapper(env)
        >>> obs, _ = env.reset()
        >>> plt.imshow(obs['image'])  # doctest: +SKIP
        ![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
    """

    def __init__(self, env, tile_size=8):
        super().__init__(env)

        self.tile_size = tile_size

        new_image_space = spaces.Box(
            low=0,
            high=255,
            shape=(self.env.width * tile_size, self.env.height * tile_size, 3),
            dtype="uint8",
        )

        self.observation_space = spaces.Dict({"image": new_image_space})

    def observation(self, obs):
        rgb_img = self.get_frame(highlight=True, tile_size=self.tile_size)

        return {"image": rgb_img}


class ActionMasking(ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)

        # The action mask sets a value for each action of either 0 (invalid) or 1 (valid).
        self.observation_space = spaces.Dict({
            **self.observation_space.spaces,
            "action_mask": spaces.Box(0.0, 1.0, shape=(self.action_space.n,))
        })

    def observation(self, obs):
        action_mask = np.ones(self.action_space.n)

        # Look at the position directly in front of the agent
        front_pos = self.unwrapped.front_pos
        full_grid = self.unwrapped.grid.encode()
        front_pos_type = full_grid[front_pos[0]][front_pos[1]][0]

        if front_pos_type in [OBJECT_TO_IDX[obj] for obj in ["wall"]]:
            action_mask[self.env.actions.forward.value] = 0.0

        if front_pos_type != OBJECT_TO_IDX["key"]:
            action_mask[self.env.actions.pickup.value] = 0.0

        if front_pos_type != OBJECT_TO_IDX["door"]:
            action_mask[self.env.actions.toggle.value] = 0.0

        # Now disable actions that we intend to never use
        action_mask[self.env.actions.drop.value] = 0.0
        action_mask[self.env.actions.done.value] = 0.0
        
        return {**obs, "action_mask": action_mask.astype(np.int8)}


class DoorUnlockBonus(Wrapper):
    def __init__(self, env):
        super().__init__(env)

    def step(self, action):
        obs = self.unwrapped.grid.encode()

        # If we just unlocked a door, add a reward shaping bonus.
        front_pos = self.unwrapped.front_pos
        front_pos_type = obs[front_pos[0]][front_pos[1]][0]
        front_pos_state = obs[front_pos[0]][front_pos[1]][2]

        if front_pos_type == OBJECT_TO_IDX["door"] and front_pos_state == 2:
            is_locked_door = True
        else:
            is_locked_door = False

        obs, reward, done, info = self.env.step(action)
        
        bonus = 0.0
        if is_locked_door and action == self.env.Actions.toggle:
            front_pos_state = obs["image"][front_pos[0]][front_pos[1]][2]
            if front_pos_state == 0:
                bonus = 0.5

        reward += bonus

        return obs, reward, done, info

class ActionBonus(Wrapper):
    """
    Wrapper which adds an exploration bonus.
    This is a reward to encourage exploration of less
    visited (state,action) pairs.
    Example:
        >>> import miniworld
        >>> import gymnasium as gym
        >>> from minigrid.wrappers import ActionBonus
        >>> env = gym.make("MiniGrid-Empty-5x5-v0")
        >>> _, _ = env.reset(seed=0)
        >>> _, reward, _, _, _ = env.step(1)
        >>> print(reward)
        0
        >>> _, reward, _, _, _ = env.step(1)
        >>> print(reward)
        0
        >>> env_bonus = ActionBonus(env)
        >>> _, _ = env_bonus.reset(seed=0)
        >>> _, reward, _, _, _ = env_bonus.step(1)
        >>> print(reward)
        1.0
        >>> _, reward, _, _, _ = env_bonus.step(1)
        >>> print(reward)
        1.0
    """

    def __init__(self, env):
        """A wrapper that adds an exploration bonus to less visited (state,action) pairs.
        Args:
            env: The environment to apply the wrapper
        """
        super().__init__(env)
        self.counts = {}

    def step(self, action):
        """Steps through the environment with `action`."""
        obs, reward, terminated, info = self.env.step(action)

        env = self.unwrapped
        tup = (tuple(env.agent_pos), env.agent_dir, action)

        # Get the count for this (s,a) pair
        pre_count = 0
        if tup in self.counts:
            pre_count = self.counts[tup]

        # Update the count for this (s,a) pair
        new_count = pre_count + 1
        self.counts[tup] = new_count

        bonus = 1 / np.sqrt(new_count)
        reward += bonus

        return obs, reward, terminated, info

    def reset(self, **kwargs):
        """Resets the environment with `kwargs`."""
        return self.env.reset(**kwargs)


class StateBonus(Wrapper):
    """
    Adds an exploration bonus based on which positions
    are visited on the grid.
    Example:
        >>> import miniworld
        >>> import gymnasium as gym
        >>> from minigrid.wrappers import StateBonus
        >>> env = gym.make("MiniGrid-Empty-5x5-v0")
        >>> _, _ = env.reset(seed=0)
        >>> _, reward, _, _, _ = env.step(1)
        >>> print(reward)
        0
        >>> _, reward, _, _, _ = env.step(1)
        >>> print(reward)
        0
        >>> env_bonus = StateBonus(env)
        >>> obs, _ = env_bonus.reset(seed=0)
        >>> obs, reward, terminated, truncated, info = env_bonus.step(1)
        >>> print(reward)
        1.0
        >>> obs, reward, terminated, truncated, info = env_bonus.step(1)
        >>> print(reward)
        0.7071067811865475
    """

    def __init__(self, env):
        """A wrapper that adds an exploration bonus to less visited positions.
        Args:
            env: The environment to apply the wrapper
        """
        super().__init__(env)
        self.counts = {}

    def step(self, action):
        """Steps through the environment with `action`."""
        obs, reward, terminated, info = self.env.step(action)

        # Tuple based on which we index the counts
        # We use the position after an update
        env = self.unwrapped
        tup = tuple(env.agent_pos)

        # Get the count for this key
        pre_count = 0
        if tup in self.counts:
            pre_count = self.counts[tup]

        # Update the count for this key
        new_count = pre_count + 1
        self.counts[tup] = new_count

        bonus = 1 / np.sqrt(new_count)
        reward += bonus

        return obs, reward, terminated, info

    def reset(self, **kwargs):
        """Resets the environment with `kwargs`."""
        return self.env.reset(**kwargs)