import math

import gym
import numpy as np
from gym import spaces
from enum import IntEnum
from gym.utils import seeding
import os
import os.path as osp
try:
    import moviepy.editor as mpy
except:
    pass


def save_mp4(frames, vid_dir, name, fps=10.0, no_frame_drop=False):
    frames = np.array(frames)
    if len(frames[0].shape) == 4:
        new_frames = frames[0]
        for i in range(len(frames) - 1):
            new_frames = np.concatenate([new_frames, frames[i + 1]])
        frames = new_frames

    if no_frame_drop:
        def f(t):
            idx = min(int(t * fps), len(frames) - 1)
            return frames[idx]

        if not osp.exists(vid_dir):
            os.makedirs(vid_dir)

        vid_file = osp.join(vid_dir, name + '.mp4')
        if osp.exists(vid_file):
            os.remove(vid_file)

        video = mpy.VideoClip(f, duration=len(frames) / fps)
        video.write_videofile(vid_file, fps, verbose=False, logger=None)

    else:
        drop_frame = 1.5

        def f(t):
            frame_length = len(frames)
            new_fps = 1. / (1. / fps + 1. / frame_length)
            idx = min(int(t * new_fps), frame_length - 1)
            return frames[int(drop_frame * idx)]

        if not osp.exists(vid_dir):
            os.makedirs(vid_dir)

        vid_file = osp.join(vid_dir, name + '.mp4')
        if osp.exists(vid_file):
            os.remove(vid_file)

        video = mpy.VideoClip(f, duration=len(frames) / fps / drop_frame)
        video.write_videofile(vid_file, fps, verbose=False, logger=None)


# Size in pixels of a cell in the full-scale human view
CELL_PIXELS = 32
# OBJECT_TO_IDX = {
#     'unseen': 0,
#     'empty': 1,
#     'wall': 2,
#     'floor': 3,
#     'door': 4,
#     'key': 5,
#     'ball': 6,
#     'box': 7,
#     'mine': 8,
#     'goal': 9,
#     'lava': 10,
#     'agent': 11,
#     'subgoal': 12,
# }
# OBJECT_TO_IDX = {
#     'unseen': 0,
#     'empty': 1,
#     'wall': 2,
#     'mine': 3,
#     'goal': 4,
#     'agent': 5,
# }
OBJECT_TO_IDX = {
    'empty': 0,
    'mine': 1,
    'goal': 2,
    'agent': 3,
}
IDX_LEN = len(OBJECT_TO_IDX)

IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))

# Map of agent direction indices to vectors
DIR_TO_VEC = [
    # Pointing right (positive X)
    np.array((1, 0)),
    # Down (positive Y)
    np.array((0, 1)),
    # Pointing left (negative X)
    np.array((-1, 0)),
    # Up (negative Y)
    np.array((0, -1)),
]


class WorldObj(object):
    """
    Base class for grid world objects
    """

    def __init__(self, type, color, **kwargs):
        self.type = type
        self.color = color
        self.contains = None

        # Initial position of the object
        self.init_pos = None

        # Current position of the object
        self.cur_pos = None

    def can_overlap(self):
        """Can the agent overlap with this?"""
        return False

    def render(self, r):
        """Draw this object with the given renderer"""
        raise NotImplementedError

    def _set_color(self, r):
        """Set the color of this object as the active drawing color"""
        c = self.color.split(',')
        c = [int(ci) for ci in c]
        r.setLineColor(c[0], c[1], c[2])
        r.setColor(c[0], c[1], c[2])


class Goal(WorldObj):
    def __init__(self, obj_type='goal', color='000,255,000', **kwargs):
        super(Goal, self).__init__(type=obj_type, color=color)

    def can_overlap(self):
        return True

    def render(self, r):
        self._set_color(r)
        r.drawPolygon([
            (0, CELL_PIXELS),
            (CELL_PIXELS, CELL_PIXELS),
            (CELL_PIXELS, 0),
            (0, 0)
        ])


class Floor(WorldObj):
    """
    Colored floor tile the agent can walk over
    """

    def __init__(self, color='000,000,255'):
        super().__init__('floor', color)

    def can_overlap(self):
        return True

    def render(self, r):
        # Give the floor a pale color
        c = self.color.split(',')
        c = [int(ci) for ci in c]
        r.setLineColor(100, 100, 100, 0)
        r.setColor(*c / 2)
        r.drawPolygon([
            (1, CELL_PIXELS),
            (CELL_PIXELS, CELL_PIXELS),
            (CELL_PIXELS, 1),
            (1, 1)
        ])


class Wall(WorldObj):
    def __init__(self, color='100,100,100'):
        super().__init__('wall', color)

    def render(self, r):
        self._set_color(r)
        r.drawPolygon([
            (0, CELL_PIXELS),
            (CELL_PIXELS, CELL_PIXELS),
            (CELL_PIXELS, 0),
            (0, 0)
        ])


class Mine(WorldObj):
    def __init__(self, color, index=-1, score=0):
        super(Mine, self).__init__(type='mine', color=color)
        self._score = score
        self._index = index

    @property
    def mine_id(self):
        return self._index

    @property
    def score(self):
        return self._score

    def set_score(self, score):
        self._score = score

    def render(self, r):
        c = self.color.split(',')
        c = [int(ci) for ci in c]
        r.setLineColor(c[0], c[1], c[2])
        r.setColor(c[0], c[1], c[2])
        r.setLineWidth(2)

        # r.drawPolygon([
        #     (4            , CELL_PIXELS-4),
        #     (CELL_PIXELS-4, CELL_PIXELS-4),
        #     (CELL_PIXELS-4,             4),
        #     (4            ,             4)
        # ])
        r.drawCircle(
            CELL_PIXELS * 0.5, CELL_PIXELS * 0.5, CELL_PIXELS / 3)
        r.drawCircle(
            CELL_PIXELS * 0.3, CELL_PIXELS * 0.3, CELL_PIXELS / 8)
        r.drawCircle(
            CELL_PIXELS * 0.7, CELL_PIXELS * 0.3, CELL_PIXELS / 8)
        r.drawCircle(
            CELL_PIXELS * 0.5, CELL_PIXELS * 0.7, CELL_PIXELS / 8)

        # r.drawLine(
        #     4,
        #     CELL_PIXELS / 2,
        #     CELL_PIXELS - 4,
        #     CELL_PIXELS / 2
        # )

        r.setLineWidth(1)


class Grid(object):
    """
    Represent a grid and operations on it
    """

    def __init__(self, COLOR_TO_IDX, width, height, wall_color, args):
        self.COLOR_TO_IDX = COLOR_TO_IDX
        # assert width >= 3
        # assert height >= 3

        self._args = args
        self.width = width
        self.height = height

        self.grid = [None] * width * height
        self._wall_color = wall_color

    def __contains__(self, key):
        if isinstance(key, WorldObj):
            for e in self.grid:
                if e is key:
                    return True
        elif isinstance(key, tuple):
            for e in self.grid:
                if e is None:
                    continue
                if (e.color, e.type) == key:
                    return True
                if key[0] is None and key[1] == e.type:
                    return True
        return False

    def __eq__(self, other):
        grid1 = self.encode()
        grid2 = other.encode()
        return np.array_equal(grid2, grid1)

    def __ne__(self, other):
        return not self == other

    def copy(self):
        from copy import deepcopy
        return deepcopy(self)

    def set(self, i, j, v):
        assert i >= 0 and i < self.width
        assert j >= 0 and j < self.height
        self.grid[j * self.width + i] = v

    def get(self, i, j):
        assert i >= 0 and i < self.width
        assert j >= 0 and j < self.height
        return self.grid[j * self.width + i]

    def horz_wall(self, x, y, length=None):
        if length is None:
            length = self.width - x
        for i in range(0, length):
            self.set(x + i, y, Wall(color=self._wall_color))

    def vert_wall(self, x, y, length=None):
        if length is None:
            length = self.height - y
        for j in range(0, length):
            self.set(x, y + j, Wall(color=self._wall_color))

    def wall_rect(self, x, y, w, h):
        self.horz_wall(x, y, w)
        self.horz_wall(x, y + h - 1, w)
        self.vert_wall(x, y, h)
        self.vert_wall(x + w - 1, y, h)

    def rotate_left(self):
        """
        Rotate the grid to the left (counter-clockwise)
        """

        grid = Grid(COLOR_TO_IDX=self.COLOR_TO_IDX, height=self.height, width=self.width, wall_color=self._wall_color,
                    args=self._args)

        for i in range(self.width):
            for j in range(self.height):
                v = self.get(i, j)
                grid.set(j, grid.height - 1 - i, v)
        return grid

    def slice(self, topX, topY, width, height):
        """
        Get a subset of the grid
        """

        grid = Grid(COLOR_TO_IDX=self.COLOR_TO_IDX, width=width, height=height, wall_color=self._wall_color,
                    args=self._args)

        for j in range(0, height):
            for i in range(0, width):
                x = topX + i
                y = topY + j

                if x >= 0 and x < self.width and y >= 0 and y < self.height:
                    v = self.get(x, y)
                else:
                    v = Wall(color=self._wall_color)

                grid.set(i, j, v)

        return grid

    def render(self, r, tile_size):
        """
        Render this grid at a given scale
        :param r: target renderer object
        :param tile_size: tile size in pixels
        """
        assert r.width == self.width * tile_size
        assert r.height == self.height * tile_size

        # Total grid size at native scale
        widthPx = self.width * CELL_PIXELS
        heightPx = self.height * CELL_PIXELS

        r.push()

        # Internally, we draw at the "large" full-grid resolution, but we
        # use the renderer to scale back to the desired size
        r.scale(tile_size / CELL_PIXELS, tile_size / CELL_PIXELS)

        # Draw the background of the in-world cells black
        r.fillRect(
            0,
            0,
            widthPx,
            heightPx,
            0, 0, 0
        )

        # Draw grid lines
        r.setLineColor(100, 100, 100)
        for rowIdx in range(0, self.height):
            y = CELL_PIXELS * rowIdx
            r.drawLine(0, y, widthPx, y)
        for colIdx in range(0, self.width):
            x = CELL_PIXELS * colIdx
            r.drawLine(x, 0, x, heightPx)

        # Render the grid
        for j in range(0, self.height):
            for i in range(0, self.width):
                cell = self.get(i, j)
                if cell == None:
                    continue
                r.push()
                r.translate(i * CELL_PIXELS, j * CELL_PIXELS)
                cell.render(r)
                r.pop()

        r.pop()

    def encode(self, vis_mask=None):
        """
        Produce a compact numpy encoding of the grid
        """

        if vis_mask is None:
            vis_mask = np.ones((self.width, self.height), dtype=bool)

        channel = self._args['mw_obs_channel']
        array = np.zeros((self.width, self.height, channel), dtype='float')
        for i in range(self.width):
            for j in range(self.height):
                if vis_mask[i, j]:
                    v = self.get(i, j)  # get object at this coordinate
                    # if self._args["mw_if_simple_obs"]:
                    #     if v is not None: array[i, j, 0] = OBJECT_TO_IDX[v.type]
                    # else:
                    if v is None:
                        array[i, j, OBJECT_TO_IDX['empty']] = 1
                    else:
                        array[i, j, OBJECT_TO_IDX[v.type]] = 1
                        # if v.type == 'mine':
                        array[i, j, 6] = self.COLOR_TO_IDX[v.color][0] / 255.0
                        array[i, j, 7] = self.COLOR_TO_IDX[v.color][1] / 255.0
                        array[i, j, 8] = self.COLOR_TO_IDX[v.color][2] / 255.0
        return array

    def encode_whole(self, agent_pos=None, agent_dir=None):
        """
        Produce a compact numpy encoding of the grid
        """
        if agent_pos is None:
            agent_pos = (0, 0)
        if agent_dir is None:
            agent_dir = 0

        channel = self._args['mw_obs_channel']
        array = np.zeros((self.width, self.height, channel), dtype='float')
        if self._args['mw_obs_id']:
            for i in range(self.width):
                for j in range(self.height):
                    v = self.get(i, j)  # get object at this coordinate
                    if (i, j) == agent_pos:  ## agent
                        array[i, j, 0] = OBJECT_TO_IDX['agent'] / IDX_LEN
                        array[i, j, 1] = agent_dir / 3
                    elif v is None:  ## empty
                        array[i, j, 0] = OBJECT_TO_IDX['empty'] / IDX_LEN
                    elif v.type is 'mine':  ## mine
                        array[i, j, 0] = OBJECT_TO_IDX[v.type] / IDX_LEN
                        array[i, j, 2] = v.mine_id / self._args['mw_mine_size']
                    else:  ## goal
                        array[i, j, 0] = OBJECT_TO_IDX[v.type] / IDX_LEN
        else:
            for i in range(self.width):
                for j in range(self.height):
                    v = self.get(i, j)  # get object at this coordinate
                    if (i, j) == agent_pos:
                        array[i, j, OBJECT_TO_IDX['agent']] = 1
                        if self._args['mw_dir_one_hot']:
                            array[i, j, IDX_LEN + 3] = agent_dir / 3
                        else:
                            array[i, j, IDX_LEN + 3] = DIR_TO_VEC[agent_dir][0]
                            array[i, j, IDX_LEN + 4] = DIR_TO_VEC[agent_dir][1]
                    elif v is None:
                        array[i, j, OBJECT_TO_IDX['empty']] = 1
                    else:
                        array[i, j, OBJECT_TO_IDX[v.type]] = 1
                        array[i, j, IDX_LEN] = self.COLOR_TO_IDX[v.color][0] / 255.0
                        array[i, j, IDX_LEN + 1] = self.COLOR_TO_IDX[v.color][1] / 255.0
                        array[i, j, IDX_LEN + 2] = self.COLOR_TO_IDX[v.color][2] / 255.0

        if self._args['mw_obs_flatten']:
            return array.reshape(-1, )
        else:
            return array

    def encode_truth(self, agent_pos=None, agent_dir=None):
        """
        Produce a compact numpy encoding of the grid
        """
        if agent_pos is None:
            agent_pos = (0, 0)
        if agent_dir is None:
            agent_dir = 0

        goal_pos = None
        for i in range(self.width):
            for j in range(self.height):
                cell = self.get(i=i, j=j)
                if cell is not None and cell.type == 'goal':
                    goal_pos = (i, j)
                    break
        surround_cells = [] ## right down left up
        for i in range(4):
            x = agent_pos[0] + DIR_TO_VEC[i][0]
            y = agent_pos[1] + DIR_TO_VEC[i][1]
            if x >= 0 and x < self.width and y >= 0 and y < self.height:
                v = self.get(x, y)
                if self._args["mw_obs_goal"] and v is not None and v.type == 'goal':
                    v = 2
                else:
                    if v is None or v.type == 'goal':
                        v = 1
                    else:
                        v = 0

            else:
                v = 0
            surround_cells.append(v)

        front_cell_number = self._args['mw_mine_size']
        front_cell_x = agent_pos[0] + DIR_TO_VEC[agent_dir][0]
        front_cell_y = agent_pos[1] + DIR_TO_VEC[agent_dir][1]

        if self._args["mw_obs_wall"]:
            if front_cell_x == -1 or front_cell_x == self.width or front_cell_y == -1 or front_cell_y == self.height:
                front_cell_number = self._args['mw_mine_size'] + 1
            if self._args["mw_obs_goal"]:
                if front_cell_x == goal_pos[0] and front_cell_y ==goal_pos[1]:
                    front_cell_number = self._args['mw_mine_size'] + 2

        if front_cell_x >= 0 and front_cell_x < self.width and front_cell_y >= 0 and front_cell_y < self.height:
            front_cell = self.get(front_cell_x, front_cell_y)
            if front_cell is not None and front_cell.type == 'mine':
                front_cell_number = front_cell.mine_id

        final_return = [agent_pos[0], agent_pos[1], agent_dir] + surround_cells + [front_cell_number]
        if self._args["mw_obs_mine_one_hot"]:
            front_cell_number_one_hot = [0.0 for _ in range(self._args['mw_mine_size'] + 1)]

            if self._args["mw_obs_wall"]:
                if self._args["mw_obs_goal"]:
                    front_cell_number_one_hot = [0.0 for _ in range(self._args['mw_mine_size'] + 3)]
                else:
                    front_cell_number_one_hot = [0.0 for _ in range(self._args['mw_mine_size'] + 2)]

            front_cell_number_one_hot[front_cell_number] = 1.0
            final_return = [agent_pos[0], agent_pos[1], agent_dir] + surround_cells + front_cell_number_one_hot
        return np.array(final_return)






    def encode_mine(self):
        # if self._args["mw_if_simple_obs"]:
        #     channels = 1
        if self._args['mw_if_obs_mine_number']:
            channels = self._args['mw_mine_size'] + 1
        else:
            channels = 3

        array = np.zeros((1, 1, channels), dtype='float')
        v = self.get(0, 0)  # get object at this coordinate

        # if self._args["mw_if_simple_obs"]:
        #     if v is not None: array[0, 0, 0] = OBJECT_TO_IDX[v.type]
        if self._args['mw_if_obs_mine_number']:
            array[0, 0, 0] = self.COLOR_TO_IDX[v.color][0] / 255.0
        else:
            array[0, 0, 0] = self.COLOR_TO_IDX[v.color][0] / 255.0
            array[0, 0, 1] = self.COLOR_TO_IDX[v.color][1] / 255.0
            array[0, 0, 2] = self.COLOR_TO_IDX[v.color][2] / 255.0
        return array

    @staticmethod
    def decode(array):
        """
        Decode an array grid encoding back into a grid
        """

        width, height, channels = array.shape
        assert channels == 3

        grid = Grid(width, height)
        for i in range(width):
            for j in range(height):
                typeIdx, colorIdx, state = array[i, j]

                if typeIdx == OBJECT_TO_IDX['unseen'] or typeIdx == OBJECT_TO_IDX['empty']:
                    continue

                objType = IDX_TO_OBJECT[typeIdx]
                color = IDX_TO_COLOR[colorIdx]

                if objType == 'wall':
                    v = Wall(color=color)
                elif objType == 'floor':
                    v = Floor(color=color)
                elif objType == 'mine':
                    v = Mine(color=color)
                elif objType == 'goal':
                    v = Goal(color=color)
                else:
                    assert False, "unknown obj type in decode '%s'" % objType

                grid.set(i, j, v)

        return grid

    def process_vis(grid, agent_pos):
        mask = np.zeros(shape=(grid.width, grid.height), dtype=np.bool)

        mask[agent_pos[0], agent_pos[1]] = True

        for j in reversed(range(0, grid.height)):
            for i in range(0, grid.width - 1):
                if not mask[i, j]:
                    continue

                mask[i + 1, j] = True
                if j > 0:
                    mask[i + 1, j - 1] = True
                    mask[i, j - 1] = True

            for i in reversed(range(1, grid.width)):
                if not mask[i, j]:
                    continue

                mask[i - 1, j] = True
                if j > 0:
                    mask[i - 1, j - 1] = True
                    mask[i, j - 1] = True

        for j in range(0, grid.height):
            for i in range(0, grid.width):
                if not mask[i, j]:
                    grid.set(i, j, None)
        return mask


class MiniGridEnv(gym.Env):
    """
    2D grid world game environment
    """

    # Enumeration of possible actions
    class Actions(IntEnum):
        # Turn left, turn right, move forward
        left = 0
        right = 1
        forward = 2

    def __init__(
            self,
            grid_size=None,
            width=None,
            height=None,
            max_steps=100,
            see_through_walls=True,  # False
            seed=1337,
            agent_view_size=7,
            use_subgoal_reward=False,
            args: dict = None
    ):
        self._args = args
        self.agent_view_size = agent_view_size
        # Can't set both grid_size and width/height
        if grid_size:
            assert width == None and height == None
            width = grid_size
            height = grid_size

        # Action enumeration for this environment
        self.actions = MiniGridEnv.Actions

        # Actions are discrete integer values
        self.action_space = spaces.Discrete(len(self.actions))

        # Number of cells (width and height) in the agent view
        self.agent_view_size = agent_view_size

        # Observations are dictionaries containing an
        # encoding of the grid and a textual 'mission' string=
        _dim = self._args['mw_obs_channel']
        obs_size = self._args['mw_observation_size']
        if self._args['mw_obs_flatten']:
            self.observation_space = spaces.Box(
                low=0,
                high=1,
                shape=(obs_size * obs_size * _dim,),
                dtype='float'
            )
        else:
            self.observation_space = spaces.Box(
                low=0,
                high=1,
                shape=(obs_size, obs_size, _dim),
                dtype='float'
            )
        # self.observation_space = spaces.Dict({
        #     'user': self.observation_space
        # })

        # Range of possible rewards
        self.reward_range = (0, 1)

        # Renderer object used to render the whole grid (full-scale)
        self.grid_render = None

        # Renderer used to render observations (small-scale agent view)
        self.obs_render = None

        # Environment configuration
        self.width = width
        self.height = height
        self.max_steps = max_steps
        self.see_through_walls = see_through_walls

        # Starting position and direction for the agent
        self.start_pos = None
        self.start_dir = None

        # Subgoal Reward
        self.use_subgoal_reward = use_subgoal_reward
        self.subgoal_reached = False
        self.subgoal_reward = 0.1

        self.goal_reached = False
        self.episode_reward = 0.

        self.fixed_rivers = False

        self.render_info = False

        # Initialize the RNG
        # stochastic processes are like, np.random or random.random etc!
        # self.seed(seed=seed)

        # self.save_video = False
        # self.rgb_frames = []

        # Initialize the state
        self.reset(if_on_launch=True)

    def reset(self, if_on_launch: bool = False):
        # Generate a new random grid at the start of each episode
        # To keep the same grid for each episode, call env.seed() with
        # the same seed before calling env.reset()

        # Subgoal Reward
        self.subgoal_reached = False

        self.goal_reached = False
        self.episode_reward = 0.

        self._gen_grid(self.width, self.height, if_on_launch=if_on_launch)
        if if_on_launch:
            self.start_pos = self.start_pos_original

        # These fields should be defined by _gen_grid
        assert self.start_pos is not None
        assert self.start_dir is not None

        # Check that the agent doesn't overlap with an object
        start_cell = self.grid.get(*self.start_pos)

        # Place the agent in the starting position and direction
        self.agent_pos = self.start_pos
        self.agent_dir = self.start_dir

        # Step count since episode start
        self.step_count = 0

        # Return first observation
        obs = self.gen_obs()

        # if self.save_video:
        #     self.rgb_frames = [self.render(mode='rgb_array')]
        return obs

    def seed(self, seed=1337):
        # Seed the random number generator
        self.np_random, _ = seeding.np_random(seed)
        self._seed = seed
        return [seed]

    @property
    def steps_remaining(self):
        return self.max_steps - self.step_count

    def __str__(self):
        """
        Produce a pretty string of the environment's grid along with the agent.
        A grid cell is represented by 2-character string, the first one for
        the object and the second one for the color.
        """

        # Map of object types to short string
        OBJECT_TO_STR = {
            'wall': 'W',
            'floor': 'F',
            'door': 'D',
            'key': 'K',
            'ball': 'A',
            'box': 'B',
            'mine': 'M',
            'goal': 'G',
            'subgoal': 'S',
            'lava': 'V',
        }

        # Short string for opened door
        OPENDED_DOOR_IDS = '_'

        # Map agent's direction to short string
        AGENT_DIR_TO_STR = {
            0: '>',
            1: 'V',
            2: '<',
            3: '^'
        }

        str = ''

        for j in range(self.grid.height):

            for i in range(self.grid.width):
                if i == self.agent_pos[0] and j == self.agent_pos[1]:
                    str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
                    continue

                c = self.grid.get(i, j)

                if c == None:
                    str += '  '
                    continue

                if c.type == 'door':
                    if c.is_open:
                        str += '__'
                    elif c.is_locked:
                        str += 'L' + c.color[0].upper()
                    else:
                        str += 'D' + c.color[0].upper()
                    continue

                str += OBJECT_TO_STR[c.type] + c.color[0].upper()

            if j < self.grid.height - 1:
                str += '\n'

        return str

    def _gen_grid(self, width, height, if_on_launch: bool = False):
        assert False, "_gen_grid needs to be implemented by each environment"

    def _step_reward(self):
        return -(1 / self.max_steps)

    def _reward(self):
        """
        Compute the reward to be given upon success
        """

        return 1.0
        # return 1. - 0.9 * (self.step_count / self.max_steps)

    def _rand_int(self, low, high):
        """
        Generate random integer in [low,high[
        """

        return self.np_random.randint(low, high)

    def _rand_float(self, low, high):
        """
        Generate random float in [low,high[
        """

        return self.np_random.uniform(low, high)

    def _rand_bool(self):
        """
        Generate random boolean value
        """

        return (self.np_random.randint(0, 2) == 0)

    def _rand_elem(self, iterable):
        """
        Pick a random element in a list
        """

        lst = list(iterable)
        idx = self._rand_int(0, len(lst))
        return lst[idx]

    def _rand_subset(self, iterable, num_elems):
        """
        Sample a random subset of distinct elements of a list
        """

        lst = list(iterable)
        assert num_elems <= len(lst)

        out = []

        while len(out) < num_elems:
            elem = self._rand_elem(lst)
            lst.remove(elem)
            out.append(elem)

        return out

    def _rand_color(self):
        """
        Generate a random color name (string)
        """

        return self._rand_elem(COLOR_NAMES)

    def _rand_pos(self, xLow, xHigh, yLow, yHigh):
        """
        Generate a random (x,y) position tuple
        """

        return (
            self.np_random.randint(xLow, xHigh),
            self.np_random.randint(yLow, yHigh)
        )

    def place_obj(self,
                  obj,
                  top=None,
                  size=None,
                  pos=None,
                  reject_fn=None,
                  max_tries=math.inf):
        """
        Place an object at an empty position in the grid

        :param top: top-left position of the rectangle where to place
        :param size: size of the rectangle where to place
        :param reject_fn: function to filter out potential positions
        """

        if top is None:
            top = (0, 0)

        if size is None:
            size = (self.grid.width, self.grid.height)

        num_tries = 0

        while True:
            # This is to handle with rare cases where rejection sampling
            # gets stuck in an infinite loop
            if num_tries > max_tries:
                raise RecursionError('rejection sampling failed in place_obj')

            num_tries += 1

            if pos is None:
                pos = np.array((
                    self._rand_int(top[0], top[0] + size[0]),
                    self._rand_int(top[1], top[1] + size[1])
                ))

            # Don't place the object on top of another object
            if self.grid.get(*pos) != None:
                continue

            # Don't place the object where the agent is
            if np.array_equal(pos, self.start_pos):
                continue

            # Check if there is a filtering criterion
            if reject_fn and reject_fn(self, pos):
                continue

            break

        self.grid.set(*pos, obj)

        if obj is not None:
            obj.init_pos = pos
            obj.cur_pos = pos

        return pos

    def place_agent(
            self,
            top=None,
            size=None,
            pos=None,
            rand_dir=True,
            max_tries=math.inf
    ):
        """
        Set the agent's starting point at an empty position in the grid
        """

        self.start_pos = None
        pos = self.place_obj(obj=None, top=top, size=size, pos=pos, max_tries=max_tries)
        self.start_pos = pos
        self.start_pos_original = pos

        if rand_dir:
            self.start_dir = self._rand_int(0, 4)

        return pos

    @property
    def dir_vec(self):
        """
        Get the direction vector for the agent, pointing in the direction
        of forward movement.
        """

        assert self.agent_dir >= 0 and self.agent_dir < 4
        return DIR_TO_VEC[self.agent_dir]

    @property
    def right_vec(self):
        """
        Get the vector pointing to the right of the agent.
        """

        dx, dy = self.dir_vec
        return np.array((-dy, dx))

    @property
    def front_pos(self):
        """
        Get the position of the cell that is right in front of the agent
        """

        return self.agent_pos + self.dir_vec

    def get_view_coords(self, i, j):
        """
        Translate and rotate absolute grid coordinates (i, j) into the
        agent's partially observable view (sub-grid). Note that the resulting
        coordinates may be negative or outside of the agent's view size.
        """

        ax, ay = self.agent_pos
        dx, dy = self.dir_vec
        rx, ry = self.right_vec

        # Compute the absolute coordinates of the top-left view corner
        sz = self.agent_view_size
        hs = self.agent_view_size // 2
        tx = ax + (dx * (sz - 1)) - (rx * hs)
        ty = ay + (dy * (sz - 1)) - (ry * hs)

        lx = i - tx
        ly = j - ty

        # Project the coordinates of the object relative to the top-left
        # corner onto the agent's own coordinate system
        vx = (rx * lx + ry * ly)
        vy = -(dx * lx + dy * ly)

        return vx, vy

    def get_view_exts(self):
        """
        Get the extents of the square set of tiles visible to the agent
        """

        # Facing right
        if self.agent_dir == 0:
            topX = self.agent_pos[0]
            topY = self.agent_pos[1] - self.agent_view_size // 2
        # Facing down
        elif self.agent_dir == 1:
            topX = self.agent_pos[0] - self.agent_view_size // 2
            topY = self.agent_pos[1]
        # Facing left
        elif self.agent_dir == 2:
            topX = self.agent_pos[0] - self.agent_view_size + 1
            topY = self.agent_pos[1] - self.agent_view_size // 2
        # Facing up
        elif self.agent_dir == 3:
            topX = self.agent_pos[0] - self.agent_view_size // 2
            topY = self.agent_pos[1] - self.agent_view_size + 1
        else:
            assert False, "invalid agent direction"

        botX = topX + self.agent_view_size
        botY = topY + self.agent_view_size

        return (topX, topY, botX, botY)

    def relative_coords(self, x, y):
        """
        Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
        """

        vx, vy = self.get_view_coords(x, y)

        if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
            return None

        return vx, vy

    def in_view(self, x, y):
        """
        check if a grid position is visible to the agent
        """

        return self.relative_coords(x, y) is not None

    def agent_sees(self, x, y):
        """
        Check if a non-empty grid position is visible to the agent
        """

        coordinates = self.relative_coords(x, y)
        if coordinates is None:
            return False
        vx, vy = coordinates

        obs = self.gen_obs()
        obs_grid = Grid.decode(obs['user'])
        obs_cell = obs_grid.get(vx, vy)
        world_cell = self.grid.get(x, y)

        return obs_cell is not None and obs_cell.type == world_cell.type

    def step(self, action):
        self.step_count += 1

        # reward = 0.0
        reward = self._step_reward()
        done = False

        info = {}

        # Get the position in front of the agent
        fwd_pos = self.front_pos

        # Get the contents of the cell in front of the agent
        fwd_cell = self.grid.get(*fwd_pos)

        # Rotate left
        if action == self.actions.left:
            self.agent_dir -= 1
            if self.agent_dir < 0:
                self.agent_dir += 4

        # Rotate right
        elif action == self.actions.right:
            self.agent_dir = (self.agent_dir + 1) % 4

        # Move forward
        elif action == self.actions.forward:
            if fwd_cell == None or fwd_cell.can_overlap():
                self.agent_pos = fwd_pos
                reward = self._step_reward()  # refresh step reward with new agent position

            if self.use_subgoal_reward and fwd_cell != None and fwd_cell.type == 'subgoal' and not self.subgoal_reached:
                reward += self.subgoal_reward
                self.subgoal_reached = True

            if fwd_cell != None and fwd_cell.type == 'goal':
                done = True
                self.goal_reached = True
                reward = self._reward()

            if fwd_cell != None and fwd_cell.type == 'lava':
                done = True

        # # Pick up an object
        # elif action == self.actions.pickup:
        #     if fwd_cell and fwd_cell.can_pickup():
        #         if self.carrying is None:
        #             self.carrying = fwd_cell
        #             self.carrying.cur_pos = np.array([-1, -1])
        #             self.grid.set(*fwd_pos, None)
        #
        # # Drop an object
        # elif action == self.actions.drop:
        #     if not fwd_cell and self.carrying:
        #         self.grid.set(*fwd_pos, self.carrying)
        #         self.carrying.cur_pos = fwd_pos
        #         self.carrying = None
        #
        # # Toggle/activate an object
        # elif action == self.actions.toggle:
        #     if fwd_cell:
        #         fwd_cell.toggle(self, fwd_pos)

        # # Done action (not used by default)
        # elif action == self.actions.done:
        #     pass

        else:
            assert False, "unknown action"

        if self.step_count >= self.max_steps:
            done = True

        obs = self.gen_obs()

        self.episode_reward += reward

        if self.render_info:
            info['frame'] = self.render(mode='rgb_array')

        # if self.save_video:
        #     self.rgb_frames.append(info['frame'])

        if done:
            info['ep_success'] = self.goal_reached
            info['ep_len'] = self.step_count
            info['ep_reward'] = self.episode_reward
            info['ep_subgoal_reached'] = self.subgoal_reached
        return obs, reward, done, info

    def gen_obs_grid(self):
        """
        Generate the sub-grid observed by the agent.
        This method also outputs a visibility mask telling us which grid
        cells the agent can actually see.
        """

        topX, topY, botX, botY = self.get_view_exts()

        grid = self.grid.slice(topX, topY, self.agent_view_size, self.agent_view_size)

        # if self._args["mw_if_rotate_obs"]:
        for i in range(self.agent_dir + 1):
            grid = grid.rotate_left()

        # Process occluders and visibility
        if not self.see_through_walls:
            vis_mask = grid.process_vis(agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1))
        else:
            vis_mask = np.ones(shape=(grid.width, grid.height), dtype=np.bool)

        # Make it so the agent sees what it's carrying
        # We do this by placing the carried object at the agent's position
        # in the agent's partially observable view
        agent_pos = grid.width // 2, grid.height - 1
        grid.set(*agent_pos, None)
        return grid, vis_mask

    def gen_obs(self):
        """
        Generate the agent's view (partially observable, low-resolution encoding)
        """
        grid, vis_mask = self.gen_obs_grid()
        return grid.encode(vis_mask)  # Encode the partially observable view into a numpy array

    def get_obs_render(self, obs, tile_pixels=CELL_PIXELS // 2):
        """
        Render an agent observation for visualization
        """

        if self.obs_render == None:
            from large_rl.envs.minigrid.rendering import Renderer
            self.obs_render = Renderer(
                self.agent_view_size * tile_pixels,
                self.agent_view_size * tile_pixels
            )

        r = self.obs_render

        r.beginFrame()

        grid = Grid.decode(obs)

        # Render the whole grid
        grid.render(r, tile_pixels)

        # Draw the agent
        ratio = tile_pixels / CELL_PIXELS
        r.push()
        r.scale(ratio, ratio)
        r.translate(
            CELL_PIXELS * (0.5 + self.agent_view_size // 2),
            CELL_PIXELS * (self.agent_view_size - 0.5)
        )
        r.rotate(3 * 90)
        r.setLineColor(255, 0, 0)
        r.setColor(255, 0, 0)
        r.drawPolygon([
            (-CELL_PIXELS / 3, CELL_PIXELS / 3),
            (CELL_PIXELS / 3, 0),
            (-CELL_PIXELS / 3, -CELL_PIXELS / 3)
        ])
        r.pop()

        r.endFrame()

        return r.getArray()

        return r.getPixmap()

    def render(self, mode='human', close=False):
        """
        Render the whole-grid human view
        """
        if self._args["mw_test_save_video"] or self._args["mw_train_save_video"]:
            if close:
                if self.grid_render:
                    self.grid_render.close()
                return

            if self.grid_render is None:
                from large_rl.envs.mining.rendering import Renderer
                self.grid_render = Renderer(
                    self.width * CELL_PIXELS,
                    self.height * CELL_PIXELS,
                    True if mode == 'human' else False
                )

            r = self.grid_render

            if r.window:
                r.window.setText(self.mission)

            r.beginFrame()

            # Render the whole grid
            self.grid.render(r, CELL_PIXELS)

            # Draw the agent
            r.push()
            r.translate(
                CELL_PIXELS * (self.agent_pos[0] + 0.5),
                CELL_PIXELS * (self.agent_pos[1] + 0.5)
            )
            r.rotate(self.agent_dir * 90)
            r.setLineColor(255, 0, 0)
            r.setColor(255, 0, 0)
            r.drawPolygon([
                (-CELL_PIXELS // 3 - CELL_PIXELS // 15, CELL_PIXELS // 3),
                (CELL_PIXELS // 3 + CELL_PIXELS // 15, 0),
                (-CELL_PIXELS // 3 - CELL_PIXELS // 15, -CELL_PIXELS // 3)
                # (-12, 10),
                # (12, 0),
                # (-12, -10)
            ])
            r.pop()

            # Compute which cells are visible to the agent
            _, vis_mask = self.gen_obs_grid()
            r.endFrame()
            if mode == 'rgb_array':
                return r.getArray()
            elif mode == 'pixmap':
                return r.getPixmap()

            return r
        else:
            return 1
