import pygame as pg
import numpy as np
from numpy.typing import NDArray

from copy import deepcopy
from typing import Optional

from shapely import Point, LineString
import shapely

from colour import Color
from .colour_utils import Highlight
from loguru import logger

from molecule_movement.Simulator import Simulator
from molecule_movement import VerticalAction, LateralAction, Goal, Obstacle, Movement, Matching, Molecule
from molecule_movement.colour_utils import c, hex_to_rgb

from molecule_movement.wrapper import CorridorConfiguration

import os
from datetime import datetime

from molecule_movement.MoleculeGameObject import (
        PolygonMoleculeGameObject,
        GoalObject,
        ObstacleObject,
        STM,
        MatchingObject,
        STM_LATERAL_START_INDICATOR_COLOR,
        STM_LATERAL_END_INDICATOR_COLOR,
        STM_VERTICAL_INDICATOR_COLOR
        )

from molecule_movement.PyGameObject import Legend, GridCache

MIN_NM_RENDERING = 3

SURFACE_MAP = {"silver": Color(rgb=(0.02,0,0.1))}

class PyGameRenderer():
    def __init__(
        self,
        simulator: Simulator,
        obstacles: list[Obstacle],
        environment_name: str,
        window_size: tuple[int, int] = (100,100),
        size: tuple[int, int] = (10, 10),
        surface: str = "silver",
        render_grid: bool = False,
        render_sensors: str = "all",
        origin_offset: tuple[float,float] = (0,0),
        flip_y: bool = True,
        flip_x: bool = False,
        draw_names: bool = True,
        store_images: bool = False):
        pg.init()

        self.nm_size = window_size[0] // size[0]
        self.flip_y = flip_y
        self.flip_x = flip_x
        self.origin_offset = origin_offset
        self.window = pg.display.set_mode(window_size)
        self.window_size = window_size
        pg.display.set_caption(f'MoleculeMovement - {environment_name}')
        self.matching_screen = pg.Surface((size[0] * self.nm_size, size[1] * self.nm_size))
        self.movement_screen = pg.Surface((size[0] * self.nm_size, size[1] * self.nm_size))

        self.render_grid = render_grid
        self.render_sensors = render_sensors

        self.surface = SURFACE_MAP[surface]
        self.surface_color = hex_to_rgb(self.surface.hex_l)
        self.window.fill(self.surface_color)
        self.matching_screen.set_colorkey(self.surface_color)
        self.movement_screen.fill(self.surface_color)
        self.movement_screen.set_colorkey(self.surface_color)

        self.simulator = simulator
        self.obstacles = obstacles
        self.environment_name = environment_name
        self.store_images = store_images
        if self.store_images:
            self.__set_directory()
            self.frame_idx = 0

        self.cleared_in_step = False

        logger.info(f"Initialized PyGameRenderer with window_size: {window_size}, surface_size: {size[0]}x{size[1]} and nm_size: {self.nm_size}")

        molecules = self.simulator.molecules
        colors = list(Color("orange").range_to(Color("lightblue"), len(molecules)))
        molecule_name_color_map = {molecules[i].name: colors[i] for i in range(len(molecules))}
        self.molecule_name_color_map = molecule_name_color_map

        self.objects = {}
        for mol in self.simulator.molecules:
            name = mol.name
            self.objects[name] = PolygonMoleculeGameObject(
                molecule=mol,
                nm_size=self.nm_size,
                color=c(molecule_name_color_map[name].rgb),
                colorkey=(0, 0, 0),
                origin_offset=origin_offset,
                flip_y=flip_y,
                flip_x=flip_x,
                draw_names=draw_names,
            )


        self.goal_objects: dict[int, GoalObject] = {}
        self.obstacle_objects: dict[int, ObstacleObject] = {}
        self._matching_objects: dict[str, MatchingObject] = {}

        self._build_static_objects()


        self.draw_names = draw_names

        self.check_rendering_sizes()
        self.build_static_layer()

    def check_rendering_sizes(self):
        if self.nm_size < MIN_NM_RENDERING:
            logger.warning(f"The rendering size for one nm is smaller than {MIN_NM_RENDERING}: {self.nm_size}. The STM manipulations will not be visible")
            logger.warning(f"Disabling the rendering of STM manipulations.")
            self.render_grid = False

    def _build_static_objects(self):
        surface_color = self.surface_color
        flip_x, flip_y, origin_offset = self.flip_x, self.flip_y, self.origin_offset
        nm_size = self.nm_size
        self.goal_objects.clear()
        for g in self.simulator.get_goals():
            self.goal_objects[id(g)] = GoalObject(
                g, nm_size, surface_color, origin_offset, flip_y, flip_x
            )

        self.obstacle_objects.clear()
        for o in self.obstacles:
            self.obstacle_objects[id(o)] = ObstacleObject(
                o, nm_size, surface_color, origin_offset, flip_y, flip_x
            )

        self.stm_vertical = STM(
            nm_size, c(Color(STM_VERTICAL_INDICATOR_COLOR).rgb),
            type="vertical", origin_offset=origin_offset, flip_y=flip_y, flip_x=flip_x
        )
        self.stm_lateral_start = STM(
            nm_size, c(Color(STM_LATERAL_START_INDICATOR_COLOR).rgb),
            type="lateral_start", origin_offset=origin_offset, flip_y=flip_y, flip_x=flip_x
        )
        self.stm_lateral_end = STM(
            nm_size, c(Color(STM_LATERAL_END_INDICATOR_COLOR).rgb),
            type="lateral_end", origin_offset=origin_offset, flip_y=flip_y, flip_x=flip_x
        )

        if self.render_grid:
            self._grid = GridCache()
            grid_grey = c(Color("grey").rgb)
            self._grid.build_tile(self.nm_size, grid_grey, line_width=1)  # 1px avoids seam issues
        self._legend = Legend(self.nm_size, surface_color)

    def build_static_layer(self):
        self.static_background = pg.Surface(self.window_size, pg.SRCALPHA).convert_alpha()
        self.static_background_without_goals = pg.Surface(self.window_size, pg.SRCALPHA).convert_alpha()
        if self.render_grid:
            grid_layer = self._grid.render_grid_layer(
                self.window_size,
                self.origin_offset,
                self.nm_size,
                y_factor=self.y_factor,
            )
            grid_static = pg.Surface(self.window_size, pg.SRCALPHA).convert_alpha()
            grid_static.blit(grid_layer, (0, 0))

            if self.flip_x or self.flip_y:
                grid_static = pg.transform.flip(grid_static, self.flip_x, self.flip_y)

            self.grid_static = grid_static
            self.static_background.blit(self.grid_static, (0, 0))
            self.static_background_without_goals.blit(self.grid_static, (0, 0))

        for obj in self.goal_objects.values():
            self.static_background.blit(obj.draw(), obj.anchor(self.window_size[1]))
        self.obstacle_static = pg.Surface(self.window_size, pg.SRCALPHA).convert_alpha()
        for obj in self.obstacle_objects.values():
            self.static_background.blit(obj.draw(), obj.anchor(self.window_size[1]))
            self.static_background_without_goals.blit(obj.draw(), obj.anchor(self.window_size[1]))
        legend_surf = self._legend.draw(10)
        self.static_background.blit(legend_surf, (10, 10))
        self.static_background_without_goals.blit(legend_surf, (10, 10))

    def render_background(self, surpress_goals: bool = False):
        if not surpress_goals:
            self.window.blit(self.static_background, (0, 0))
        else:
            self.window.blit(self.static_background_without_goals, (0, 0))

    def render_goals(self) -> None:
        for obj in self.goal_objects.values():
            surf = obj.draw()
            self.window.blit(surf, obj.anchor(self.window_size[1]))

    def render_obstacles(self):
        for obj in self.obstacle_objects.values():
            self.window.blit(obj.draw(), obj.anchor(self.window_size[1]))

    def render_molecules(self, current_molecule: Molecule) -> None:
        for name, obj in self.objects.items():
            if self.render_sensors == "all" or (
                self.render_sensors == "focus" and obj.molecule is current_molecule
            ):
                s = obj.sensor_surface()
                if s is not None:
                    self.window.blit(s, obj.sensor_anchor(self.window_size[1]))

            poly = obj.draw_surface()
            self.window.blit(poly, obj.anchor(self.window_size[1]))

    def render_STM(self, action: "VerticalAction | LateralAction") -> None:
        if self.nm_size < MIN_NM_RENDERING:
            return

        H = self.window_size[1]

        if isinstance(action, VerticalAction):
            surf = self.stm_vertical.draw()
            anchor = self.stm_vertical.anchor(action.xy.coords[0], H)
            self.window.blit(surf, anchor)

        elif isinstance(action, LateralAction):
            p0 = self.stm_lateral_start.center_to_pixel_xy(action.xy.coords[0], H)
            p1 = self.stm_lateral_end.center_to_pixel_xy(action.xy_dest.coords[0], H)
            pg.draw.line(self.window, (255, 255, 255), p0, p1, width=1)

            a0 = self.stm_lateral_start.anchor(action.xy.coords[0], H)
            a1 = self.stm_lateral_end.anchor(action.xy_dest.coords[0], H)
            self.window.blit(self.stm_lateral_start.draw(), a0)
            self.window.blit(self.stm_lateral_end.draw(), a1)

    def render_matching(
        self,
        matching: Matching,
        highlight: Highlight = Highlight(),
        index: int | None = None,
        corridor_config: Optional[CorridorConfiguration] = None,
    ):
        name = matching.molecule.name
        gx, gy = matching.goal.position.x, matching.goal.position.y
        highlight.colour = self.molecule_name_color_map[matching.molecule.name] if highlight.colour == "" else highlight.colour
        key = f"{index}@{name}@{gx:.6f},{gy:.6f}-{highlight.colour}"

        sprite = self._matching_objects.get(key)
        if sprite is None:
            sprite = MatchingObject(
                matching,
                nm_size=self.nm_size,
                highlight=highlight,
                index=index,
                origin_offset=self.origin_offset,
                flip_y=self.flip_y,
                flip_x=self.flip_x,
                corridor_config=corridor_config,
            )
            self._matching_objects[key] = sprite

        # draw and blit (surface built once inside MatchingSprite)
        surf = sprite.draw()
        self.window.blit(surf, sprite.anchor(self.window_size[1]))

    def render_movements(self, moved_molecules: Optional[dict[Molecule, tuple[Point, Movement]]] = None):
        if moved_molecules is None: return
        for molecule, (old_center, movement) in moved_molecules.items():
            travelled_distance = movement.distance_moved
            factor = travelled_distance / molecule.maximum_movement
            if factor >= 1.0:
                color = c(Color("purple").rgb)
            else:
                rgb_low  = np.array(Color("yellow").rgb)
                rgb_high = np.array(Color("purple").rgb)
                color = c(tuple((1 - factor) * rgb_low + factor * rgb_high))
            pg.draw.line(self.movement_screen,
                         color,
                         (self.nm_size * (old_center.x + self.origin_offset[0]),      self.nm_size * (old_center.y + self.origin_offset[1])),
                         (self.nm_size * (molecule.center.x + self.origin_offset[0]), self.nm_size * (molecule.center.y + self.origin_offset[1])), width=3)
        self.window.blit(pg.transform.flip(self.movement_screen, self.flip_x, self.flip_y), (0,0))

    def render_point(self, p: Point, screen: pg.Surface) -> None:
        pg.draw.rect(screen, "red", (p.x * self.nm_size - 1, p.y * self.nm_size -1, 3, 3), width=-1)

    def set_render_grid(self, value: bool) -> None:
        self.render_grid = value

    def get_render_grid(self) -> bool:
        return self.render_grid

    def clear(self, force: bool = False) -> None:
        if force:
            if self.store_images:
                self.__set_directory()
            self.window.fill(self.surface_color)
            return
        if not self.cleared_in_step:
            self.window.fill(self.surface_color)
            self.cleared_in_step = True

    def clear_movement(self) -> None:
        self.movement_screen.fill(self.surface_color)

    def update(self) -> None:
        pg.display.update()
        if self.store_images:
            filename = f"{self.frame_idx:06d}.png"
            pg.image.save(self.window, os.path.join(self.directory, filename))
            self.frame_idx += 1
        self.cleared_in_step = False

    def set_simulator(self, simulator: Simulator) -> None:
        self.simulator = simulator
        molecules = self.simulator.molecules
        colors = list(Color("orange").range_to(Color("blue"), len(molecules)))
        self.molecule_name_color_map = {molecules[i].name: colors[i] for i in range(len(molecules))}
        self.objects = {}
        for mol in self.simulator.molecules:
            name = mol.name
            self.objects[name] = PolygonMoleculeGameObject(
                molecule=mol,
                nm_size=self.nm_size,
                color=c(self.molecule_name_color_map[name].rgb),
                colorkey=(0, 0, 0),
                origin_offset=self.origin_offset,
                flip_y=self.flip_y,
                flip_x=self.flip_x,
                draw_names=self.draw_names,
            )
        self._build_static_objects()
        self.build_static_layer()

    def __scaled_center_position(self, action_xy: Point, surface_dimensions: tuple[float, float], surface_height: float, anchor: bool = True) -> tuple[float, float]:
        y_factor = -1 if self.flip_y else 1
        action_xy = shapely.affinity.translate(action_xy, self.origin_offset[0], self.origin_offset[1])
        action_xy = shapely.affinity.scale(action_xy, self.nm_size, y_factor * self.nm_size, origin=(0,0))
        if self.flip_y: action_xy = shapely.affinity.translate(action_xy, 0, surface_height)
        if anchor:
            action_xy = shapely.affinity.translate(action_xy, -surface_dimensions[0] / 2, -surface_dimensions[1] / 2)
        return (action_xy.x, action_xy.y)

    def _build_goal_objects(self):
        self.goal_objects.clear()
        for goal in self.simulator.get_goals():
            obj = GoalObject(
                goal=goal,
                nm_size=self.nm_size,
                colorkey=self.surface_color,
                origin_offset=self.origin_offset,
                flip_y=self.flip_y,
                flip_x=self.flip_x,
            )
            self.goal_objects[id(goal)] = obj


    def get_image(self) -> NDArray:
        """
        The rendered image as a rgb array.

        Gymnasium's channel convention is H x W x C
        """
        data = pg.surfarray.array3d(self.window)  # in W x H x C channel convention
        return np.moveaxis(data, 0, 1)

    def __set_directory(self):
        now = datetime.now().isoformat(timespec='minutes')
        directory = f"image_dump/{self.environment_name.replace('/','_')}_{now}"
        os.makedirs(directory, exist_ok=True)
        self.directory = directory
        logger.trace(f"Storing all rendered frames into {self.directory}")

    def inc_x_offset(self, value: int = 1):
        self.origin_offset = (self.origin_offset[0] + value, self.origin_offset[1])
        self.movement_screen.scroll(dx=self.nm_size * value)

    def dec_x_offset(self, value: int = 1):
        self.origin_offset = (self.origin_offset[0] - value, self.origin_offset[1])
        self.movement_screen.scroll(dx=-self.nm_size * value)

    def inc_y_offset(self, value: int = 1):
        self.origin_offset = (self.origin_offset[0], self.origin_offset[1] + value)
        self.movement_screen.scroll(dy=self.nm_size * value)

    def dec_y_offset(self, value: int = 1):
        self.origin_offset = (self.origin_offset[0], self.origin_offset[1] - value)
        self.movement_screen.scroll(dy=-self.nm_size * value)

    @property
    def y_factor(self) -> int:
        return -1 if self.flip_y else 1
