# coding=utf-8
# Copyright 2021 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code from https://github.com/ucl-dark/paired.

Implements the multi-agent version of the Grid and MultiGridEnv classes.

Note that at each step, the environment expects an array of actions equal to the
number of agents with which the class was initialized. Similarly, it will return
an array of observations, and an array of rewards.

In the competitive version, as soon as one agent finds the goal, the game is
over.

In the non-competitive case, all episodes have a fixed length based on the
maximum number of steps. To avoid issues with some agents finishing early and
therefore requiring support for non-scalar step types, if an agent finishes
before the step deadlie it will be respawned in a new location. To make the
single-agent case comparable to this design, it should also run for a fixed
number of steps and allow the agent to find the goal as many times as possible
within this step budget.

Unlike Minigrid, Multigrid does not include the string text of the 'mission'
with each observation.
"""
import math

import gym
import gym_minigrid.minigrid as minigrid
import gym_minigrid.window as window
import gym_minigrid.rendering as rendering
import numpy as np

# Map of color names to RGB values
AGENT_COLOURS = [
    np.array([60, 182, 234]),  # Blue
    np.array([229, 52, 52]),  # Red
    np.array([144, 32, 249]),  # Purple
    np.array([69, 196, 60]),  # Green
    np.array([252, 227, 35]),  # Yellow
]


class WorldObj(minigrid.WorldObj):
    """Override MiniGrid base class to deal with Agent objects."""

    def __init__(self, obj_type, color=None):
        assert obj_type in minigrid.OBJECT_TO_IDX, obj_type
        self.type = obj_type
        if color:
            assert color in minigrid.COLOR_TO_IDX, color
            self.color = color

        self.contains = None

        # Initial position of the object
        self.init_pos = None

        # Current position of the object
        self.cur_pos = None

    @staticmethod
    def decode(type_idx, color_idx, state):
        """Create an object from a 3-tuple state description."""

        obj_type = minigrid.IDX_TO_OBJECT[type_idx]
        if obj_type != "agent":
            color = minigrid.IDX_TO_COLOR[color_idx]

        if obj_type == "empty" or obj_type == "unseen":
            return None

        if obj_type == "wall":
            v = minigrid.Wall(color)
        elif obj_type == "floor":
            v = minigrid.Floor(color)
        elif obj_type == "ball":
            v = minigrid.Ball(color)
        elif obj_type == "key":
            v = minigrid.Key(color)
        elif obj_type == "box":
            v = minigrid.Box(color)
        elif obj_type == "door":
            # State, 0: open, 1: closed, 2: locked
            is_open = state == 0
            is_locked = state == 2
            v = Door(color, is_open, is_locked)
        elif obj_type == "goal":
            v = minigrid.Goal()
        elif obj_type == "lava":
            v = minigrid.Lava()
        elif obj_type == "agent":
            v = Agent(color_idx, state)
        else:
            assert False, "unknown object type in decode '%s'" % obj_type

        return v


class Door(minigrid.Door):
    """Extends minigrid Door class to multiple agents possibly carrying keys."""

    def toggle(self, env, pos, carrying):
        # If the player has the right key to open the door
        if self.is_locked:
            if isinstance(carrying,
                          minigrid.Key) and carrying.color == self.color:
                self.is_locked = False
                self.is_open = True
                return True
            return False

        self.is_open = not self.is_open
        return True


class Agent(WorldObj):
    """Class to represent other agents existing in the world."""

    def __init__(self, agent_id, state):
        super(Agent, self).__init__("agent")
        self.agent_id = agent_id
        self.dir = state

    def can_contain(self):
        """Can this contain another object?"""
        return True

    def encode(self):
        """Encode the a description of this object as a 3-tuple of integers."""
        return (minigrid.OBJECT_TO_IDX[self.type], self.agent_id, self.dir)

    def render(self, img):
        tri_fn = rendering.point_in_triangle(
            (0.12, 0.19),
            (0.87, 0.50),
            (0.12, 0.81),
        )

        # Rotate the agent based on its direction
        tri_fn = rendering.rotate_fn(tri_fn,
                                     cx=0.5,
                                     cy=0.5,
                                     theta=0.5 * math.pi * self.dir)
        color = AGENT_COLOURS[self.agent_id]
        rendering.fill_coords(img, tri_fn, color)


class Grid(minigrid.Grid):
    """Extends Grid class, overrides some functions to cope with multi-agent
    case."""

    @classmethod
    def render_tile(
        cls,
        obj,
        highlight=None,
        tile_size=minigrid.TILE_PIXELS,
        subdivs=3,
        cell_type=None,
    ):
        """Render a tile and cache the result."""
        # Hash map lookup key for the cache
        if isinstance(highlight, list):
            key = (tuple(highlight), tile_size)
        else:
            key = (highlight, tile_size)
        key = obj.encode() + key if obj else key

        if key in cls.tile_cache:
            return cls.tile_cache[key]

        img = np.zeros(shape=(tile_size * subdivs, tile_size * subdivs, 3),
                       dtype=np.uint8)

        # Draw the grid lines (top and left edges)
        rendering.fill_coords(img, rendering.point_in_rect(0, 0.031, 0, 1),
                              (100, 100, 100))
        rendering.fill_coords(img, rendering.point_in_rect(0, 1, 0, 0.031),
                              (100, 100, 100))

        if obj is not None and obj.type != "agent":
            obj.render(img)

        # Highlight the cell if needed (do not highlight walls)
        if highlight and not (cell_type is not None and cell_type == "wall"):
            if isinstance(highlight, list):
                for a, agent_highlight in enumerate(highlight):
                    if agent_highlight:
                        rendering.highlight_img(img, color=AGENT_COLOURS[a])
            else:
                # Default highlighting for agent's partially observed views
                rendering.highlight_img(img)

        # Render agents after highlight to avoid highlighting agent triangle
        # (the combination of colours makes it difficult to ID agent)
        if obj is not None and obj.type == "agent":
            obj.render(img)

        # Downsample the image to perform supersampling/anti-aliasing
        img = rendering.downsample(img, subdivs)

        # Cache the rendered tile
        cls.tile_cache[key] = img

        return img

    def render(self, tile_size, highlight_mask=None):
        """Render this grid at a given scale.

        Args:
            tile_size: Tile size in pixels.
            highlight_mask: An array of binary masks, showing which part of the
                grid should be highlighted for each agent. Can also be used in
                partial observation for single agent, which must be handled
                differently.

        Returns:
            An image of the rendered Grid.
        """
        if highlight_mask is None:
            highlight_mask = np.zeros(shape=(self.width, self.height),
                                      dtype=np.bool)

        # Compute the total grid size
        width_px = self.width * tile_size
        height_px = self.height * tile_size

        img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)

        # Render the grid
        for y in range(0, self.height):
            for x in range(0, self.width):
                cell = self.get(x, y)
                cell_type = cell.type if cell else None

                if isinstance(highlight_mask, list):
                    # Figure out highlighting for each agent
                    n_agents = len(highlight_mask)
                    highlights = [
                        highlight_mask[a][x, y] for a in range(n_agents)
                    ]
                else:
                    highlights = highlight_mask[x, y]

                tile_img = Grid.render_tile(
                    cell,
                    highlight=highlights,
                    tile_size=tile_size,
                    cell_type=cell_type,
                )

                ymin = y * tile_size
                ymax = (y + 1) * tile_size
                xmin = x * tile_size
                xmax = (x + 1) * tile_size
                img[ymin:ymax, xmin:xmax, :] = tile_img

        return img

    @staticmethod
    def decode(array):
        """Decode an array grid encoding back into a grid."""

        width, height, channels = array.shape
        assert channels == 3

        vis_mask = np.ones(shape=(width, height), dtype=np.bool)

        grid = Grid(width, height)
        for i in range(width):
            for j in range(height):
                type_idx, color_idx, state = array[i, j]
                v = WorldObj.decode(type_idx, color_idx, state)
                grid.set(i, j, v)
                vis_mask[i, j] = type_idx != minigrid.OBJECT_TO_IDX["unseen"]

        return grid, vis_mask

    def rotate_left(self):
        """Rotate the grid counter-clockwise, including agents within it."""
        grid = Grid(self.height, self.width)

        for i in range(self.width):
            for j in range(self.height):
                v = self.get(i, j)

                # Directions are relative to the agent so must be modified
                if v is not None and v.type == "agent":
                    # Make a new agent so original grid isn't modified
                    v = Agent(v.agent_id, v.dir)
                    v.dir -= 1
                    if v.dir < 0:
                        v.dir += 4

                grid.set(j, grid.height - 1 - i, v)

        return grid

    def slice(self, top_x, top_y, width, height, agent_pos=None):
        """Get a subset of the grid for agents' partial observations."""

        grid = Grid(width, height)

        for j in range(0, height):
            for i in range(0, width):
                x = top_x + i
                y = top_y + j

                if x >= 0 and x < self.width and y >= 0 and y < self.height:
                    v = self.get(x, y)
                else:
                    v = minigrid.Wall()

                grid.set(i, j, v)

        return grid


class MultiGridEnv(minigrid.MiniGridEnv):
    """2D grid world game environment with multi-agent support."""

    def __init__(
        self,
        grid_size=None,
        width=None,
        height=None,
        max_steps=100,
        see_through_walls=False,
        seed=52,
        agent_view_size=7,
        n_agents=3,
        competitive=False,
        fixed_environment=False,
        minigrid_mode=False,
    ):
        """Constructor for multi-agent gridworld environment generator.

        Args:
            grid_size: Number of tiles for the width and height of the square
                grid.
            width: Number of tiles across grid width.
            height: Number of tiles in height of grid.
            max_steps: Number of environment steps before the episode end (max
                episode length).
            see_through_walls: True if agents can see through walls.
            seed: Random seed used in generating environments.
            agent_view_size: Number of tiles in the agent's square, partially
                observed view of the world.
            n_agents: The number of agents playing in the world.
            competitive: If True, as soon as one agent locates the goal, the
                episode ends for all agents. If False, if one agent locates the
                goal it is respawned somewhere else in the grid, and the episode
                continues until max_steps is reached.
            fixed_environment: If True, will use the same random seed each time
                the environment is generated, so it will remain constant / be
                the same environment each time.
            minigrid_mode: Set to True to maintain backwards compatibility with
                minigrid in the single agent case.
        """
        # Can't set both grid_size and width/height
        if grid_size:
            assert width is None and height is None
            width = grid_size
            height = grid_size

        # Set the number of agents
        self.n_agents = n_agents

        # If competitive, only one agent is allowed to reach the goal.
        self.competitive = competitive

        if self.n_agents == 1:
            self.competitive = True

        # Action enumeration for this environment
        self.actions = MultiGridEnv.Actions

        # Number of cells (width and height) in the agent view
        self.agent_view_size = agent_view_size

        # Range of possible rewards
        self.reward_range = (0, 1)

        # Compute observation and action spaces
        # Direction always has an extra dimension for tf-agents compatibility
        self.direction_obs_space = gym.spaces.Box(low=0,
                                                  high=3,
                                                  shape=(self.n_agents,),
                                                  dtype="uint8")

        self.pos_obs_space = gym.spaces.Box(low=0,
                                            high=grid_size,
                                            shape=(self.n_agents,),
                                            dtype="uint8")

        # Maintain for backwards compatibility with minigrid.
        self.minigrid_mode = minigrid_mode
        if self.minigrid_mode:
            msg = "Backwards compatibility with minigrid only possible with 1 agent"
            assert self.n_agents == 1, msg

            # Single agent case
            # Actions are discrete integer values
            self.action_space = gym.spaces.Discrete(len(self.actions))
            # Images have three dimensions
            self.image_obs_space = gym.spaces.Box(
                low=0,
                high=255,
                shape=(self.agent_view_size, self.agent_view_size, 3),
                dtype="uint8",
            )
        else:
            # First dimension of all observations is the agent ID
            self.action_space = gym.spaces.Box(low=0,
                                               high=len(self.actions) - 1,
                                               shape=(self.n_agents,),
                                               dtype="int64")

            self.image_obs_space = gym.spaces.Box(
                low=0,
                high=255,
                shape=(self.n_agents, self.agent_view_size,
                       self.agent_view_size, 3),
                dtype="uint8",
            )

        # Observations are dictionaries containing an encoding of the grid and
        # the agent's direction. x, y are agents' position that are only used by
        # surrogate model.
        self.observation_space = gym.spaces.Dict({
            "image": self.image_obs_space,
            "direction": self.direction_obs_space,
            "x": self.pos_obs_space,
            "y": self.pos_obs_space,
        })

        # Window to use for human rendering mode
        self.window = None

        # Environment configuration
        self.width = width
        self.height = height
        self.max_steps = max_steps
        self.see_through_walls = see_through_walls

        # Current position and direction of the agent
        self.agent_pos = [None] * self.n_agents
        self.agent_dir = [None] * self.n_agents

        # Maintain a done variable for each agent
        self.done = [False] * self.n_agents

        # Initialize the RNG
        self.seed(seed=seed)
        self.fixed_environment = fixed_environment

        # Initialize the state
        self.reset()

    def seed(self, seed):
        super().seed(seed=seed)
        self.seed_value = seed
        return [seed]

    def reset(self):
        if self.fixed_environment:
            self.seed(self.seed_value)

        # Current position and direction of the agent
        self.agent_pos = [None] * self.n_agents
        self.agent_dir = [None] * self.n_agents
        self.done = [False] * self.n_agents

        # Generate the grid. Will be random by default, or same environment if
        # 'fixed_environment' is True.
        self._gen_grid(self.width, self.height)

        # These fields should be defined by _gen_grid
        for a in range(self.n_agents):
            assert self.agent_pos[a] is not None
            assert self.agent_dir[a] is not None

            # Check that the agent doesn't overlap with an object
            start_cell = self.grid.get(*self.agent_pos[a])
            assert (start_cell.type == "agent" or start_cell is None or
                    start_cell.can_overlap())

        # Item picked up, being carried, initially nothing
        self.carrying = [None] * self.n_agents

        # Step count since episode start
        self.step_count = 0

        # Return first observation
        obs = self.gen_obs()

        return obs

    def __str__(self):
        """Produce a pretty string of the environment's grid along with the
        agent.

        A grid cell is represented by 2-character string, the first one for
        the object and the second one for the color.

        Returns:
            String representation of the grid.
        """

        # Map of object types to short string
        obj_to_str = {
            "wall": "W",
            "floor": "F",
            "door": "D",
            "key": "K",
            "ball": "A",
            "box": "B",
            "goal": "G",
            "lava": "V",
        }

        # Map agent's direction to short string
        agent_dir_to_str = {0: ">", 1: "V", 2: "<", 3: "^"}

        text = ""

        for j in range(self.grid.height):
            for i in range(self.grid.width):
                # Draw agents
                agent_here = False
                for a in range(self.n_agents):
                    if self.agent_pos[a] is not None and (
                            i == self.agent_pos[a][0] and
                            j == self.agent_pos[a][1]):
                        text += str(a) + agent_dir_to_str[self.agent_dir[a]]
                        agent_here = True
                if agent_here:
                    continue

                c = self.grid.get(i, j)

                if c is None:
                    text += "  "
                    continue

                if c.type == "door":
                    if c.is_open:
                        text += "__"
                    elif c.is_locked:
                        text += "L" + c.color[0].upper()
                    else:
                        text += "D" + c.color[0].upper()
                    continue

                text += obj_to_str[c.type] + c.color[0].upper()

            if j < self.grid.height - 1:
                text += "\n"

        return text

    def place_obj(self,
                  obj,
                  top=None,
                  size=None,
                  reject_fn=None,
                  max_tries=math.inf):
        """Place an object at an empty position in the grid.

        Args:
            obj: Instance of Minigrid WorldObj class (such as Door, Key, etc.).
            top: (x,y) position of the top-left corner of rectangle where to
                place.
            size: Size of the rectangle where to place.
            reject_fn: Function to filter out potential positions.
            max_tries: Throw an error if a position can't be found after this
                many tries.

        Returns:
            Position where object was placed.
        """
        if top is None:
            top = (0, 0)
        else:
            top = (max(top[0], 0), max(top[1], 0))

        if size is None:
            size = (self.grid.width, self.grid.height)

        num_tries = 0

        while True:
            # This is to handle with rare cases where rejection sampling
            # gets stuck in an infinite loop
            if num_tries > max_tries:
                raise gym.error.RetriesExceededError(
                    "Rejection sampling failed in place_obj")

            num_tries += 1

            pos = (
                self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
                self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
            )

            # Don't place the object on top of another object
            if self.grid.get(*pos) is not None:
                continue

            # Don't place the object where the agent is
            pos_no_good = False
            for a in range(self.n_agents):
                if np.array_equal(pos, self.agent_pos[a]):
                    pos_no_good = True
            if pos_no_good:
                continue

            # Check if there is a filtering criterion
            if reject_fn and reject_fn(self, pos):
                continue

            break

        self.grid.set(pos[0], pos[1], obj)

        if obj is not None:
            obj.init_pos = pos
            obj.cur_pos = pos

        return pos

    def place_agent(self,
                    top=None,
                    size=None,
                    rand_dir=True,
                    max_tries=math.inf):
        """Set the starting point of all agents in the world.

        Name chosen for backwards compatibility.

        Args:
            top: (x,y) position of the top-left corner of rectangle where agents
                can be placed.
            size: Size of the rectangle where to place.
            rand_dir: Choose a random direction for agents.
            max_tries: Throw an error if a position can't be found after this
                many tries.
        """
        for a in range(self.n_agents):
            self.place_one_agent(a,
                                 top=top,
                                 size=size,
                                 rand_dir=rand_dir,
                                 max_tries=math.inf)

    def place_one_agent(
        self,
        agent_id,
        top=None,
        size=None,
        rand_dir=True,
        max_tries=math.inf,
        agent_obj=None,
    ):
        """Set the agent's starting point at an empty position in the grid."""

        self.agent_pos[agent_id] = None
        pos = self.place_obj(None, top, size, max_tries=max_tries)

        self.place_agent_at_pos(agent_id,
                                pos,
                                agent_obj=agent_obj,
                                rand_dir=rand_dir)

        return pos

    def place_agent_at_pos(self, agent_id, pos, agent_obj=None, rand_dir=True):
        pos = (pos[0], pos[1])
        self.agent_pos[agent_id] = pos
        if rand_dir:
            self.agent_dir[agent_id] = self._rand_int(0, 4)

        # Place the agent object into the grid
        if not agent_obj:
            agent_obj = Agent(agent_id, self.agent_dir[agent_id])
            agent_obj.init_pos = pos
        else:
            agent_obj.dir = self.agent_dir[agent_id]
        agent_obj.cur_pos = pos
        self.grid.set(pos[0], pos[1], agent_obj)

    @property
    def dir_vec(self):
        """Get the direction vector for the agent (points toward forward
        movement).

        Returns:
            An array of directions that each agent is facing.
        """

        for a in range(self.n_agents):
            assert self.agent_dir[a] >= 0 and self.agent_dir[a] < 4
        return [
            minigrid.DIR_TO_VEC[self.agent_dir[a]] for a in range(self.n_agents)
        ]

    @property
    def right_vec(self):
        """Get the vector pointing to the right of the agents."""
        return [np.array((-dy, dx)) for (dx, dy) in self.dir_vec]

    @property
    def front_pos(self):
        """Get the position of the cell that is right in front of the agent."""
        front_pos = [None] * self.n_agents
        for a in range(self.n_agents):
            assert self.agent_pos[a] is not None and self.dir_vec[a] is not None
            front_pos[a] = self.agent_pos[a] + self.dir_vec[a]
        return front_pos

    def get_view_coords(self, i, j, agent_id):
        """Convert grid coordinates into agent's partially observed view.

        Translate and rotate absolute grid coordinates (i, j) into the agent's
        partially observable view (sub-grid).

        Note that the resulting coordinates may be negative or outside of the
        agent's view size.

        Args:
            i: Integer x coordinate.
            j: Integer y coordinate.
            agent_id: ID of the agent.

        Returns:
            Agent-centric coordinates.
        """

        ax, ay = self.agent_pos[agent_id]
        dx, dy = self.dir_vec[agent_id]
        rx, ry = self.right_vec[agent_id]

        # Compute the absolute coordinates of the top-left view corner
        sz = self.agent_view_size
        hs = self.agent_view_size // 2
        tx = ax + (dx * (sz - 1)) - (rx * hs)
        ty = ay + (dy * (sz - 1)) - (ry * hs)

        lx = i - tx
        ly = j - ty

        # Project the coordinates of the object relative to the top-left
        # corner onto the agent's own coordinate system
        vx = rx * lx + ry * ly
        vy = -(dx * lx + dy * ly)

        return vx, vy

    def get_view_exts(self, agent_id):
        """Get the extents of the square set of tiles visible to the agent.

        Note: the bottom extent indices are not included in the set

        Args:
            agent_id: Integer ID of the agent.

        Returns:
            Top left and bottom right (x,y) coordinates of set of visible tiles.
        """
        # Facing right
        if self.agent_dir[agent_id] == 0:
            top_x = self.agent_pos[agent_id][0]
            top_y = self.agent_pos[agent_id][1] - self.agent_view_size // 2
        # Facing down
        elif self.agent_dir[agent_id] == 1:
            top_x = self.agent_pos[agent_id][0] - self.agent_view_size // 2
            top_y = self.agent_pos[agent_id][1]
        # Facing left
        elif self.agent_dir[agent_id] == 2:
            top_x = self.agent_pos[agent_id][0] - self.agent_view_size + 1
            top_y = self.agent_pos[agent_id][1] - self.agent_view_size // 2
        # Facing up
        elif self.agent_dir[agent_id] == 3:
            top_x = self.agent_pos[agent_id][0] - self.agent_view_size // 2
            top_y = self.agent_pos[agent_id][1] - self.agent_view_size + 1
        else:
            assert False, "invalid agent direction"

        bot_x = top_x + self.agent_view_size
        bot_y = top_y + self.agent_view_size

        return (top_x, top_y, bot_x, bot_y)

    def relative_coords(self, x, y, agent_id):
        """Check if a grid position belongs to the agent's field of view.

        Args:
            x: Integer x coordinate.
            y: Integer y coordinate.
            agent_id: ID of the agent.

        Returns:
            The corresponding agent-centric coordinates of the grid position.
        """
        vx, vy = self.get_view_coords(x, y, agent_id)

        if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
            return None

        return vx, vy

    def in_view(self, x, y, agent_id):
        """Check if a grid position is visible to the agent."""
        return self.relative_coords(x, y, agent_id) is not None

    def agent_sees(self, x, y, agent_id):
        """Check if a non-empty grid position is visible to the agent."""
        coordinates = self.relative_coords(x, y, agent_id)
        if coordinates is None:
            return False
        vx, vy = coordinates

        obs = self.gen_obs()
        obs_grid, _ = Grid.decode(obs["image"][agent_id])
        obs_cell = obs_grid.get(vx, vy)
        world_cell = self.grid.get(x, y)

        return obs_cell is not None and obs_cell.type == world_cell.type

    def agent_is_done(self, agent_id):
        # Remove correspnding agent object from the grid
        pos = self.agent_pos[agent_id]
        agent_obj = self.grid.get(pos[0], pos[1])
        self.grid.set(pos[0], pos[1], None)

        self.done[agent_id] = True

        # If an agent finishes the level while carrying an object, it is
        # randomly respawned in a new position. Warning: this may break
        # dependencies for the level (e.g. if a key is spawned on the wrong side
        # of a door).
        if self.carrying[agent_id]:
            self.place_obj(obj=self.carrying[agent_id])
            self.carrying[agent_id] = None

        # Respawn agent in new location
        self.place_one_agent(agent_id, agent_obj=agent_obj)

    def move_agent(self, agent_id, new_pos):
        # Retrieve agent object
        old_pos = self.agent_pos[agent_id]
        agent_obj = self.grid.get(old_pos[0], old_pos[1])
        assert agent_obj.agent_id == agent_id
        assert agent_obj.cur_pos == old_pos

        # Update the agent position in grid and environment
        self.grid.set(old_pos[0], old_pos[1], None)
        self.agent_pos[agent_id] = (new_pos[0], new_pos[1])
        agent_obj.cur_pos = (new_pos[0], new_pos[1])
        self.grid.set(new_pos[0], new_pos[1], agent_obj)
        assert self.grid.get(new_pos[0],
                             new_pos[1]).cur_pos == self.agent_pos[agent_id]

    def rotate_agent(self, agent_id):
        # Retrieve agent object
        pos = self.agent_pos[agent_id]
        agent_obj = self.grid.get(pos[0], pos[1])
        assert agent_obj.agent_id == agent_id

        # Update the dir
        agent_obj.dir = self.agent_dir[agent_id]
        self.grid.set(pos[0], pos[1], agent_obj)
        assert self.grid.get(pos[0], pos[1]).dir == self.agent_dir[agent_id]

    def step_one_agent(self, action, agent_id):
        reward = 0

        # Get the position in front of the agent
        fwd_pos = self.front_pos[agent_id]

        # Get the contents of the cell in front of the agent
        fwd_cell = self.grid.get(*fwd_pos)

        # Rotate left
        if action == self.actions.left:
            self.agent_dir[agent_id] -= 1
            if self.agent_dir[agent_id] < 0:
                self.agent_dir[agent_id] += 4
            self.rotate_agent(agent_id)

        # Rotate right
        elif action == self.actions.right:
            self.agent_dir[agent_id] = (self.agent_dir[agent_id] + 1) % 4
            self.rotate_agent(agent_id)

        # Move forward
        elif action == self.actions.forward:
            # Make sure agents can't walk into each other
            agent_blocking = False
            for a in range(self.n_agents):
                if a != agent_id and np.array_equal(self.agent_pos[a], fwd_pos):
                    agent_blocking = True

            # Deal with object interactions
            if not agent_blocking:
                if fwd_cell is not None and fwd_cell.type == "goal":
                    self.agent_is_done(agent_id)
                    reward = self._reward()
                elif fwd_cell is not None and fwd_cell.type == "lava":
                    self.agent_is_done(agent_id)
                elif fwd_cell is None or fwd_cell.can_overlap():
                    self.move_agent(agent_id, fwd_pos)

        # Pick up an object
        elif action == self.actions.pickup:
            if fwd_cell and fwd_cell.can_pickup():
                if self.carrying[agent_id] is None:
                    self.carrying[agent_id] = fwd_cell
                    self.carrying[agent_id].cur_pos = np.array([-1, -1])
                    self.grid.set(fwd_pos[0], fwd_pos[1], None)
                    a_pos = self.agent_pos[agent_id]
                    agent_obj = self.grid.get(a_pos[0], a_pos[1])
                    agent_obj.contains = fwd_cell

        # Drop an object
        elif action == self.actions.drop:
            if not fwd_cell and self.carrying[agent_id]:
                self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying[agent_id])
                self.carrying[agent_id].cur_pos = fwd_pos
                self.carrying[agent_id] = None
                a_pos = self.agent_pos[agent_id]
                agent_obj = self.grid.get(a_pos[0], a_pos[1])
                agent_obj.contains = None

        # Toggle/activate an object
        elif action == self.actions.toggle:
            if fwd_cell:
                if fwd_cell.type == "door":
                    fwd_cell.toggle(self, fwd_pos, self.carrying[agent_id])
                else:
                    fwd_cell.toggle(self, fwd_pos)

        # Done action -- by default acts as no-op.
        elif action == self.actions.done:
            pass

        else:
            assert False, "unknown action"

        return reward

    def step(self, actions):
        # Maintain backwards compatibility with MiniGrid when there is one agent
        if not isinstance(actions, list) and self.n_agents == 1:
            actions = [actions]

        self.step_count += 1

        rewards = [0] * self.n_agents

        # Randomize order in which agents act for fairness
        agent_ordering = np.arange(self.n_agents)
        np.random.shuffle(agent_ordering)

        # Step each agent
        for a in agent_ordering:
            rewards[a] = self.step_one_agent(actions[a], a)

        obs = self.gen_obs()

        # Backwards compatibility
        if self.minigrid_mode:
            rewards = rewards[0]

        collective_done = False
        # In competitive version, if one agent finishes the episode is over.
        if self.competitive:
            collective_done = np.sum(self.done) >= 1

        # Running out of time applies to all agents
        if self.step_count >= self.max_steps:
            collective_done = True

        return obs, rewards, collective_done, {}

    def gen_obs_grid(self, agent_id):
        """Generate the sub-grid observed by the agent.

        This method also outputs a visibility mask telling us which grid cells
        the agent can actually see.

        Args:
            agent_id: Integer ID of the agent for which to generate the grid.

        Returns:
            Sub-grid and visibility mask.
        """

        top_x, top_y, _, _ = self.get_view_exts(agent_id)

        grid = self.grid.slice(top_x, top_y, self.agent_view_size,
                               self.agent_view_size)

        for _ in range(self.agent_dir[agent_id] + 1):
            grid = grid.rotate_left()

        # Process occluders and visibility
        # Note that this incurs some performance cost
        if not self.see_through_walls:
            vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2,
                                                   self.agent_view_size - 1))
        else:
            vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)

        # Make it so the agent sees what it's carrying
        # We do this by placing the carried object at the agent's position
        # in the agent's partially observable view
        agent_pos = grid.width // 2, grid.height - 1
        if self.carrying[agent_id]:
            grid.set(agent_pos[0], agent_pos[1], self.carrying[agent_id])
        else:
            grid.set(agent_pos[0], agent_pos[1], None)

        return grid, vis_mask

    def gen_obs(self):
        """Generate the stacked observation for all agents."""
        images = []
        dirs = []
        xs = []
        ys = []
        for a in range(self.n_agents):
            image, direction = self.gen_agent_obs(a)
            images.append(image)
            dirs.append(direction)
            xs.append(self.agent_pos[a][0])
            ys.append(self.agent_pos[a][1])

        # Backwards compatibility: if there is a single agent do not return an
        # array
        if self.minigrid_mode:
            images = images[0]

        # Observations are dictionaries containing:
        # - an image (partially observable view of the environment)
        # - the agent's direction/orientation (acting as a compass)
        # Note direction has shape (1,) for tfagents compatibility
        # - x, y positions of agents (only for the surrogate model)
        obs = {"image": images, "direction": dirs, "x": xs, "y": ys}

        return obs

    def gen_agent_obs(self, agent_id):
        """Generate the agent's view (partially observed, low-resolution
        encoding).

        Args:
            agent_id: ID of the agent for which to generate the observation.

        Returns:
            3-dimensional partially observed agent-centric view, and int
            direction
        """
        grid, vis_mask = self.gen_obs_grid(agent_id)

        # Encode the partially observable view into a numpy array
        image = grid.encode(vis_mask)

        return image, self.agent_dir[agent_id]

    def get_obs_render(self, obs, tile_size=minigrid.TILE_PIXELS // 2):
        """Render an agent observation for visualization."""

        grid, vis_mask = Grid.decode(obs)

        # Render the whole grid
        img = grid.render(
            tile_size,
            # agent_pos=self.agent_pos,
            # agent_dir=self.agent_dir,
            highlight_mask=vis_mask,
        )

        return img

    def compute_agent_visibility_mask(self, agent_id):
        # Mask of which cells to highlight
        highlight_mask = np.zeros(shape=(self.width, self.height),
                                  dtype=np.bool)

        # Compute which cells are visible to the agent
        _, vis_mask = self.gen_obs_grid(agent_id)

        # Compute the world coordinates of the bottom-left corner
        # of the agent's view area
        f_vec = self.dir_vec[agent_id]
        r_vec = self.right_vec[agent_id]
        top_left = (self.agent_pos[agent_id] + f_vec *
                    (self.agent_view_size - 1) - r_vec *
                    (self.agent_view_size // 2))

        # For each cell in the visibility mask
        for vis_j in range(0, self.agent_view_size):
            for vis_i in range(0, self.agent_view_size):
                # If this cell is not visible, don't highlight it
                if not vis_mask[vis_i, vis_j]:
                    continue

                # Compute the world coordinates of this cell
                abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)

                if abs_i < 0 or abs_i >= self.width:
                    continue
                if abs_j < 0 or abs_j >= self.height:
                    continue

                # Mark this cell to be highlighted
                highlight_mask[abs_i, abs_j] = True

        return highlight_mask

    def render(self,
               mode="human",
               close=False,
               highlight=True,
               tile_size=minigrid.TILE_PIXELS):
        """Render the whole-grid human view."""

        if close:
            if self.window:
                self.window.close()
            return None

        if mode == "human" and not self.window:
            self.window = window.Window("gym_minigrid")
            # self.window.show(block=False)

        if highlight:
            highlight_mask = []
            for a in range(self.n_agents):
                if self.agent_pos[a] is not None:
                    highlight_mask.append(self.compute_agent_visibility_mask(a))
        else:
            highlight_mask = None

        # Render the whole grid
        img = self.grid.render(tile_size, highlight_mask=highlight_mask)

        if mode == "human":
            self.window.show_img(img)
            if hasattr(self, "mission"):
                self.window.set_caption(self.mission)

        # self.window.show()

        return img
