"""
Copied and adapted from https://github.com/mila-iqia/babyai
"""
from __future__ import annotations

from minigrid.core.roomgrid import RoomGrid
from minigrid.envs.babyai.core.verifier import (
    ActionInstr,
    AfterInstr,
    AndInstr,
    BeforeInstr,
    PutNextInstr,
    SeqInstr,
)
from minigrid.minigrid_env import MissionSpace


class RejectSampling(Exception):
    """
    Exception used for rejection sampling
    """

    pass


class BabyAIMissionSpace(MissionSpace):
    """
    Class that mimics the behavior required by minigrid.minigrid_env.MissionSpace,
    but does not change how missions are generated for BabyAI. It silences
    the gymnasium.utils.passive_env_checker given that it considers all strings to be
    plausible samples.
    """

    def __init__(self):
        super().__init__(mission_func=self._gen_mission)

    @staticmethod
    def _gen_mission():
        return "go"

    def contains(self, x: str):
        return True


class RoomGridLevel(RoomGrid):
    """
    Base for levels based on RoomGrid.
    A level, generates missions generated from
    one or more patterns. Levels should produce a family of missions
    of approximately similar difficulty.
    """

    def __init__(self, room_size=8, max_steps: int | None = None, **kwargs):
        mission_space = BabyAIMissionSpace()

        # If `max_steps` arg is passed it will be fixed for every episode,
        # if not it will vary after reset depending on the maze size.
        self.fixed_max_steps = False
        if max_steps is not None:
            self.fixed_max_steps = True
        else:
            max_steps = 0  # only for initialization
        super().__init__(
            room_size=room_size,
            mission_space=mission_space,
            max_steps=max_steps,
            **kwargs,
        )

    def reset(self, **kwargs):
        obs = super().reset(**kwargs)

        # Recreate the verifier
        self.instrs.reset_verifier(self)

        # Compute the time step limit based on the maze size and instructions
        nav_time_room = self.room_size**2
        nav_time_maze = nav_time_room * self.num_rows * self.num_cols
        num_navs = self.num_navs_needed(self.instrs)

        if not self.fixed_max_steps:
            self.max_steps = num_navs * nav_time_maze

        return obs

    def step(self, action):
        obs, reward, terminated, truncated, info = super().step(action)

        # If we drop an object, we need to update its position in the environment
        if action == self.actions.drop:
            self.update_objs_poss()

        # If we've successfully completed the mission
        status = self.instrs.verify(action)

        if status == "success":
            terminated = True
            reward = self._reward()
        elif status == "failure":
            terminated = True
            reward = 0

        return obs, reward, terminated, truncated, info

    def update_objs_poss(self, instr=None):
        if instr is None:
            instr = self.instrs
        if (
            isinstance(instr, BeforeInstr)
            or isinstance(instr, AndInstr)
            or isinstance(instr, AfterInstr)
        ):
            self.update_objs_poss(instr.instr_a)
            self.update_objs_poss(instr.instr_b)
        else:
            instr.update_objs_poss()

    def _gen_grid(self, width, height):
        # We catch RecursionError to deal with rare cases where
        # rejection sampling gets stuck in an infinite loop
        while True:
            try:
                super()._gen_grid(width, height)

                # Generate the mission
                self.gen_mission()

                # Validate the instructions
                self.validate_instrs(self.instrs)

            except RecursionError as error:
                print("Timeout during mission generation:", error)
                continue

            except RejectSampling as error:
                print("Sampling rejected:", error)
                continue

            break

        # Generate the surface form for the instructions
        self.surface = self.instrs.surface(self)
        self.mission = self.surface

    def validate_instrs(self, instr):
        """
        Perform some validation on the generated instructions
        """
        # Gather the colors of locked doors
        colors_of_locked_doors = []
        if hasattr(self, "unblocking") and self.unblocking:
            for i in range(self.num_cols):
                for j in range(self.num_rows):
                    room = self.get_room(i, j)
                    for door in room.doors:
                        if door and door.is_locked:
                            colors_of_locked_doors.append(door.color)

        if isinstance(instr, PutNextInstr):
            # Resolve the objects referenced by the instruction
            instr.reset_verifier(self)

            # Check that the objects are not already next to each other
            if set(instr.desc_move.obj_set).intersection(set(instr.desc_fixed.obj_set)):
                raise RejectSampling(
                    "there are objects that match both lhs and rhs of PutNext"
                )
            if instr.objs_next():
                raise RejectSampling("objs already next to each other")

            # Check that we are not asking to move an object next to itself
            move = instr.desc_move
            fixed = instr.desc_fixed
            if len(move.obj_set) == 1 and len(fixed.obj_set) == 1:
                if move.obj_set[0] is fixed.obj_set[0]:
                    raise RejectSampling("cannot move an object next to itself")

        if isinstance(instr, ActionInstr):
            if not hasattr(self, "unblocking") or not self.unblocking:
                return
            # TODO: either relax this a bit or make the bot handle this super corner-y scenarios
            # Check that the instruction doesn't involve a key that matches the color of a locked door
            potential_objects = ("desc", "desc_move", "desc_fixed")
            for attr in potential_objects:
                if hasattr(instr, attr):
                    obj = getattr(instr, attr)
                    if obj.type == "key" and obj.color in colors_of_locked_doors:
                        raise RejectSampling(
                            "cannot do anything with/to a key that can be used to open a door"
                        )
            return

        if isinstance(instr, SeqInstr):
            self.validate_instrs(instr.instr_a)
            self.validate_instrs(instr.instr_b)
            return

        assert False, "unhandled instruction type"

    def gen_mission(self):
        """
        Generate a mission (instructions and matching environment)
        Derived level classes should implement this method
        """
        raise NotImplementedError

    @property
    def level_name(self):
        return self.__class__.level_name

    @property
    def gym_id(self):
        return self.__class__.gym_id

    def num_navs_needed(self, instr) -> int:
        """
        Compute the maximum number of navigations needed to perform
        a simple or complex instruction
        """

        if isinstance(instr, PutNextInstr):
            return 2

        elif isinstance(instr, ActionInstr):
            return 1

        elif isinstance(instr, SeqInstr):
            na = self.num_navs_needed(instr.instr_a)
            nb = self.num_navs_needed(instr.instr_b)
            return na + nb

        else:
            raise NotImplementedError(
                "instr needs to be an instance of PutNextInstr, ActionInstr, or SeqInstr"
            )

    def open_all_doors(self):
        """
        Open all the doors in the maze
        """

        for i in range(self.num_cols):
            for j in range(self.num_rows):
                room = self.get_room(i, j)
                for door in room.doors:
                    if door:
                        door.is_open = True

    def check_objs_reachable(self, raise_exc=True):
        """
        Check that all objects are reachable from the agent's starting
        position without requiring any other object to be moved
        (without unblocking)
        """

        # Reachable positions
        reachable = set()

        # Work list
        stack = [self.agent_pos]

        while len(stack) > 0:
            i, j = stack.pop()

            if i < 0 or i >= self.grid.width or j < 0 or j >= self.grid.height:
                continue

            if (i, j) in reachable:
                continue

            # This position is reachable
            reachable.add((i, j))

            cell = self.grid.get(i, j)

            # If there is something other than a door in this cell, it
            # blocks reachability
            if cell and cell.type != "door":
                continue

            # Visit the horizontal and vertical neighbors
            stack.append((i + 1, j))
            stack.append((i - 1, j))
            stack.append((i, j + 1))
            stack.append((i, j - 1))

        # Check that all objects are reachable
        for i in range(self.grid.width):
            for j in range(self.grid.height):
                cell = self.grid.get(i, j)

                if not cell or cell.type == "wall":
                    continue

                if (i, j) not in reachable:
                    if not raise_exc:
                        return False
                    raise RejectSampling("unreachable object at " + str((i, j)))

        # All objects reachable
        return True
