from __future__ import annotations
import numpy as np
from minigrid.core import constants
constants.COLORS  = {
    "red": np.array([255, 0, 0]),
    "green": np.array([0, 255, 0]),
    "blue": np.array([0, 0, 255]),
    "purple": np.array([112, 39, 195]),
    "yellow": np.array([255, 255, 0]),
    "grey": np.array([100, 100, 100]),
    "orange": np.array([255, 165, 0]),   # Bright orange, added
    "cyan": np.array([0, 255, 255]),    # Cyan, added
    "pink": np.array([255, 105, 180]),  # Hot pink, added
}

constants.COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5,
                "orange": 6, "cyan":7, "pink":8}

import importlib
from minigrid.core import grid, mission, world_object
importlib.reload(grid)
importlib.reload(mission)
importlib.reload(world_object)
from minigrid.core.constants import COLOR_NAMES
from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Door, Goal, Key, Wall, Floor
from minigrid.minigrid_env import MiniGridEnv


import gymnasium as gym
import numpy as np
from PIL import Image

from environments.grid.MultiRoomGridColor import MultiRoomGridColor


# an example of a custimized environment
class SimpleEnv(MiniGridEnv):
    def __init__(
        self,
        size=10,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        max_steps: int | None = None,
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir

        mission_space = MissionSpace(mission_func=self._gen_mission)

        if max_steps is None:
            max_steps = 4 * size**2

        super().__init__(
            mission_space=mission_space,
            grid_size=size,
            # Set this to True for maximum speed
            see_through_walls=True,
            max_steps=max_steps,
            **kwargs,
        )

    @staticmethod
    def _gen_mission():
        return "grand mission"

    def _gen_grid(self, width, height):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Generate verical separation wall
        for i in range(0, height):
            self.grid.set(5, i, Wall())
        
        # Place the door and key
        self.grid.set(5, 6, Door(COLOR_NAMES[0], is_locked=True))
        self.grid.set(3, 6, Key(COLOR_NAMES[0]))

        # Place a goal square in the bottom-right corner
        self.put_obj(Goal(), width - 2, height - 2)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "grand mission"



# left, top, right, bottom. wwww is a closed room.
# wwwo,owww is two connected rooms with an opening between them.
# wwwo,owwd/wwwo,odww is 4 connected rooms. top 2 have an opening. right 2 have a door connecting them.
# note that adjacent rooms must match up with the wall type.
# Give this as a list
# [["wwwo", "owwd"], ["wwwo", "odww"]].
class MultiRoomGrid(MiniGridEnv):
    def __init__(self, config, start_rooms, goal_rooms, room_size=3, max_steps=100, see_through_walls=False,
                 modified_settings=None, modified_color={"random": False, "all_color": False},
                 agent_key_same_room: bool = False, agent_goal_same_room: bool = False,
                 wrong_direction: bool = False,
                 agent_position = None,
                 agent_direction = None,
                 goal_position = None,
                 key_required = None,
                 **kwargs):

        self.num_rows = len(config)
        self.num_cols = len(config[0])
        self.room_size = room_size
        self.start_rooms = start_rooms
        self.goal_rooms = goal_rooms
        self.config = config
        self.max_tries = 100
        self.agent_key_same_room = agent_key_same_room # agent and the key in the same room
        self.agent_goal_same_room = agent_goal_same_room # agent and the goal in the same room
        self.wrong_direction = wrong_direction # if the agent faces the wrong direction
        self.init_agent_position = agent_position # initial position for agent
        self.init_agent_direction = agent_direction # initial direction for agent
        self.default_goal_position = goal_position # default position for the goal
        self.key_required = key_required

        self.width = (self.num_cols * room_size) + (self.num_cols + 1) # Sum of room sizes + 1 space extra for walls.
        self.height = (self.num_rows * room_size) + (self.num_rows + 1) # Sum of room sizes + 1 space extra for walls.
        self.modified_settings = modified_settings
        self.modified_color = modified_color
        # Initialize OBJECT_TO_NEW_COLORS once during instantiation
        self.OBJECT_TO_NEW_COLORS = {}
        self._initialize_object_colors()


        # Placeholder mission space does nothing for now, since we don't want to use it.
        mission_space = MissionSpace(
            mission_func=lambda color, type: f"Unused",
            ordered_placeholders=[COLOR_NAMES, ["box", "key"]],
        )

        super().__init__(
            mission_space=mission_space,
            max_steps=max_steps,
            width=self.width,
            height=self.height,
            see_through_walls=see_through_walls
        )

        self.spec = gym.envs.registration.EnvSpec(
            id="multi_grid",
            max_episode_steps=max_steps,
            autoreset=False,
            disable_env_checker=False
        )
    
    def __str__(self):
        if self.agent_pos:
            return super().__str__()
        else:
            return 'str(env) invoked before the grid is initialized.'

    def _initialize_object_colors(self):
        """
        Initializes the OBJECT_TO_NEW_COLORS with new colors for walls, floors, and goals.
        """
        self.OBJECT_TO_NEW_COLORS["goal"] = self._new_object_color("goal")
        self.OBJECT_TO_NEW_COLORS["floor"] = self._new_object_color("floor")
        self.OBJECT_TO_NEW_COLORS["wall"] = self._new_object_color("wall")

    def _find_next_door(self):
        """
        Identifies the next door the agent should move toward based on the agent's current room.
        This is a rule-based method rather than a shortest-path computation.
        """
        agent_room = self.agent_room

        # Define the next door locations based on room locations
        room_to_door = {
            "upper": (4, 2),
            "middle": (6, 4),
            "lower": (4, 6)
        }


        # Determine which room the agent is in
        if agent_room == [0, 0]:  # Agent is in the upper-left room
            # print("Upper door")
            return room_to_door["upper"]
        elif agent_room == [0, 1]:  # Agent is in the upper-right room
            # print("Middle Door")
            return room_to_door["middle"]
        else:  # Agent is in the lower room
            # print("Lower Door")
            return room_to_door["lower"]

    def _find_correct_directions(self, agent_pos):
        """
        Determines all valid directions the agent can face to move toward the next door.
        """
        next_door = self._find_next_door()
        # print(f"Next door is {next_door} and the agnent is {agent_pos}")

        dx = next_door[0] - agent_pos[0]
        dy = next_door[1] - agent_pos[1]
        # print("dx={}, dy={}".format(dx, dy))

        valid_directions = set()

        if dx == 0:
            valid_directions.add(0)
            valid_directions.add(2)
        elif dx > 0:
            valid_directions.add(0)  # Right
        elif dx < 0:
            valid_directions.add(2)  # Left

        if dy == 0:
            valid_directions.add(1)  # Down
            valid_directions.add(3)  # Up
        elif dy > 0:
            valid_directions.add(1)  # Down
        elif dy < 0:
            valid_directions.add(3)  # Up
        # print(f"valid_directions={valid_directions}")

        return valid_directions

    def _get_wrong_direction(self, agent_pos):
        """
        Forces the agent to face a direction that does NOT lead toward the next door.
        """
        possible_directions = {0, 1, 2, 3}  # 0: right, 1: down, 2: left, 3: up
        correct_directions = self._find_correct_directions(agent_pos)

        # Remove all correct directions
        wrong_directions = list(possible_directions - correct_directions)

        return np.random.choice(wrong_directions)

    def _find_all_doors(self):
        """
        Finds the coordinates of all doors in the grid.
        """
        print(self.grid)
        doors = []
        for x in range(self.width):
            for y in range(self.height):
                obj = self.grid.get(x, y)
                print(obj)
                if isinstance(obj, Door) or isinstance(obj, Floor):
                    doors.append((x, y))
        return doors

    def reset(self, seed=None, **kwargs):
        if seed is not None:
            np.random.seed(seed)
        self._initialize_object_colors()
        return super().reset(**kwargs)

    def _sample_room(self, ul):
        try_idx = 0
        while try_idx < self.max_tries:
            loc = (np.random.randint(low=ul[0]+1, high=ul[0]+(self.room_size + 1)), np.random.randint(low=ul[1]+1, high=ul[1]+(self.room_size + 1)))

            if self.grid.get(*loc) == None and (self.agent_pos is None or not np.allclose(loc, self.agent_pos)):
                return loc

            try_idx += 1

        raise("Failed to sample point in room.")

    def _construct_room(self, room_config, ul):
        # Build default walls on all 4 sides
        if self.modified_settings is not None and "wall" in self.modified_settings:
            new_wall_color = self.OBJECT_TO_NEW_COLORS["wall"]
        else:
            new_wall_color = None
        self.grid.wall_rect(*ul, self.room_size + 2, self.room_size + 2, c=new_wall_color)

        # Examine each wall in the room config
        for dir, wall in zip(("l", "t", "r", "b"), room_config):
            # Carve out an opening or door
            if wall == "o" or wall == "d":
                if dir == "l":
                    opening_idx = (ul[0], ul[1] + (self.room_size + 2) // 2)
                elif dir == "r":
                    opening_idx = (ul[0] + self.room_size + 1, ul[1] + (self.room_size + 2) // 2)
                elif dir == "t":
                    opening_idx = (ul[0] + (self.room_size + 2) // 2, ul[1])
                elif dir == "b":
                    opening_idx = (ul[0] + (self.room_size + 2) // 2, ul[1] + self.room_size + 1)

                if wall == "o":
                    if self.modified_settings is not None and "floor" in self.modified_settings:
                        new_floor_color = self.OBJECT_TO_NEW_COLORS["floor"]
                        obj_type = Floor(color=new_floor_color)
                    else:
                        obj_type = Floor()
                else:
                    if self.is_agent_in_lower_rooms:
                        obj_type = Door("red", is_open=True, is_locked=False)
                    else:
                        obj_type = Door("red", is_open=False, is_locked=True)
                self.grid.set(*opening_idx, obj_type)

    def _new_object_color(self, obj: str):
        if self.modified_color["random"] == False:
            if obj == "wall":
                new_color = "pink"
            elif obj == "floor":
                new_color = "cyan"
            elif obj == "goal":
                new_color = "orange"
            else:
                raise ValueError(f"Unrecognized object type in the current settings: {obj}")
            return new_color
        elif self.modified_color["random"] == True and self.modified_color["all_color"] == True:
            return np.random.choice(list(constants.COLORS.keys()))
        elif self.modified_color["random"] == True and self.modified_color["all_color"] == False:
            return np.random.choice(["pink","cyan", "orange"])
        else:
            raise ValueError(f"Unrecognized color setting in the current settings: {self.modified_color}")

    def _gen_grid(self, width, height):
        # --- Agent Placement ---
        # Sample agent start location from the available start_rooms.
        agent_room_idx = np.random.choice(len(self.start_rooms))
        self.agent_room = self.start_rooms[agent_room_idx]
        # Place the agent in the selected start room.
        room_ul_agent = (self.room_size + 1) * np.array(self.start_rooms[agent_room_idx][::-1])

        if self.init_agent_position is not None:
            self.agent_pos = self.init_agent_position
        else:
            self.agent_pos = self._sample_room(room_ul_agent)
        # print(f"the agent position is {self.agent_pos}")
        self.is_agent_in_lower_rooms = (self.agent_pos[1] >= 4) # check if key is required



        # Create an empty grid (using MultiRoomGridColor)
        self.grid = MultiRoomGridColor(width, height)

        self.mission = ""
        ul = [0, 0]
        key_required = False

        # Construct each room in the grid based on the configuration.
        for row in self.config:
            for col in row:
                if "d" in col:
                    key_required = True

                self._construct_room(col, ul)
                ul[0] += self.room_size + 1

            ul[0] = 0
            ul[1] += self.room_size + 1

        if self.key_required is None:
            self.key_required = key_required

        # --- Goal Placement ---
        if self.agent_goal_same_room:
            # Warning: only works if we don't specified agent position
            # print("place agent and goal in the same room")
            room_ul_goal = room_ul_agent
        else:
            # Choose a random goal room from the provided goal_rooms.
            if self.wrong_direction:
                # print("Do not place the agent and goal in the same room")
                agent_room = self.agent_room
                # print(f"agent room is {agent_room} and the goal rooms is {self.goal_rooms}")
                filtered_goal_rooms = [room for room in self.goal_rooms if room != agent_room]
                room_idx_goal = np.array(filtered_goal_rooms[np.random.choice(len(filtered_goal_rooms))][::-1])
            else:
                room_idx_goal = np.array(self.goal_rooms[np.random.choice(len(self.goal_rooms))][::-1])
            room_ul_goal = ((self.room_size + 1) * room_idx_goal[0], (self.room_size + 1) * room_idx_goal[1])



        if self.modified_settings is not None and "goal" in self.modified_settings:
            new_goal_color = self.OBJECT_TO_NEW_COLORS["goal"]
            self._place_object(room_ul_goal, Goal(color=new_goal_color))
        else:
            self._place_object(room_ul_goal, Goal())
        self.goal_pos = self.find_goal_position()

        # print(f"key required: {self.key_required} agent_in_lower_rooms: {self.is_agent_in_lower_rooms}"
        #       f"wrong_direction: {self.wrong_direction}")
        # --- Key Placement ---
        # Place the key only if a door was required and the agent is not in a lower room.
        if self.key_required and not self.is_agent_in_lower_rooms and not self.wrong_direction:
            if self.agent_key_same_room:
                # Warning: only works if we don't specified agent position
                # print("The agent and the key are in the same room")
                # Place key in the same room as the agent.
                room_ul_key = room_ul_agent
                self._place_object(room_ul_key, Key("red"))
            else:

                # Original behavior: randomly choose one of the upper rooms.
                room_idx_key = np.random.choice(2)
                room_ul_key = (self.room_size + 1) * np.array(self.start_rooms[room_idx_key][::-1])
                self._place_object(room_ul_key, Key("red"))

        if self.init_agent_direction is None:
            if self.wrong_direction:
                self.agent_dir = self._get_wrong_direction(self.agent_pos)
            else:
                self.agent_dir = np.random.randint(low=0, high=4)
        else:
            self.agent_dir = self.init_agent_direction
        # print(f"the agent direction is {self.agent_dir}")

    def _place_object(self, ul, obj):
        loc = self._sample_room(ul)
        if isinstance(obj, Goal) and self.default_goal_position is not None:
            loc = self.default_goal_position
        self.put_obj(obj, *loc)

        return loc

    def find_goal_position(self):
        """
        Finds and returns the (x, y) position of the goal in the grid.
        """
        for x in range(self.width):
            for y in range(self.height):
                obj = self.grid.get(x, y)
                if isinstance(obj, Goal):
                    return (x, y)

        raise ValueError("Goal position not found in the grid!")



def visualize_grid_env(env):
    env.reset()
    frame = env.get_frame()
    
    return Image.fromarray(frame, "RGB")
    
def main():
    
    # src_env = MultiRoomGrid(
    #     [["wwow", "owwo"], ["wwow", "ooww"]], start_rooms=[[0, 0], [0, 1],], goal_rooms=[[1, 0], [1, 1]], room_size = 3,
    #     render_mode="rgb_array", highlight=False, modified_settings=[],
    #     agent_position=(6, 5), goal_position=(7, 5), agent_direction=0
    # )
    #
    # src_frame = visualize_grid_env(src_env)
    # src_frame.save('src_frame.png')
    #
    # tgt_env = MultiRoomGrid(
    #     [["wwow", "owwd"], ["wwow", "odww"]],
    #     start_rooms=[[0, 0], [0, 1]],
    #     goal_rooms=[[1, 0], [1, 1]], room_size = 3,
    #     render_mode="rgb_array", highlight=False, modified_settings=[],
    #     agent_position=(6, 5), agent_direction=0,
    #     goal_position=(7,5), key_required=False,
    # )
    #
    # tgt_frame = visualize_grid_env(tgt_env)
    # tgt_frame.save('tgt_frame.png')

    open_src_env = MultiRoomGrid(
        [["wwoo", "owwo"], ["woow", "ooww"]], start_rooms=[[0, 1],], goal_rooms=[[0, 0]], room_size = 3,
        render_mode="rgb_array", highlight=False, modified_settings=[], agent_position=(5,6), agent_direction=1, goal_position=(1,5)
    )

    open_src_frame = visualize_grid_env(open_src_env)
    open_src_frame.save('open_src.png')

    open_target_env = MultiRoomGrid(
        [["wwoo", "owwo"], ["woow", "ooww"]], start_rooms=[[0, 1], ], goal_rooms=[[0, 0], [1, 1]], room_size=3,
        render_mode="rgb_array", highlight=False, modified_settings=[],
    )

    open_target_frame = visualize_grid_env(open_target_env)
    open_target_frame.save('open_target.png')


    
if __name__ == "__main__":
    main()