from __future__ import annotations
from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Wall
from minigrid.minigrid_env import MiniGridEnv
import time
import gymnasium as gym
import os

class SimpleEnv(MiniGridEnv):
    """A small MiniGrid environment returning RGB arrays as observations.

    Actions are global directions (up/down/left/right) irrespective of
    the agent's current facing. This makes the env convenient for
    visualizing option policies that are defined on the grid cells.
    """
    def __init__(
            self,
            size=10,
            agent_start_pos=(1, 1),
            agent_start_dir=0,
            max_steps: int | None = None,
            **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self.size = size
        mission_space = MissionSpace(mission_func=self._gen_mission)

        if max_steps is None:
            max_steps = 4 * size ** 2

        super().__init__(
            mission_space=mission_space,
            grid_size=size,
            see_through_walls=True,
            max_steps=max_steps,
            **kwargs,
        )
        self.action_space = gym.spaces.Discrete(4)

    @staticmethod
    def _gen_mission():
        return "grand mission"

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        obs = self.render()  # (H,W,3) uint8
        return obs, {}

    def _gen_grid(self, width, height):
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Example corridor-like layout (active below in the discrete version).
        for i in range(0, 6):
            self.grid.set(5, i, Wall())
        for i in range(7, 9):
            self.grid.set(5, i, Wall())

        for j in range(4, 8):
            self.grid.set(j, 5, Wall())
        for j in range(4, 8):
            self.grid.set(j, 7, Wall())

        # (Other commented-out layouts kept for reference)

        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "grand mission"

    def step(self, action):
        """Take a global action and return an RGB frame observation.

        Action mapping:
            0 – up (Y−), 1 – down (Y+), 2 – left (X−), 3 – right (X+)
        """

        # Direction encoding:
        # 0: right (east), 1: down (south), 2: left (west), 3: up (north)
        current_dir = self.agent_dir

        # Map global action to target facing
        if action == 0:        # up
            target_dir = 3
        elif action == 1:      # down
            target_dir = 1
        elif action == 2:      # left
            target_dir = 2
        elif action == 3:      # right
            target_dir = 0

        # Rotate until facing the desired direction
        while current_dir != target_dir:
            super().step(0)                # rotate left
            current_dir = (current_dir - 1) % 4

        # Move forward once
        super().step(2)

        # Simple goal reward (toy signal for debugging/visualization)
        target_pos = (2, 4)
        reward = 1 if tuple(self.agent_pos) == target_pos else 0

        terminated, truncated = False, False
        rgb_obs = self.render()            # Return RGB frame as observation

        return rgb_obs, reward, terminated, truncated, {}


def main():
    env = SimpleEnv(render_mode="human", highlight=False)
    env.reset()
    print(f"Agent initial position: {env.agent_pos}, direction: {env.agent_dir}")

    # Example manual moves (commented out)
    # actions = [3, 1, 2, 0]
    # for action in actions:
    #     obs, reward, done, info = env.step(action)
    #     print(f"Agent moved to: {env.agent_pos}, facing: {env.agent_dir}")

    env.render()
    time.sleep(1000)


if __name__ == "__main__":
    main()
