import random

from gymnasium.core import ObservationWrapper

from minigrid.core.constants import OBJECT_TO_IDX, COLOR_TO_IDX
from gymnasium import spaces
import numpy as np



class EmptyObject:
    def __init__(self, cur_pos):
        self.type = 'empty'
        self.cur_pos = cur_pos
        self.color = None


class AgentObject:
    def __init__(self, cur_pos):
        self.type = 'agent'
        self.cur_pos = cur_pos


class LanguageObsWrapper(ObservationWrapper):
    """
    Fully observable grid with a language state representation.
    Example:
        >>> import gymnasium as gym
        >>> from minigrid.wrappers import SymbolicObsWrapper
        >>> env = gym.make("BabyAI-GoToRedBlueBall-v0")
        >>> env_obs = LanguageObsWrapper(env)
        >>> obs, _ = env_obs.reset()
        >>> obs['image'].shape
        (11, 11, 3)
    """

    def __init__(self, env):
        env = FullyObsWrapper(env)
        super().__init__(env)
        self.obj_grid = None

        new_image_space = spaces.Box(
            low=0,
            high=max(OBJECT_TO_IDX.values()),
            shape=(self.env.unwrapped.width, self.env.unwrapped.height, 3),  # number of cells
            dtype="uint8",
        )
        self.observation_space = spaces.Dict(
            {**self.observation_space.spaces, "image": new_image_space}
        )

    @staticmethod
    def bfs(grid, start, visited):
        queue = [start]
        room = set()
        while queue:
            (x, y) = queue.pop(0)
            if grid[x][y].type not in ['wall', 'door']:
                visited.add((x, y))
                room.add((x, y))
                if x > 0 and (x - 1, y) not in visited:
                    queue.append((x - 1, y))
                if x < len(grid) - 1 and (x + 1, y) not in visited:
                    queue.append((x + 1, y))
                if y > 0 and (x, y - 1) not in visited:
                    queue.append((x, y - 1))
                if y < len(grid[0]) - 1 and (x, y + 1) not in visited:
                    queue.append((x, y + 1))
        return room

    @staticmethod
    def get_rooms(grid):
        rooms = []  # list of tuples, each tuple is (x, y, object)
        visited = set()
        for i in range(len(grid)):
            for j in range(len(grid[i])):
                if grid[i][j].type not in ['wall', 'door'] and (i, j) not in visited:
                    room = LanguageObsWrapper.bfs(grid, (i, j), visited)
                    rooms.append(room)
                    # visited.update(room)
                # elif grid[i][j] == 'door':
                #     doors.append((i, j))
        return rooms

    @staticmethod
    def get_object_of_type(obj_type, grid):
        objects = []  # list of tuples, each tuple is (x, y, object of type obj_type)
        for i in range(len(grid)):
            for j in range(len(grid[i])):
                obj = grid[i][j]
                if obj.type == obj_type:
                    objects.append((i, j, obj))
        return objects

    @staticmethod
    def get_adjoining_rooms(x, y, rooms):
        """
        get adjoining rooms given door pos (x, y)
        """
        adjoining_room_ids = []
        for del_x, del_y in zip([0, 1, -1, 0], [1, 0, 0, -1]):
            new_x, new_y = x + del_x, y + del_y
            for ind, room in enumerate(rooms):
                if (new_x, new_y) in room:
                    adjoining_room_ids.append(ind + 1)
                    if len(adjoining_room_ids) == 2:
                        return adjoining_room_ids
                    break

    def observation(self, obs):
        ncol, nrow = self.unwrapped.width, self.unwrapped.height
        self.grid = self.unwrapped.grid
        objects = np.array(
            [o if o is not None else -1 for o in self.grid.grid]
        )
        _objects = np.transpose(objects.reshape(1, nrow, ncol), (0, 2, 1))
        agent_pos = self.env.unwrapped.agent_pos
        _objects[0, agent_pos[0], agent_pos[1]] = AgentObject(cur_pos=agent_pos)
        # ncol, nrow = self.width, self.height
        grid = _objects

        # adding the empty object class
        for x in range(len(grid[0])):
            for y in range(len(grid[0][x])):
                if grid[0, x, y] == -1:
                    grid[0, x, y] = EmptyObject(cur_pos=(x, y))
                elif grid[0, x, y].type == obs['mission'].split()[-1]:
                    self.env.env.env.goal_obj = grid[0, x, y]

                    # retrieve the rooms positions using BFS
        # TODO: only once at the start
        rooms = LanguageObsWrapper.get_rooms(grid[0])
        # get language description of the grid
        descriptions = []
        for i, room in enumerate(rooms):
            objects = []
            for (x, y) in room:
                obj_in_cell = grid[0][x][y]
                if obj_in_cell.type not in ['empty', 'wall', 'door', 'agent']:
                    objects.append(f'{obj_in_cell.color} {obj_in_cell.type}')
                elif obj_in_cell.type in ['door', 'agent']:
                    objects.append(f'{obj_in_cell.type}')
                # TODO: key in box

            if len(objects) == 0:
                descriptions.append(f"Room {i + 1} is empty")
            else:
                objects = ', '.join((objects))
                descriptions.append(f"Room {i + 1} has {objects}")
        self.obj_grid = grid[0]
        doors = LanguageObsWrapper.get_object_of_type('door', grid[0])
        if len(doors) > 0:
            doors_status = []
            for door in doors:
                x, y, door_obj = door

                door_status = "open" if door_obj.is_open else ("locked" if door_obj.is_locked else "closed")
                # adjoining rooms for the door object
                adj_room_ids = LanguageObsWrapper.get_adjoining_rooms(x, y, rooms)
                doors_status.append(
                    f"The {door_obj.color} door connecting Room {min(adj_room_ids)} "
                    f"and Room {max(adj_room_ids)} is {door_status}")
            descriptions.extend(doors_status)

        obs["language"] = '. '.join(descriptions) + '. '

        objects = np.array(
            [o if o is not None else -1 for o in self.grid.grid]
        )
        agent_pos = self.env.unwrapped.agent_pos
        ncol, nrow = self.unwrapped.width, self.unwrapped.height
        grid = np.mgrid[:ncol, :nrow]
        _objects = np.transpose(objects.reshape(1, nrow, ncol), (0, 2, 1))

        grid = np.concatenate([grid, _objects])
        grid = np.transpose(grid, (1, 2, 0))
        grid[agent_pos[0], agent_pos[1], 2] = AgentObject(cur_pos=agent_pos)
        # adding the empty object class
        for x in range(len(grid)):
            for y in range(len(grid[x])):
                if grid[x, y, 2] == -1:
                    grid[x, y, 2] = EmptyObject(cur_pos=(x, y))
        obs["image"] = grid

        # admissible actions
        admissible_actions = []
        for o in objects:
            if o != -1 and o.type != 'wall':
                if o.type == 'door':
                    name = f"{o.color} door"
                    # admissible_actions.append(f"go to the {name}")
                    # admissible_actions.append(f"toggle the door")
                    admissible_actions.append(f"toggle {name}")
                elif o.type == 'box':
                    name = f"{o.color} box"
                    # admissible_actions.append(f"go to the {name}")
                    admissible_actions.append(f"pick up {name}")
                    # admissible_actions.append(f"drop the box")
                    admissible_actions.append(f"drop box in void")
                elif o.type == 'ball':
                    name = f"{o.color} ball"
                    # admissible_actions.append(f"go to the {name}")
                    admissible_actions.append(f"pick up {name}")
                    # admissible_actions.append(f"drop the ball")
                    admissible_actions.append(f"drop ball in void")
                elif o.type == 'key':
                    name = f"{o.color} key"
                    # admissible_actions.append(f"go to the {name}")
                    admissible_actions.append(f"pick up {name}")
                    # admissible_actions.append(f"drop the key")
                    admissible_actions.append(f"drop key in void")
        # admissible_actions.append("go to an empty cell")
        admissible_actions.append("done picking up")
        obs["admissible_actions"] = list(set(admissible_actions))
        return obs

    def is_object_next_to_door(self):
        """
        Check if a specified object (ball, box, or key) is next to a door.

        Args:
            grid (list): The grid representing the environment.
            object_type (str): The type of object to check ('ball', 'box', or 'key').

        Returns:
            bool: True if the object is next to a door, False otherwise.
        """
        grid = self.obj_grid
        doors = LanguageObsWrapper.get_object_of_type('door', grid)
        box = LanguageObsWrapper.get_object_of_type("box", grid)
        key = LanguageObsWrapper.get_object_of_type("key", grid)
        ball = LanguageObsWrapper.get_object_of_type("ball", grid)
        if len(doors) > 0:
            for door in doors:
                x, y, door_obj = door
                for obj in [box, key, ball]:
                    if len(obj) > 0:
                        for o in obj:
                            if abs(x - o[0]) + abs(y - o[1]) == 1:
                                return True

        return False

    def same_objects(self):
        """
        Check if the same object is present in the grid.
        """
        grid = self.obj_grid
        balls = LanguageObsWrapper.get_object_of_type('ball', grid)
        boxes = LanguageObsWrapper.get_object_of_type('box', grid)
        keys = LanguageObsWrapper.get_object_of_type('key', grid)
        doors = LanguageObsWrapper.get_object_of_type('door', grid)
        for objects in [balls, boxes, keys, doors]:
            if len(objects) > 0:
                for i in range(len(objects)):
                    for j in range(i+1, len(objects)):
                        if objects[i][2].color == objects[j][2].color:
                            return True
        return False

class FullyObsWrapper(ObservationWrapper):
    """
    Fully observable gridworld using a compact grid encoding instead of the agent view.

    Example:
        >>> import gymnasium as gym
        >>> import matplotlib.pyplot as plt
        >>> from minigrid.wrappers import FullyObsWrapper
        >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
        >>> obs, _ = env.reset()
        >>> obs['image'].shape
        (7, 7, 3)
        >>> env_obs = FullyObsWrapper(env)
        >>> obs, _ = env_obs.reset()
        >>> obs['image'].shape
        (11, 11, 3)
    """

    def __init__(self, env):
        super().__init__(env)

        new_image_space = spaces.Box(
            low=0,
            high=255,
            shape=(
                self.env.unwrapped.width,
                self.env.unwrapped.height,
                3,
            ),  # number of cells
            dtype="uint8",
        )

        self.observation_space = spaces.Dict(
            {**self.observation_space.spaces, "image": new_image_space}
        )

    def observation(self, obs):
        env = self.unwrapped
        full_grid = env.grid.encode()
        full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array(
            [OBJECT_TO_IDX["agent"], COLOR_TO_IDX["red"], env.agent_dir]
        )

        return {**obs, "image": full_grid}


def get_random_obj(baby_ai_bot):
    all_obj = [obj for obj in baby_ai_bot.mission.unwrapped.grid.grid if obj is not None and obj.type != "wall"]
    obj = all_obj[random.randint(0, len(all_obj) - 1)]
    return obj.color, obj.type

def find_pos_everywhere(bot, obj_desc, adjacent=False):
    """Find the position of the closest visible object matching a given description."""

    # print(obj_desc)
    # for i in obj_desc.obj_set:
    #     print(i.color, i.type)
    assert len(obj_desc.obj_set) > 0

    best_distance_to_obj = 9999
    best_pos = None
    best_obj = None
    lowest_locked_doors_in_between = 9999

    for i in range(len(obj_desc.obj_set)):
        if obj_desc.obj_set[i].type == "wall":
            continue
        try:
            if obj_desc.obj_set[i] == bot.mission.unwrapped.carrying:
                continue
            obj_pos = obj_desc.obj_poss[i]

            shortest_path_to_obj, _, with_blockers = shortest_path_everywhere(bot,
                                                                              lambda pos, cell: pos == obj_pos,
                                                                              try_with_blockers=True
                                                                              )
            assert shortest_path_to_obj is not None
            distance_to_obj = len(shortest_path_to_obj)

            if with_blockers:
                # The distance should take into account the steps necessary
                # to unblock the way. Instead of computing it exactly,
                # we can use a lower bound on this number of steps
                # which is 4 when the agent is not holding anything
                # (pick, turn, drop, turn back
                # and 7 if the agent is carrying something
                # (turn, drop, turn back, pick,
                # turn to other direction, drop, turn back)
                distance_to_obj = len(shortest_path_to_obj) + (
                    7 if bot.mission.unwrapped.carrying else 4
                )

            # If we looking for a door and we are currently in that cell
            # that contains the door, it will take us at least 2
            # (3 if `adjacent == True`) steps to reach the goal.`
            if distance_to_obj == 0:
                distance_to_obj = 3 if adjacent else 2

            # If what we want is to face a location that is adjacent to an object,
            # and if we are already right next to this object,
            # then we should not prefer this object to those at distance 2
            if adjacent and distance_to_obj == 1:
                distance_to_obj = 3


            num_locked_doors_on_path_to_obj = 0
            for pos in shortest_path_to_obj:
                cell = bot.mission.unwrapped.grid.get(*pos)
                if cell and cell.type == "door" and cell.is_locked:
                    num_locked_doors_on_path_to_obj += 1

            # print(f"num_locked_doors_on_path_to_obj: {num_locked_doors_on_path_to_obj}")
            if (num_locked_doors_on_path_to_obj < lowest_locked_doors_in_between ) or (distance_to_obj < best_distance_to_obj or
                    (distance_to_obj == best_distance_to_obj and best_obj.type == "door" and best_obj.is_locked and not obj_desc.obj_set[i].is_locked)):
                best_distance_to_obj = distance_to_obj
                best_pos = obj_pos
                best_obj = obj_desc.obj_set[i]
                lowest_locked_doors_in_between = num_locked_doors_on_path_to_obj
        except IndexError:
            # Suppose we are tracking red keys, and we just used a red key to open a door,
            # then for the last i, accessing obj_desc.obj_poss[i] will raise an IndexError
            # -> Solution: Not care about that red key we used to open the door
            pass

    return best_obj, best_pos

def shortest_path_everywhere(bot, accept_fn, try_with_blockers=False):
    """
    Finds the path to any of the locations that satisfy `accept_fn`.
    Prefers the paths that avoid blockers for as long as possible.
    """

    # Initial states to visit (BFS)
    initial_states = [(*bot.mission.unwrapped.agent_pos, *bot.mission.unwrapped.dir_vec)]

    path = finish = None
    with_blockers = False
    path, finish, previous_pos = _breadth_first_search_everywhere(bot,
        initial_states, accept_fn, ignore_blockers=False)

    # if not path:
    #    raise Exception("Path has blockers and is not reachable")
    if not path and try_with_blockers:
        with_blockers = True
        path, finish, _ = _breadth_first_search_everywhere(bot,
            [(i, j, 1, 0) for i, j in previous_pos],
            accept_fn, ignore_blockers=True)
        if path:
            # `path` now contains the path to a cell that is reachable without
            # blockers. Now let's add the path to this cell
            pos = path[-1]
            extra_path = []
            while pos:
                extra_path.append(pos)
                pos = previous_pos[pos]
            path = path + extra_path[1:]

    if path:
        # And the starting position is not required
        path = path[::-1]
        path = path[1:]

    # Note, that with_blockers only makes sense if path is not None
    return path, finish, with_blockers

def _breadth_first_search_everywhere(bot, initial_states, accept_fn, ignore_blockers):
    """Performs breadth first search.

    This is pretty much your textbook BFS. The state space is agent's locations,
    but the current direction is also added to the queue to slightly prioritize
    going straight over turning.

    """
    bot.bfs_counter += 1

    queue = [(state, None) for state in initial_states]
    grid = bot.mission.unwrapped.grid
    previous_pos = dict()

    while len(queue) > 0:
        state, prev_pos = queue[0]
        queue = queue[1:]
        i, j, di, dj = state

        if (i, j) in previous_pos:
            continue

        bot.bfs_step_counter += 1

        cell = grid.get(i, j)
        previous_pos[(i, j)] = prev_pos

        # If we reached a position satisfying the acceptance condition
        if accept_fn((i, j), cell):
            path = []
            pos = (i, j)
            while pos:
                path.append(pos)
                pos = previous_pos[pos]
            return path, (i, j), previous_pos

        # If this cell was not visually observed, don't expand from it
        #if not self.vis_mask[i, j]:
            #continue

        if cell:
            if cell.type == "wall":
                continue
            # If this is a door
            elif cell.type == "door":
                # If the door is closed, don't visit neighbors
                if not cell.is_open and not ignore_blockers:
                    continue
            elif not ignore_blockers:
                continue

        # Location to which the bot can get without turning
        # are put in the queue first
        for k, l in [(di, dj), (dj, di), (-dj, -di), (-di, -dj)]:
            next_pos = (i + k, j + l)
            next_dir_vec = (k, l)
            next_state = (*next_pos, *next_dir_vec)
            queue.append((next_state, (i, j)))

    # Path not found
    return None, None, previous_pos