from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Goal, Lava, Wall
from minigrid.minigrid_env import MiniGridEnv


class Corridor(MiniGridEnv):
    def __init__(
        self,
        size: int = 8,
        agent_start_pos: tuple[int, int] = (1, 2),
        agent_start_dir: int = 0,
        max_steps: int | None = None,
        render_mode: str = "rgb_array",
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir

        mission_space = MissionSpace(mission_func=self._gen_mission)

        super().__init__(
            mission_space=mission_space,
            width=size,
            height=5,
            max_steps=256,
            render_mode=render_mode,
            **kwargs,
        )

    @staticmethod
    def _gen_mission():
        return "reach da goal"

    def _gen_grid(self, width: int, height: int) -> None:
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Generate lava to the sides of the corridor
        for i in range(1, width - 1):
            self.grid.set(i, 1, Lava())
            self.grid.set(i, 3, Lava())

        # Place a goal square in the bottom-right corner
        self.put_obj(Goal(), width - 2, height - 3)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "reach da goal"


class Fork(MiniGridEnv):
    def __init__(
        self,
        size: int = 8,
        agent_start_pos: tuple[int, int] = (1, 2),
        agent_start_dir: int = 0,
        max_steps: int | None = None,
        render_mode: str = "rgb_array",
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir

        mission_space = MissionSpace(mission_func=self._gen_mission)

        super().__init__(
            mission_space=mission_space,
            width=size,
            height=5,
            max_steps=256,
            render_mode=render_mode,
            **kwargs,
        )

    @staticmethod
    def _gen_mission():
        return "reach da goal"

    def _gen_grid(self, width: int, height: int) -> None:
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Generate fork in the middle of the corridor
        for i in range(2, width - 1):
            self.grid.set(i, 2, Wall())

        # Generate lava on the upper side of the fork before the goal
        for i in range(3, width - 2):
            self.grid.set(i, 1, Lava())

        # Place a goal square in the bottom-right corner
        self.put_obj(Goal(), width - 2, height - 4)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "reach da goal"


class DoorFork(MiniGridEnv):
    def __init__(
        self,
        size: int = 9,
        corridor_length: int = 5,
        agent_start_dir: int = 3,
        max_steps: int | None = None,
        render_mode: str = "rgb_array",
        **kwargs,
    ):
        assert corridor_length < size, "Corridor length must be less than the grid size."
        self.corridor_length = corridor_length
        self.agent_start_dir = agent_start_dir

        mission_space = MissionSpace(mission_func=self._gen_mission)

        self.agent_start_pos = (corridor_length + 1, size - 2)

        super().__init__(
            mission_space=mission_space,
            width=corridor_length * 2 + 3,
            height=size,
            max_steps=max_steps or 256,
            render_mode=render_mode,
            **kwargs,
        )

    @staticmethod
    def _gen_mission():
        return "reach da goal"

    def _gen_grid(self, width: int, height: int) -> None:
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, width, height)

        # Tile walls around the initial corridor
        for x in range(1, self.corridor_length + 1):
            for y in range(2, height - 1):
                self.grid.set(x, y, Wall())
        for x in range(self.corridor_length + 2, width - 1):
            for y in range(2, height - 1):
                self.grid.set(x, y, Wall())

        # Place goals
        goal_y = 1
        self.put_obj(Goal(), 1, goal_y)
        self.put_obj(Goal(), width - 2, goal_y)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "reach da goal"
