from __future__ import annotations

from minigrid.core.constants import COLOR_NAMES
from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Door, Goal, Key, Wall, Box, Ball, Lava
from minigrid.manual_control import ManualControl
from minigrid.minigrid_env import MiniGridEnv

import numpy as np
import pygame


class ManyDoorsEnv(MiniGridEnv):
    def __init__(
        self,
        size=13,
        agent_start_pos=(1, 2),
        agent_start_dir=0,
        agent_view_size=3,
        max_steps: int | None = None,
        identity=1,  # identity of the reward function
        position_reward=True,
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self.agent_view_size = agent_view_size

        self.position_reward = position_reward

        mission_space = MissionSpace(mission_func=self._gen_mission)

        if max_steps is None:
            max_steps = 1000

        if identity == 1:
            self.door_rewards = [1, 0]  # reward for opening each door
        elif identity == 2:
            self.door_rewards = [0, 1]  # reward for opening each door

        # door pairs are presented to the agent, who has to decide which to enter
        self.door_pairs = []
        for i in range(10):
            door1 = Door(COLOR_NAMES[3], is_open=False, is_locked=False)
            door2 = Door(COLOR_NAMES[5], is_open=False, is_locked=False)
            door1.other_door = door2
            door2.other_door = door1
            self.door_pairs.append([door1, door2])

        self.goal = Goal()
        self.identity = identity

        self.door1_opened = False
        self.door2_opened = False

        super().__init__(
            mission_space=mission_space,
            grid_size=size,
            # Set this to True for maximum speed
            agent_view_size=agent_view_size,
            see_through_walls=True,
            max_steps=max_steps,
            **kwargs,
        )

    @staticmethod
    def _gen_mission():
        return "What Doors Will you Go Through?"

    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        self.door1 = Door(COLOR_NAMES[1], is_open=False, is_locked=False)
        self.door2 = Door(COLOR_NAMES[2], is_open=False, is_locked=False)
        self.goal = Goal()

        for row in range(2, 11, 4):
            for col in range(2, 11, 4):
                pair_num = ((row - 1) // 4) * 3 + (col - 1) // 4
                # random number to control who is ontop
                if np.random.rand() > 0.5:
                    self.grid.set(col, row - 1, self.door_pairs[pair_num][1])
                    self.grid.set(col, row + 1, self.door_pairs[pair_num][0])
                else:
                    self.grid.set(col, row - 1, self.door_pairs[pair_num][0])
                    self.grid.set(col, row + 1, self.door_pairs[pair_num][1])
                self.grid.wall_rect(col, row, 1, 1)
                self.grid.wall_rect(col + 2, row - 1, 1, 1)
                self.grid.wall_rect(col + 2, row + 1, 1, 1)
            if row == 2:
                self.grid.wall_rect(1, row + 2, 10, 1)
            else:
                self.grid.wall_rect(2, row + 2, 10, 1)

        self.grid.set(11, 10, self.goal)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "What Doors Will you Go Through?"

    def step(self, action):
        obs, reward, terminated, truncated, info = super().step(action)

        # reward agent slightly for getting closer to bottom right corner
        if self.position_reward:
            reward += 0.05 * (self.agent_pos[0] / self.grid.width)
        else:
            reward = 0

        reward -= 0.1

        # reward if agent opens a door
        for door1, door2 in self.door_pairs:
            if door1.is_open and not door2.is_locked:
                reward += self.door_rewards[0]
                door2.is_locked = True
            elif door2.is_open and not door1.is_locked:
                reward += self.door_rewards[1]
                door1.is_locked = True
            else:
                # nothing
                reward += 0

        return obs, reward, terminated, truncated, info
        # give the agent a reward based on the doors collected

    def reset(self, seed, options=None):
        # reset all the door locations to be random again
        for row in range(2, 11, 4):
            for col in range(2, 11, 4):
                pair_num = ((row - 1) // 4) * 3 + (col - 1) // 4
                # random number to control who is ontop
                if np.random.rand() > 0.5:
                    self.grid.set(col, row - 1, self.door_pairs[pair_num][1])
                    self.grid.set(col, row + 1, self.door_pairs[pair_num][0])
                else:
                    self.grid.set(col, row - 1, self.door_pairs[pair_num][0])
                    self.grid.set(col, row + 1, self.door_pairs[pair_num][1])
                self.grid.wall_rect(col, row, 1, 1)
                self.grid.wall_rect(col + 2, row - 1, 1, 1)
                self.grid.wall_rect(col + 2, row + 1, 1, 1)
            if row == 2:
                self.grid.wall_rect(1, row + 2, 10, 1)
            else:
                self.grid.wall_rect(2, row + 2, 10, 1)

        # all doors are closed
        for door1, door2 in self.door_pairs:
            door1.is_open = False
            door2.is_open = False
            door1.is_locked = False
            door2.is_locked = False

        obs, info = super().reset()

        return obs, info


def main():
    env = ManyDoorsEnv(identity=2, render_mode="human")

    env.highlight = False

    # enable manual control for testing
    manual_control = ManualControl(env, seed=42)
    manual_control.start()


if __name__ == "__main__":
    main()
