from __future__ import annotations

import time
import gymnasium as gym
import numpy as np
from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Wall
from minigrid.minigrid_env import MiniGridEnv


class SimpleEnv(MiniGridEnv):
    """MiniGrid environment with hand-crafted wall layouts for experiments.

    The environment returns RGB frames as observations and supports global
    directional actions (up/down/left/right) independent of orientation.
    """
    def __init__(
        self,
        size: int = 15,
        agent_start_pos: tuple[int, int] = (1, 1),
        agent_start_dir: int = 0,
        max_steps: int | None = None,
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self.size = size
        mission_space = MissionSpace(mission_func=self._gen_mission)

        if max_steps is None:
            max_steps = 4 * size**2

        super().__init__(
            mission_space=mission_space,
            grid_size=size,
            see_through_walls=True,
            max_steps=max_steps,
            **kwargs,
        )
        self.action_space = gym.spaces.Discrete(4)

    # ------------------------------------------------------------------ #
    #                        Minigrid hooks                               #
    # ------------------------------------------------------------------ #

    @staticmethod
    def _gen_mission():
        return "grand mission"

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        obs = self.render()  # (H, W, 3) uint8
        return obs, {}

    def _gen_grid(self, width, height):
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # ------------------------------------------------------------------
        # The following commented blocks illustrate alternative wall layouts.
        # Do NOT delete them; simply uncomment one block if you wish to use
        # that specific environment variant.
        # ------------------------------------------------------------------

        # ---------------- Env-3: global information transfer --------------
        # for i in range(0, 4):
        #     self.grid.set(6, i, Wall())
        # for i in range(6, 15):
        #     self.grid.set(6, i, Wall())
        # for i in range(17, 20):
        #     self.grid.set(6, i, Wall())
        #
        # for i in range(3, 7):
        #     self.grid.set(8, i, Wall())
        # for i in range(3, 10):
        #     self.grid.set(11, i, Wall())
        # for j in range(8, 11):
        #     self.grid.set(j, 3, Wall())
        # for j in range(13, 17):
        #     self.grid.set(j, 3, Wall())
        # for i in range(3, 8):
        #     self.grid.set(17, i, Wall())
        # for j in range(7, 10):
        #     self.grid.set(j, 13, Wall())
        # for i in range(13, 18):
        #     self.grid.set(10, i, Wall())
        # for i in range(14, 19):
        #     self.grid.set(15, i, Wall())
        # for j in range(15, 18):
        #     self.grid.set(j, 14, Wall())
        # for i in range(15, 17):
        #     self.grid.set(17, i, Wall())
        # for i in range(13, 16):
        #     self.grid.set(2, i, Wall())
        # for j in range(2, 5):
        #     self.grid.set(j, 13, Wall())
        # self.grid.set(9, 17, Wall())
        #
        # for i in range(0, 3):
        #     self.grid.set(3, i, Wall())
        # for i in range(3, 6):
        #     self.grid.set(i, 4, Wall())
        # for i in range(3, 6):
        #     self.grid.set(i, 6, Wall())
        # for i in range(17, 20):
        #     self.grid.set(4, i, Wall())
        # for i in range(2, 5):
        #     self.grid.set(i, 16, Wall())
        #
        # for i in range(0, 4):
        #     self.grid.set(13, i, Wall())
        # for i in range(6, 15):
        #     self.grid.set(13, i, Wall())
        # for i in range(17, 20):
        #     self.grid.set(13, i, Wall())
        #
        # for i in range(0, 2):
        #     self.grid.set(i, 10, Wall())
        # for i in range(4, 9):
        #     self.grid.set(i, 10, Wall())
        # for i in range(11, 16):
        #     self.grid.set(i, 10, Wall())
        # for i in range(18, 20):
        #     self.grid.set(i, 10, Wall())
        #
        # for i in range(0, 5):
        #     self.grid.set(6, i, Wall())
        # for i in range(7, 15):
        #     self.grid.set(6, i, Wall())
        # for i in range(18, 20):
        #     self.grid.set(6, i, Wall())

        # ---------------- Small 4-room environment ------------------------
        # for i in range(0, 2):
        #     self.grid.set(4, i, Wall())
        # for i in range(3, 6):
        #     self.grid.set(4, i, Wall())
        # for i in range(7, 9):
        #     self.grid.set(4, i, Wall())
        #
        # for i in range(0, 2):
        #     self.grid.set(i, 4, Wall())
        # for i in range(3, 6):
        #     self.grid.set(i, 4, Wall())
        # for i in range(7, 9):
        #     self.grid.set(i, 4, Wall())

        # ---------------- Corridor environment ---------------------------
        # for i in range(0, 6):
        #     self.grid.set(5, i, Wall())
        # for i in range(7, 9):
        #     self.grid.set(5, i, Wall())
        #
        # for j in range(4, 8):
        #     self.grid.set(j, 5, Wall())
        # for j in range(4, 8):
        #     self.grid.set(j, 7, Wall())

        # ---------------- Key–lock environment (active) ------------------
        for i in range(9, 15):
            self.grid.set(i, 5, Wall())
        for j in range(0, 3):
            self.grid.set(9, j, Wall())
        self.grid.set(9, 4, Wall())

        for i in range(9, 15):
            self.grid.set(i, 9, Wall())
        for j in range(10, 11):
            self.grid.set(9, j, Wall())
        self.grid.set(9, 13, Wall())

        for i in range(0, 3):
            self.grid.set(i, 9, Wall())
        for i in range(4, 6):
            self.grid.set(i, 9, Wall())
        for j in range(9, 15):
            self.grid.set(6, j, Wall())
        for j in range(7, 9):
            self.grid.set(2, j, Wall())
        for j in range(7, 9):
            self.grid.set(4, j, Wall())

        # ------------------------------------------------------------------

        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "grand mission"

    # ------------------------------------------------------------------ #
    #                               Step                                 #
    # ------------------------------------------------------------------ #

    def step(self, action):
        """
        Execute a *global* action (up, down, left, right) independent of
        the agent’s current orientation.

        Global action mapping:
            0 – Move up    (Y−)
            1 – Move down  (Y+)
            2 – Move left  (X−)
            3 – Move right (X+)
        """
        # Direction codes: 0=east, 1=south, 2=west, 3=north
        current_dir = self.agent_dir

        # Map global action to target facing direction
        if action == 0:        # up
            target_dir = 3
        elif action == 1:      # down
            target_dir = 1
        elif action == 2:      # left
            target_dir = 2
        elif action == 3:      # right
            target_dir = 0
        else:
            raise ValueError("Invalid action index")

        # Rotate until facing the target direction
        while current_dir != target_dir:
            super().step(0)  # rotate left
            current_dir = (current_dir - 1) % 4

        # Move forward once
        super().step(2)

        reward = 0
        terminated, truncated = False, False
        rgb_obs = self.render()  # Return RGB image as observation

        return rgb_obs, reward, terminated, truncated, {}


# ------------------------------------------------------------------ #
#                               Debug                                #
# ------------------------------------------------------------------ #

def main():
    env = SimpleEnv(render_mode="human", highlight=False)
    env.reset()
    print(f"Agent initial position: {env.agent_pos}, direction: {env.agent_dir}")

    # Example usage of the step function
    actions = [3, 1, 2, 0]  # right, down, left, up
    # for act in actions:
    #     obs, reward, terminated, truncated, info = env.step(act)
    #     print(f"New position: {env.agent_pos}, facing: {env.agent_dir}")

    env.render()
    time.sleep(1000)


if __name__ == "__main__":
    main()
