from __future__ import annotations
from minigrid.minigrid_env import MiniGridEnv
from minigrid.core.world_object import  Key, Door, Floor
from environments.grid.simple_grid import visualize_grid_env
from environments.grid.fake_goal import LabeledGoal
from typing import Any, SupportsFloat

import gymnasium as gym
import numpy as np
from gymnasium.core import ActType, ObsType
from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace

class TwoGoalsGrid(MiniGridEnv):
    def __init__(
        self,
        config,
        start_rooms,
        goal_rooms,
        room_size=3,
        max_steps=100,
        see_through_walls=False,
        true_goal="green",
        punishment=None,
        **kwargs,
    ):
        """
        Initialize the TwoGoalsGrid environment.

        Args:
            config: A 2D array specifying the room configurations.
            start_rooms: A list of [row, col] specifying start rooms.
            goal_rooms: A list of [row, col] specifying goal rooms.
            room_size: Size of each room.
            max_steps: Maximum steps per episode.
            see_through_walls: If the agent can see through walls.
            true_goal: The color of the true goal ('green' or 'orange').
            punishment: List of punishments for reaching the fake goal ('episode_end', 'negative_reward').
        """
        self.num_rows = len(config)
        self.num_cols = len(config[0])
        self.room_size = room_size
        self.start_rooms = start_rooms
        self.goal_rooms = goal_rooms
        self.config = config
        self.max_tries = 100
        self.punishment = punishment or []

        # Define the true and fake goal colors
        self.true_goal_color = true_goal
        self.fake_goal_color = "orange" if true_goal == "green" else "green"

        self.width = (self.num_cols * room_size) + (self.num_cols + 1)
        self.height = (self.num_rows * room_size) + (self.num_rows + 1)
        # print(f"true goal color: {self.true_goal_color}, fake goal color: {self.fake_goal_color},"
        #       f" punishment: {self.punishment}")

        # Mission space
        mission_space = MissionSpace(
            mission_func=lambda: "Navigate to the true goal."
        )

        super().__init__(
            mission_space=mission_space,
            max_steps=max_steps,
            width=self.width,
            height=self.height,
            see_through_walls=see_through_walls
        )

        self.spec = gym.envs.registration.EnvSpec(
            id="two_goal_grid",
            max_episode_steps=max_steps,
            autoreset=False,
            disable_env_checker=False
        )

    def _sample_room(self, ul):
        try_idx = 0
        while try_idx < self.max_tries:
            loc = (np.random.randint(low=ul[0]+1, high=ul[0]+(self.room_size + 1)), np.random.randint(low=ul[1]+1, high=ul[1]+(self.room_size + 1)))

            if self.grid.get(*loc) == None and (self.agent_pos is None or not np.allclose(loc, self.agent_pos)):
                return loc

            try_idx += 1

        raise("Failed to sample point in room.")

    def _construct_room(self, room_config, ul):
        # Build walls around the room
        self.grid.wall_rect(*ul, self.room_size + 2, self.room_size + 2)

        # Add doors or openings as specified in the room configuration
        for dir, wall in zip(("l", "t", "r", "b"), room_config):
            if wall in ("o", "d"):  # Opening or door
                if dir == "l":
                    opening_idx = (ul[0], ul[1] + (self.room_size + 2) // 2)
                elif dir == "r":
                    opening_idx = (ul[0] + self.room_size + 1, ul[1] + (self.room_size + 2) // 2)
                elif dir == "t":
                    opening_idx = (ul[0] + (self.room_size + 2) // 2, ul[1])
                elif dir == "b":
                    opening_idx = (ul[0] + (self.room_size + 2) // 2, ul[1] + self.room_size + 1)

                if wall == "o":
                    self.grid.set(*opening_idx, Floor())
                elif wall == "d":
                    self.grid.set(*opening_idx, Door("red", is_open=False, is_locked=True))

    def _place_object(self, ul, obj):
        loc = None
        if isinstance(obj, LabeledGoal) and not obj.is_true_goal:
            while loc is None or self._is_blocking_door(loc):
                loc = self._sample_room(ul)
        else:
            loc = self._sample_room(ul)
        self.put_obj(obj, *loc)
        return loc

    def _is_blocking_door(self, loc):
        # Check if the given location is adjacent to a door
        for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
            neighbor = self.grid.get(loc[0] + dx, loc[1] + dy)
            if isinstance(neighbor, Door) or isinstance(neighbor, Floor):
                return True
        return False

    def _gen_grid(self, width, height):
        # Create the grid
        self.grid = Grid(width, height)
        ul = [0, 0]
        key_required = False

        # Construct rooms
        for row in self.config:
            for col in row:
                if "d" in col:
                    key_required = True
                self._construct_room(col, ul)
                ul[0] += self.room_size + 1
            ul[0] = 0
            ul[1] += self.room_size + 1

        # Place agent in a start room
        room_idx = np.random.choice(len(self.start_rooms))
        room_ul = (self.room_size + 1) * np.array(self.start_rooms[room_idx][::-1])
        self.agent_pos = self._sample_room(room_ul)
        self.agent_dir = np.random.randint(0, 4)

        # Place true goal
        true_goal_room = self.goal_rooms[np.random.choice(len(self.goal_rooms))]
        true_goal_ul = (self.room_size + 1) * np.array(true_goal_room[::-1])
        self._place_object(true_goal_ul, LabeledGoal(is_true_goal=True, color=self.true_goal_color))

        # Place fake goal
        fake_goal_room = self.goal_rooms[np.random.choice(len(self.goal_rooms))]
        fake_goal_ul = (self.room_size + 1) * np.array(fake_goal_room[::-1])
        self._place_object(fake_goal_ul, LabeledGoal(is_true_goal=False, color=self.fake_goal_color))

        # Place key if required
        if key_required:
            key_room_idx = np.random.choice(len(self.start_rooms))
            key_room_ul = (self.room_size + 1) * np.array(self.start_rooms[key_room_idx][::-1])
            self._place_object(key_room_ul, Key("red"))

    def step(
            self, action: ActType
    ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
        """
        Perform a step in the environment with punishments for reaching the fake goal.

        Args:
            action: The action taken by the agent.

        Returns:
            obs: Observation after taking the action.
            reward: Reward received.
            terminated: Whether the episode is terminated.
            truncated: Whether the episode is truncated (max steps reached).
            info: Additional information.
        """
        self.step_count += 1

        reward = 0
        terminated = False
        truncated = False

        # Get the position in front of the agent
        fwd_pos = self.front_pos

        # Get the contents of the cell in front of the agent
        fwd_cell = self.grid.get(*fwd_pos)

        # Rotate left
        if action == self.actions.left:
            self.agent_dir -= 1
            if self.agent_dir < 0:
                self.agent_dir += 4

        # Rotate right
        elif action == self.actions.right:
            self.agent_dir = (self.agent_dir + 1) % 4

        # Move forward
        elif action == self.actions.forward:
            if fwd_cell is None or fwd_cell.can_overlap():
                self.agent_pos = tuple(fwd_pos)

            if isinstance(fwd_cell, LabeledGoal):
                if fwd_cell.is_true_goal:
                    terminated = True
                    reward = self._reward()
                else:
                    # Handle punishments for the fake goal
                    if "episode_end" in self.punishment:
                        terminated = True
                    if "negative_reward" in self.punishment:
                        reward = -0.5

            if fwd_cell is not None and fwd_cell.type == "lava":
                terminated = True

        # Pick up an object
        elif action == self.actions.pickup:
            if fwd_cell and fwd_cell.can_pickup():
                if self.carrying is None:
                    self.carrying = fwd_cell
                    self.carrying.cur_pos = np.array([-1, -1])
                    self.grid.set(fwd_pos[0], fwd_pos[1], None)

        # Drop an object
        elif action == self.actions.drop:
            if not fwd_cell and self.carrying:
                self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
                self.carrying.cur_pos = fwd_pos
                self.carrying = None

        # Toggle/activate an object
        elif action == self.actions.toggle:
            if fwd_cell:
                fwd_cell.toggle(self, fwd_pos)

        # Done action (not used by default)
        elif action == self.actions.done:
            pass

        else:
            raise ValueError(f"Unknown action: {action}")

        # Check if max steps have been reached
        if self.step_count >= self.max_steps:
            truncated = True

        if self.render_mode == "human":
            self.render()

        obs = self.gen_obs()

        return obs, reward, terminated, truncated, {}


def main():
    # Source environment: True goal is green, fake goal is orange
    src_env = TwoGoalsGrid(
        config=[["wwow", "owwo"], ["wwow", "ooww"]],
        start_rooms=[[0, 0]],
        goal_rooms=[[1, 0], [1, 1]],
        room_size=3,
        max_steps=100,
        render_mode="rgb_array",
        see_through_walls=False,
        true_goal="green"
    )

    # Visualize and save the source environment
    src_frame = visualize_grid_env(src_env)
    src_frame.save('src_frame_2.png')
    print("Saved source environment as 'src_frame_2.png'.")

    # Target environment: True goal is orange, fake goal is green
    tgt_env = TwoGoalsGrid(
        config=[["wwow", "owwd"], ["wwow", "odww"]],
        start_rooms=[[0, 0]],
        goal_rooms=[[1, 0], [1, 1]],
        room_size=3,
        max_steps=100,
        render_mode="rgb_array",
        see_through_walls=False,
        true_goal="orange"
    )

    # Visualize and save the target environment
    tgt_frame = visualize_grid_env(tgt_env)
    tgt_frame.save('tgt_frame_2.png')
    print("Saved target environment as 'tgt_frame_2.png'.")

if __name__ == "__main__":
    main()

