import time
import warnings
from collections import deque
from enum import IntEnum

import numpy as np
from numpy import array
from recordtype import recordtype

from flatland.envs.step_utils.states import TrainState

from flatland.utils.graphics_pil import PILGL, PILSVG
from flatland.utils.graphics_pgl import PGLGL


# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!

class AgentRenderVariant(IntEnum):
    BOX_ONLY = 0
    ONE_STEP_BEHIND = 1
    AGENT_SHOWS_OPTIONS = 2
    ONE_STEP_BEHIND_AND_BOX = 3
    AGENT_SHOWS_OPTIONS_AND_BOX = 4


class RenderTool(object):
    """ RenderTool is a facade to a renderer.
        (This was introduced for the Browser / JS renderer which has now been removed.)
    """
    def __init__(self, env, gl="PGL", jupyter=False,
                 agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
                 show_debug=False, clear_debug_text=True, screen_width=800, screen_height=600,
                 host="localhost", port=None):

        self.env = env
        self.frame_nr = 0
        self.start_time = time.time()
        self.times_list = deque()

        self.agent_render_variant = agent_render_variant

        if gl in ["PIL", "PILSVG", "PGL"]:
            self.renderer = RenderLocal(env, gl, jupyter,
                 agent_render_variant,
                 show_debug, clear_debug_text, screen_width, screen_height)
            self.gl = self.renderer.gl
        else:
            print("[", gl, "] not found, switch to PGL")

    def render_env(self,
                   show=False,  # whether to call matplotlib show() or equivalent after completion
                   show_agents=True,  # whether to include agents
                   show_inactive_agents=False,  # whether to show agents before they start
                   show_observations=True,  # whether to include observations
                   show_predictions=False,  # whether to include predictions
                   show_rowcols=False, # label the rows and columns
                   frames=False,  # frame counter to show (intended since invocation)
                   episode=None,  # int episode number to show
                   step=None,  # int step number to show in image
                   selected_agent=None,  # indicate which agent is "selected" in the editor):
                   return_image=False): # indicate if image is returned for use in monitor:
        return self.renderer.render_env(show, show_agents, show_inactive_agents, show_observations,
                    show_predictions, show_rowcols, frames, episode, step, selected_agent, return_image)

    def close_window(self):
        self.renderer.close_window()

    def reset(self):
        self.renderer.reset()
    
    def set_new_rail(self):
        self.renderer.set_new_rail()
        self.renderer.env = self.env  # bit of a hack - copy our env to the delegate

    def update_background(self):
        self.renderer.update_background()
    
    def get_endpoint_URL(self):
        """ Returns a string URL for the root of the HTTP server
            TODO: Need to update this work work on a remote server!  May be tricky...
        """
        #return "http://localhost:{}".format(self.renderer.get_port())
        if hasattr(self.renderer, "get_endpoint_url"):
            return self.renderer.get_endpoint_url()
        else:
            print("Attempt to get_endpoint_url from RenderTool - only supported with BROWSER")
            return None

    def get_image(self):
        """ 
        """
        if hasattr(self.renderer, "gl"):
            return self.renderer.gl.get_image()
        else:
            print("Attempt to retrieve image from RenderTool - not supported with BROWSER")
            return None


class RenderBase(object):
    def __init__(self, env):
        pass

    def render_env(self):
        pass

    def close_window(self):
        pass

    def reset(self):
        pass

    def set_new_rail(self):
        """ Signal to the renderer that the env has changed and will need re-rendering.
        """
        pass

    def update_background(self):
        """ A lesser version of set_new_rail?  
            TODO: can update_background be pruned for simplicity?
        """
        pass



class RenderLocal(RenderBase):
    """ Class to render the RailEnv and agents.
        Uses two layers, layer 0 for rails (mostly static), layer 1 for agents etc (dynamic)
        The lower / rail layer 0 is only redrawn after set_new_rail() has been called.
        Created with a "GraphicsLayer" or gl - now either PIL or PILSVG
    """
    visit = recordtype("visit", ["rc", "iDir", "iDepth", "prev"])

    color_list = list("brgcmyk")
    # \delta RC for NESW
    transitions_row_col = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
    pix_per_cell = 1  # misnomer...
    half_pix_per_cell = pix_per_cell / 2
    x_y_half = array([half_pix_per_cell, -half_pix_per_cell])
    row_col_to_xy = array([[0, -pix_per_cell], [pix_per_cell, 0]])
    grid = array(np.meshgrid(np.arange(10), -np.arange(10))) * array([[[pix_per_cell]], [[pix_per_cell]]])
    theta = np.linspace(0, np.pi / 2, 5)
    arc = array([np.cos(theta), np.sin(theta)]).T  # from [1,0] to [0,1]

    def __init__(self, env, gl="PILSVG", jupyter=False,
                 agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
                 show_debug=False, clear_debug_text=True, screen_width=800, screen_height=600):

        self.env = env
        self.frame_nr = 0
        self.start_time = time.time()
        self.times_list = deque()

        self.agent_render_variant = agent_render_variant

        self.gl_str = gl

        if gl == "PIL":
            self.gl = PILGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
        elif gl == "PILSVG":
            self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
        else:
            if gl != "PGL":
                print("[", gl, "] not found, switch to PGL, PILSVG")
                print("Using PGL")
            self.gl = PGLGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)

        self.new_rail = True
        self.show_debug = show_debug
        self.clear_debug_text = clear_debug_text
        self.update_background()

    def reset(self):
        """
        Resets the environment
        :return:
        """
        self.set_new_rail()
        self.frame_nr = 0
        self.start_time = time.time()
        self.times_list = deque()
        return

    def update_background(self):
        # create background map
        targets = {}
        for agent_idx, agent in enumerate(self.env.agents):
            if agent is None:
                continue
            #print(f"updatebg: {agent_idx} {agent.target}")
            targets[tuple(agent.target)] = agent_idx
        self.gl.build_background_map(targets)

    def resize(self):
        self.gl.resize(self.env)

    def set_new_rail(self):
        """ Tell the renderer that the rail has changed.
            eg when the rail has been regenerated, or updated in the editor.
        """
        self.new_rail = True

    def plot_agents(self, targets=True, selected_agent=None):
        color_map = self.gl.get_cmap('hsv', lut=(len(self.env.agents) + 1))

        for agent_idx, agent in enumerate(self.env.agents):
            if agent is None:
                continue
            color = color_map(agent_idx)
            self.plot_single_agent(agent.position, agent.direction, color, target=agent.target if targets else None,
                                   static=True, selected=agent_idx == selected_agent)

        for agent_idx, agent in enumerate(self.env.agents):
            if agent is None:
                continue
            color = color_map(agent_idx)
            self.plot_single_agent(agent.position, agent.direction, color, target=agent.target if targets else None)

    def get_transition_row_col(self, row_col_pos, direction, bgiTrans=False):
        """
        Get the available transitions for row_col_pos in direction direction,
        as row & col deltas.

        If bgiTrans is True, return a grid of indices of available transitions.

        eg for a cell row_col_pos = (4,5), in direction direction = 0 (N),
        where the available transitions are N and E, returns:
        [[-1,0], [0,1]] ie N=up one row, and E=right one col.
        and if bgiTrans is True, returns a tuple:
        (
            [[-1,0], [0,1]], # deltas as before
            [0, 1] #  available transition indices, ie N, E
        )
        """

        transitions = self.env.rail.get_transitions(*row_col_pos, direction)
        transition_list = np.where(transitions)[0]  # RC list of transitions

        # HACK: workaround dead-end transitions
        if len(transition_list) == 0:
            reverse_direciton = (direction + 2) % 4
            transitions = tuple(int(tmp_dir == reverse_direciton) for tmp_dir in range(4))
            transition_list = np.where(transitions)[0]  # RC list of transitions

        transition_grid = self.__class__.transitions_row_col[transition_list]

        if bgiTrans:
            return transition_grid, transition_list
        else:
            return transition_grid

    def plot_single_agent(self, position_row_col, direction, color="r", target=None, static=False, selected=False):
        """
        Plot a simple agent.
        Assumes a working graphics layer context (cf a MPL figure).
        """
        if position_row_col is None:
            return

        rt = self.__class__

        direction_row_col = rt.transitions_row_col[direction]  # agent direction in RC
        direction_xy = np.matmul(direction_row_col, rt.row_col_to_xy)  # agent direction in xy

        xyPos = np.matmul(position_row_col - direction_row_col / 2, rt.row_col_to_xy) + rt.x_y_half

        if static:
            color = self.gl.adapt_color(color, lighten=True)

        color = color

        self.gl.scatter(*xyPos, color=color, layer=1, marker="o", s=100)  # agent location
        xy_dir_line = array([xyPos, xyPos + direction_xy / 2]).T  # line for agent orient.
        self.gl.plot(*xy_dir_line, color=color, layer=1, lw=5, ms=0, alpha=0.6)
        if selected:
            self._draw_square(xyPos, 1, color)

        if target is not None:
            target_row_col = array(target)
            target_xy = np.matmul(target_row_col, rt.row_col_to_xy) + rt.x_y_half
            self._draw_square(target_xy, 1 / 3, color, layer=1)

    def plot_transition(self, position_row_col, transition_row_col, color="r", depth=None):
        """
        plot the transitions in transition_row_col at position position_row_col.
        transition_row_col is a 2d numpy array containing a list of RC transitions,
        eg [[-1,0], [0,1]] means N, E.

        """

        rt = self.__class__
        position_xy = np.matmul(position_row_col, rt.row_col_to_xy) + rt.x_y_half
        transition_xy = position_xy + np.matmul(transition_row_col, rt.row_col_to_xy / 2.4)
        self.gl.scatter(*transition_xy.T, color=color, marker="o", s=50, alpha=0.2)
        if depth is not None:
            for x, y in transition_xy:
                self.gl.text(x, y, depth)

    def draw_transition(self,
                        line,
                        center,
                        rotation,
                        dead_end=False,
                        curves=False,
                        color="gray",
                        arrow=True,
                        spacing=0.1):
        """
        gLine is a numpy 2d array of points,
        in the plotting space / coords.
        eg:
        [[0,.5],[1,0.2]] means a line
        from x=0, y=0.5
        to   x=1, y=0.2
        """

        if not curves and not dead_end:
            # just a straigt line, no curve nor dead_end included in this basic rail element
            self.gl.plot(
                [line[0][0], line[1][0]],  # x
                [line[0][1], line[1][1]],  # y
                color=color
            )
        else:
            # it was not a simple line to draw: the rail has a curve or dead_end included.
            rt = self.__class__
            straight = rotation in [0, 2]
            dx, dy = np.squeeze(np.diff(line, axis=0)) * spacing / 2

            if straight:

                if color == "auto":
                    if dx > 0 or dy > 0:
                        color = "C1"  # N or E
                    else:
                        color = "C2"  # S or W

                if dead_end:
                    line_xy = array([
                        line[1] + [dy, dx],
                        center,
                        line[1] - [dy, dx],
                    ])
                    self.gl.plot(*line_xy.T, color=color)
                else:
                    line_xy = line + [-dy, dx]
                    self.gl.plot(*line_xy.T, color=color)

                    if arrow:
                        middle_xy = np.sum(line_xy * [[1 / 4], [3 / 4]], axis=0)

                        arrow_xy = array([
                            middle_xy + [-dx - dy, +dx - dy],
                            middle_xy,
                            middle_xy + [-dx + dy, -dx - dy]])
                        self.gl.plot(*arrow_xy.T, color=color)

            else:

                middle_xy = np.mean(line, axis=0)
                dxy = middle_xy - center
                corner = middle_xy + dxy
                if rotation == 1:
                    arc_factor = 1 - spacing
                    color_auto = "C1"
                else:
                    arc_factor = 1 + spacing
                    color_auto = "C2"
                dxy2 = (center - corner) * arc_factor  # for scaling the arc

                if color == "auto":
                    color = color_auto

                self.gl.plot(*(rt.arc * dxy2 + corner).T, color=color)

                if arrow:
                    dx, dy = np.squeeze(np.diff(line, axis=0)) / 20
                    iArc = int(len(rt.arc) / 2)
                    middle_xy = corner + rt.arc[iArc] * dxy2
                    arrow_xy = array([
                        middle_xy + [-dx - dy, +dx - dy],
                        middle_xy,
                        middle_xy + [-dx + dy, -dx - dy]])
                    self.gl.plot(*arrow_xy.T, color=color)

    def render_observation(self, agent_handles, observation_dict):
        """
        Render the extent of the observation of each agent. All cells that appear in the agent
        observation will be highlighted.
        :param agent_handles: List of agent indices to adapt color and get correct observation
        :param observation_dict: dictionary containing sets of cells of the agent observation

        """
        rt = self.__class__

        # Check if the observation builder provides an observation
        if len(observation_dict) < 1:
            warnings.warn(
                "Predictor did not provide any predicted cells to render. \
                Observation builder needs to populate: env.dev_obs_dict")
        else:
            for agent in agent_handles:
                color = self.gl.get_agent_color(agent)
                for visited_cell in observation_dict[agent]:
                    cell_coord = array(visited_cell[:2])
                    cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
                    self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)

    def render_prediction(self, agent_handles, prediction_dict):
        """
        Render the extent of the observation of each agent. All cells that appear in the agent
        observation will be highlighted.
        :param agent_handles: List of agent indices to adapt color and get correct observation
        :param observation_dict: dictionary containing sets of cells of the agent observation

        """
        rt = self.__class__
        if len(prediction_dict) < 1:
            warnings.warn(
                "Predictor did not provide any predicted cells to render. \
                Predictors builder needs to populate: env.dev_pred_dict")
        else:
            for agent in agent_handles:
                color = self.gl.get_agent_color(agent)
                for visited_cell in prediction_dict[agent]:
                    cell_coord = array(visited_cell[:2])
                    if type(self.gl) is PILSVG:
                        # TODO : Track highlighting (Adrian)
                        r = cell_coord[0]
                        c = cell_coord[1]
                        transitions = self.env.rail.grid[r, c]
                        self.gl.set_predicion_path_at(r, c, transitions, agent_rail_color=color)
                    else:
                        cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
                        self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)

    def render_rail(self, spacing=False, rail_color="gray", curves=True, arrows=False):

        cell_size = 1  # TODO: remove cell_size
        env = self.env

        # Draw cells grid
        grid_color = [0.95, 0.95, 0.95]
        for row in range(env.height + 1):
            self.gl.plot([0, (env.width + 1) * cell_size],
                         [-row * cell_size, -row * cell_size],
                         color=grid_color, linewidth=2)
        for col in range(env.width + 1):
            self.gl.plot([col * cell_size, col * cell_size],
                         [0, -(env.height + 1) * cell_size],
                         color=grid_color, linewidth=2)

        # Draw each cell independently
        for row in range(env.height):
            for col in range(env.width):

                # bounding box of the grid cell
                x0 = cell_size * col  # left
                x1 = cell_size * (col + 1)  # right
                y0 = cell_size * -row  # top
                y1 = cell_size * -(row + 1)  # bottom

                # centres of cell edges
                coords = [
                    ((x0 + x1) / 2.0, y0),  # N middle top
                    (x1, (y0 + y1) / 2.0),  # E middle right
                    ((x0 + x1) / 2.0, y1),  # S middle bottom
                    (x0, (y0 + y1) / 2.0)  # W middle left
                ]

                # cell centre
                center_xy = array([x0, y1]) + cell_size / 2

                # cell transition values
                cell = env.rail.get_full_transitions(row, col)

                cell_valid = env.rail.cell_neighbours_valid((row, col), check_this_cell=True)

                # Special Case 7, with a single bit; terminate at center
                nbits = 0
                tmp = cell

                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1

                # as above - move the from coord to the centre
                # it's a dead env.
                is_dead_end = nbits == 1

                if not cell_valid:
                    self.gl.scatter(*center_xy, color="r", s=30)

                for orientation in range(4):  # ori is where we're heading
                    from_ori = (orientation + 2) % 4  # 0123=NESW -> 2301=SWNE
                    from_xy = coords[from_ori]

                    moves = env.rail.get_transitions(row, col, orientation)

                    for to_ori in range(4):
                        to_xy = coords[to_ori]
                        rotation = (to_ori - from_ori) % 4
                        if (moves[to_ori]):  # if we have this transition
                            self.draw_transition(
                                array([from_xy, to_xy]), center_xy,
                                rotation, dead_end=is_dead_end, curves=curves and not is_dead_end, spacing=spacing,
                                color=rail_color)

    def render_env(self,
                   show=False,  # whether to call matplotlib show() or equivalent after completion
                   show_agents=True,  # whether to include agents
                   show_inactive_agents=False,
                   show_observations=True,  # whether to include observations
                   show_predictions=False,  # whether to include predictions
                   show_rowcols=False,  # label the rows and columns
                   frames=False,  # frame counter to show (intended since invocation)
                   episode=None,  # int episode number to show
                   step=None,  # int step number to show in image
                   selected_agent=None,  # indicate which agent is "selected" in the editor
                   return_image=False): # indicate if image is returned for use in monitor:
        """ Draw the environment using the GraphicsLayer this RenderTool was created with.
            (Use show=False from a Jupyter notebook with %matplotlib inline)
        """

        # if type(self.gl) is PILSVG:
        if self.gl_str in ["PILSVG", "PGL"]:
            return self.render_env_svg(show=show,
                                show_observations=show_observations,
                                show_predictions=show_predictions,
                                selected_agent=selected_agent,
                                show_agents=show_agents,
                                show_inactive_agents=show_inactive_agents,
                                show_rowcols=show_rowcols,
                                return_image=return_image
                                )
        else:
            return self.render_env_pil(show=show,
                                show_agents=show_agents,
                                show_inactive_agents=show_inactive_agents,
                                show_observations=show_observations,
                                show_predictions=show_predictions,
                                show_rowcols=show_rowcols,
                                frames=frames,
                                episode=episode,
                                step=step,
                                selected_agent=selected_agent,
                                return_image=return_image
                                )

    def _draw_square(self, center, size, color, opacity=255, layer=0):
        x0 = center[0] - size / 2
        x1 = center[0] + size / 2
        y0 = center[1] - size / 2
        y1 = center[1] + size / 2
        self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity)

    def get_image(self):
        return self.gl.get_image()

    def render_env_pil(self,
                       show=False,  # whether to call matplotlib show() or equivalent after completion
                       # use false when calling from Jupyter.  (and matplotlib no longer supported!)
                       show_agents=True,  # whether to include agents
                       show_inactive_agents=False, 
                       show_observations=True,  # whether to include observations
                       show_predictions=False,  # whether to include predictions
                       show_rowcols=False, # label the rows and columns
                       frames=False,  # frame counter to show (intended since invocation)
                       episode=None,  # int episode number to show
                       step=None,  # int step number to show in image
                       selected_agent=None,  # indicate which agent is "selected" in the editor
                       return_image=False # indicate if image is returned for use in monitor:
                       ):

        if type(self.gl) is PILGL:
            self.gl.begin_frame()

        env = self.env

        self.render_rail()

        # Draw each agent + its orientation + its target
        if show_agents:
            self.plot_agents(targets=True, selected_agent=selected_agent)
        if show_observations:
            self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
        if show_predictions and len(env.dev_pred_dict) > 0:
            self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict)
        # Draw some textual information like fps
        text_y = [-0.3, -0.6, -0.9]
        if frames:
            self.gl.text(0.1, text_y[2], "Frame:{:}".format(self.frame_nr))
        self.frame_nr += 1

        if episode is not None:
            self.gl.text(0.1, text_y[1], "Ep:{}".format(episode))

        if step is not None:
            self.gl.text(0.1, text_y[0], "Step:{}".format(step))

        time_now = time.time()
        self.gl.text(2, text_y[2], "elapsed:{:.2f}s".format(time_now - self.start_time))
        self.times_list.append(time_now)
        if len(self.times_list) > 20:
            self.times_list.popleft()
        if len(self.times_list) > 1:
            rFps = (len(self.times_list) - 1) / (self.times_list[-1] - self.times_list[0])
            self.gl.text(2, text_y[1], "fps:{:.2f}".format(rFps))

        self.gl.prettify2(env.width, env.height, self.pix_per_cell)

        # TODO: for MPL, we don't want to call clf (called by endframe)
        # if not show:

        if show and type(self.gl) is PILGL:
            self.gl.show()

        self.gl.pause(0.00001)

        if return_image:
            return self.get_image()
        return

    def render_env_svg(
        self, show=False, show_observations=True, show_predictions=False, selected_agent=None,
        show_agents=True, show_inactive_agents=False, show_rowcols=False, return_image=False
    ):
        """
        Renders the environment with SVG support (nice image)
        """

        env = self.env

        self.gl.begin_frame()

        if self.new_rail:
            self.new_rail = False
            self.gl.clear_rails()

            # store the targets
            targets = {}
            selected = {}
            for agent_idx, agent in enumerate(self.env.agents):
                if agent is None:
                    continue
                targets[tuple(agent.target)] = agent_idx
                selected[tuple(agent.target)] = (agent_idx == selected_agent)

            # Draw each cell independently
            for r in range(env.height):
                for c in range(env.width):
                    transitions = env.rail.grid[r, c]
                    if (r, c) in targets:
                        target = targets[(r, c)]
                        is_selected = selected[(r, c)]
                    else:
                        target = None
                        is_selected = False

                    self.gl.set_rail_at(r, c, transitions, target=target, is_selected=is_selected,
                                        rail_grid=env.rail.grid, num_agents=env.get_num_agents(),
                                        show_debug=self.show_debug)

            self.gl.build_background_map(targets)

            if show_rowcols:
                # label rows, cols
                for iRow in range(env.height):
                    self.gl.text_rowcol((iRow, 0), str(iRow), layer=self.gl.RAIL_LAYER)
                for iCol in range(env.width):
                    self.gl.text_rowcol((0, iCol), str(iCol), layer=self.gl.RAIL_LAYER)


        if show_agents:
            for agent_idx, agent in enumerate(self.env.agents):

                if agent is None:
                    continue

                # Show an agent even if it hasn't already started
                if agent.position is None:
                    if show_inactive_agents:
                        # print("agent ", agent_idx, agent.position, agent.old_position, agent.initial_position)
                        self.gl.set_agent_at(agent_idx, *(agent.initial_position), 
                            agent.initial_direction, agent.initial_direction,
                            is_selected=(selected_agent == agent_idx),
                            rail_grid=env.rail.grid,
                            show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
                            malfunction=False)
                    continue

                is_malfunction = agent.malfunction_handler.malfunction_down_counter > 0

                if self.agent_render_variant == AgentRenderVariant.BOX_ONLY:
                    self.gl.set_cell_occupied(agent_idx, *(agent.position))

                elif self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND or \
                    self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:  # noqa: E125

                    # Most common case - the agent has been running for >1 steps
                    if agent.old_position is not None:
                        position = agent.old_position
                        direction = agent.direction
                        old_direction = agent.old_direction

                    # the agent's first step - it doesn't have an old position yet
                    elif agent.position is not None:
                        position = agent.position
                        direction = agent.direction
                        old_direction = agent.direction
                        
                    # When the editor has just added an agent
                    elif agent.initial_position is not None:
                        position = agent.initial_position
                        direction = agent.initial_direction
                        old_direction = agent.initial_direction

                    # set_agent_at uses the agent index for the color
                    if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:
                        self.gl.set_cell_occupied(agent_idx, *(agent.position))
                    self.gl.set_agent_at(agent_idx, *position, old_direction, direction,
                                         selected_agent == agent_idx, rail_grid=env.rail.grid,
                                         show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
                                         malfunction=is_malfunction)
                else:
                    position = agent.position
                    direction = agent.direction
                    for possible_direction in range(4):
                        # Is a transition along movement `desired_movement_from_new_cell` to the current cell possible?
                        isValid = env.rail.get_transition((*agent.position, agent.direction), possible_direction)
                        if isValid:
                            direction = possible_direction

                            # set_agent_at uses the agent index for the color
                            self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
                                                 selected_agent == agent_idx, rail_grid=env.rail.grid,
                                                 show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
                                                 malfunction=is_malfunction)

                    # set_agent_at uses the agent index for the color
                    if self.agent_render_variant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX:
                        self.gl.set_cell_occupied(agent_idx, *(agent.position))
                    
                    if show_inactive_agents:
                        show_this_agent = True
                    else:
                        show_this_agent = agent.state.is_on_map_state()

                    if show_this_agent:
                        self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, 
                                        selected_agent == agent_idx,
                                        rail_grid=env.rail.grid, malfunction=is_malfunction)

        if show_observations:
            self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
        if show_predictions:
            self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict)
        


        if show:
            self.gl.show()
        for i in range(3):
            self.gl.process_events()

        self.frame_nr += 1
        if return_image:
            return self.get_image()
        return

    def close_window(self):
        self.gl.close_window()
