import gym
from gym_minigrid.minigrid import Door


class MinigridThreeActionsWrapper(gym.ActionWrapper):
    def __init__(self, env, max_steps_multiplier=1):
        super().__init__(env)
        self.action_space = gym.spaces.Discrete(3)
        self.equalities = [
            ([0, 1], [1, 0]),
            ([0, 1], []),
            ([0, 0], [1, 1]),
            ([1], [0, 0, 0]),
            ([0], [1, 1, 1]),
            ([0, 0, 0, 0], []),
            ([1, 1, 1, 1], []),
        ]
        self.env.unwrapped.max_steps *= max_steps_multiplier

    def action(self, action):
        if action > 2:
            raise ValueError(
                f"Expected action in [0,1,2], {action} is invalid action for this environment"
            )
        return action


class MinigridDoorKeyWrapper(gym.ActionWrapper):
    def __init__(self, env, max_steps_multiplier=0.9):
        super().__init__(env)
        self.env.unwrapped.max_steps *= max_steps_multiplier
        self.env.unwrapped.max_steps = int(self.env.unwrapped.max_steps)
        self.action_space = gym.spaces.Discrete(5)
        self.equalities = [
            ([1, 0], []),
            ([0, 1], []),
            ([0, 0], [1, 1]),
            ([3, 3], [3]),
            ([4, 4], [4]),
        ]
        self.num_actions = 5

    def action(self, action):
        if action > 4:
            raise ValueError(
                f"Expected action in [0,1,2,3,4], {action} is invalid action for this environment"
            )
        if action == 4:
            action = 5

        return action


class MinigridDoorKeyTabularWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.num_states = self.unwrapped.width * self.unwrapped.height * 4 * 2 * 2
        self.observation_space = gym.spaces.Discrete(self.num_states)
        self.timeout = self.env.unwrapped.max_steps
        self.gamma = 0.99
        for e in self.unwrapped.grid.grid:
            if isinstance(e, Door):
                self.door = e

    def observation(self, observation):
        return (self.unwrapped.carrying is None) + 2 * (
            self.door.is_open
            + 2
            * (
                self.unwrapped.agent_dir
                + 4
                * (
                    self.unwrapped.agent_pos[1] * self.unwrapped.width
                    + self.unwrapped.agent_pos[0]
                )
            )
        )
