from __future__ import annotations

from molecule_movement.PyGameObject import PyGameObject, StaticPyGameObject
from molecule_movement import  Goal, Obstacle, Molecule, Matching
from molecule_movement.colour_utils import c

import pygame as pg
from typing import Optional, Tuple, Literal, List

from molecule_movement.wrapper import CorridorConfiguration

from colour import Color
from .colour_utils import Highlight
from pygame import Surface
from shapely import GeometryCollection
import shapely
from molecule_movement.shapes import compute_corridor

from loguru import logger

from math import isfinite

RGB = tuple[int, int, int]

STM_VERTICAL_INDICATOR_COLOR="red"
STM_LATERAL_START_INDICATOR_COLOR="red"
STM_LATERAL_END_INDICATOR_COLOR="blue"
STM_SIZE=3
STM_INDICATOR_MIN=2
STM_INDICATOR_WIDTH_MIN=2

MATCHING_INDICATOR_COLOR="white"
MATCHING_INDICATOR_MIN=1

GOAL_COLOR="green"
GOAL_DRAW_WIDTH=3

OBSTACLE_COLOR="grey"

MIN_MOLECULE_ROTATION_INDICATOR=1

RGB = Tuple[int, int, int]
STMType = Literal["vertical", "lateral_start", "lateral_end"]


built = {'poly':0, 'sensor':0, 'stm':0}

class MoleculeGameObject(PyGameObject):
    def __init__(self,
                 molecule: Molecule,
                 size: tuple[float,float],
                 orientation: int,
                 nm_size,
                 color=c(Color("red").rgb),
                 colorkey: tuple[int, int, int] = (0,0,0),
                 origin_offset: Optional[tuple[float, float]] = (0,0),
                 flip_y: Optional[bool] = False,
                 flip_x: Optional[bool] = False) -> None:
        super().__init__(nm_size, color, colorkey, origin_offset, flip_y, flip_x)
        self.molecule = molecule
        self.text = self.font.render(molecule.name, True, self.color)
        self.crashed_text = self.font.render(molecule.name, True, c(Color("red").rgb))
        self.size = (size[0] * self.nm_size + self.text.get_width(), size[1] * self.nm_size + self.text.get_height() * 2)
        self.molecule_size = (size[0] * self.nm_size, size[1] * self.nm_size)
        self.molecule_center = (self.molecule_size[0] // 2, self.molecule_size[1] // 2)
        self.orientation = orientation

        self.create_surfaces()

    def draw(self) -> Surface:
        logger.warning(f"Tried to draw empty molecule.")
        assert False, "Tried to draw empty molecule."

    def create_surfaces(self) -> None:
        self.molecule_surface = self._prepare_surface(self.molecule_size)
        self.surface = self._prepare_surface(self.size)

    def draw_name(self) -> Surface:
        col, row = self.molecule_center
        row += self.text.get_height()
        if not self.molecule.crashed:
            self.surface.blit(self.text, (col,row))
        else:
            self.surface.blit(self.crashed_text, (col,row))
        return self.surface



class PolygonMoleculeGameObject:
    def __init__(
        self,
        molecule,
        nm_size: int,
        color: RGB,
        colorkey: RGB = (0, 0, 0),
        origin_offset: Optional[tuple[float, float]] = (0, 0),
        flip_y: bool = False,
        flip_x: bool = False,
        draw_names: bool = True,
    ) -> None:
        self.molecule = molecule
        self.nm_size = nm_size
        self.color = color
        self.origin_offset = origin_offset or (0, 0)
        self.flip_x = flip_x
        self.flip_y = flip_y
        self.y_factor = -1 if flip_y else 1
        self.draw_names = draw_names

        self._poly_surface: Optional[pg.Surface] = None
        self._poly_cache_key = None  # (pts, crashed, flip_x, flip_y, nm_size)

        self._sensor_surface: Optional[pg.Surface] = None
        self._sensor_cache_key = None
        self._sensor_anchor: Tuple[int, int] = (0, 0)

    def _ensure_polygon_surface(self):

        crashed = self.molecule.crashed
        cache_key = (self.molecule.orientation, crashed, self.nm_size)
        if cache_key == self._poly_cache_key and self._poly_surface is not None:
            return

        pts = self.molecule.get_drawable_polygon(self.nm_size)
        if len(pts) > 1 and pts[0] == pts[-1]:
            pts = pts[:-1]

        xs = [p[0] for p in pts]
        ys = [p[1] for p in pts]
        w = max(1, int(max(xs)) + 1)
        h = max(1, int(max(ys)) + 1)

        surf = pg.Surface((w, h), pg.SRCALPHA)
        fill = (255, 0, 0) if crashed else self.color

        pts_i = [(int(x), int(y)) for (x, y) in pts]
        pg.draw.polygon(surf, fill, pts_i, 0)

        if self.flip_x or self.flip_y:
            surf = pg.transform.flip(surf, self.flip_x, self.flip_y)

        self._poly_surface = surf.convert_alpha()
        self._poly_cache_key = cache_key

    def draw_surface(self) -> pg.Surface:
        self._ensure_polygon_surface()
        return self._poly_surface

    def anchor(self, surface_height: int) -> Tuple[int, int]:
        """
        World-space top-left for blitting the cached polygon surface.
        (polygon surface is already local-tight, so no extra offset needed)
        """
        ax = (self.molecule.center.x + self.origin_offset[0] - self.molecule.size_x / 2) * self.nm_size
        ay = (
            self.y_factor * (self.molecule.center.y + self.origin_offset[1]) * self.nm_size
            - self.molecule.size_y / 2 * self.nm_size
            + (surface_height if self.flip_y else 0)
        )
        return (int(ax), int(ay))

    def _ensure_sensor_surface(self):
        sensors = getattr(self.molecule, "sensors", [])
        if len(sensors) < 2:
            self._sensor_surface = None
            self._sensor_cache_key = None
            return

        cache_key = (self.molecule.orientation, self.nm_size)
        if cache_key == self._sensor_cache_key and self._sensor_surface is not None:
            return

        origin = self.molecule.center

        scaled_geoms = [shapely.affinity.scale(s, self.nm_size, self.nm_size, origin=origin) for s in sensors]
        gc_scaled = GeometryCollection(scaled_geoms)

        minx, miny, maxx, maxy = gc_scaled.convex_hull.bounds

        pad = 1
        w = max(10, int(maxx - minx) + 2 * pad)
        h = max(10, int(maxy - miny) + 2 * pad)
        surf = pg.Surface((w, h), pg.SRCALPHA)

        tx = -minx + pad
        ty = -miny + pad

        for s in scaled_geoms:
            placed = shapely.affinity.translate(s, tx, ty)
            coords = list(placed.exterior.coords)
            if len(coords) > 1 and coords[0] == coords[-1]:
                coords = coords[:-1]
            pts_i = [(int(x), int(y)) for (x, y) in coords]
            pg.draw.polygon(surf, (255, 0, 0), pts_i, width=1)

        if self.flip_x or self.flip_y:
            surf = pg.transform.flip(surf, self.flip_x, self.flip_y)

        self._sensor_surface = surf.convert_alpha()
        self._sensor_cache_key = cache_key

        self._sensor_surface_size = surf.get_size()

    def sensor_surface(self) -> Optional[pg.Surface]:
        self._ensure_sensor_surface()
        return self._sensor_surface

    def sensor_anchor(self, surface_height: int) -> Tuple[int, int]:
        ax = (self.molecule.center.x + self.origin_offset[0]) * self.nm_size - self._sensor_surface.get_width() / 2
        ay = (
            self.y_factor * (self.molecule.center.y + self.origin_offset[1]) * self.nm_size
            - self._sensor_surface.get_height() / 2
        )
        self._sensor_anchor = (int(ax), int(ay))
        if self.flip_y:
            ay += surface_height
        return (ax, ay)

class GoalObject(StaticPyGameObject):
    def __init__(
        self,
        goal: Goal,
        nm_size: int,
        colorkey: RGB = (0, 0, 0),
        origin_offset: Tuple[float, float] = (0, 0),
        flip_y: bool = False,
        flip_x: bool = False,
        color: Optional[RGB] = None,
        outline_width: int = GOAL_DRAW_WIDTH,  # goals outlined by default
    ) -> None:
        super().__init__(
            obj=goal,
            nm_size=nm_size,
            color=c(Color(GOAL_COLOR).rgb) if color is None else color,
            colorkey=colorkey,
            origin_offset=origin_offset,
            flip_y=flip_y,
            flip_x=flip_x,
            outline_width=outline_width,
        )

class ObstacleObject(StaticPyGameObject):
    def __init__(
        self,
        obstacle: Obstacle,
        nm_size: int,
        colorkey: RGB = (0, 0, 0),
        origin_offset: Tuple[float, float] = (0, 0),
        flip_y: bool = False,
        flip_x: bool = False,
        color: Optional[RGB] = None,
        outline_width: int = 0,  # filled by default
    ) -> None:
        super().__init__(
            obj=obstacle,
            nm_size=nm_size,
            color=c(Color(OBSTACLE_COLOR).rgb) if color is None else color,
            colorkey=colorkey,
            origin_offset=origin_offset,
            flip_y=flip_y,
            flip_x=flip_x,
            outline_width=outline_width,
        )

class STM(PyGameObject):
    """
    Cached STM indicator sprite. Pixel pattern depends only on (nm_size, type, color).
    Flips are handled at anchor/time-of-blit level (not in the cache key).
    """
    # shared cache across all STM instances
    _CACHE: dict[tuple[int, STMType, RGB], pg.Surface] = {}

    def __init__(
        self,
        nm_size: int,
        color: RGB,
        *,
        type: STMType = "vertical",
        colorkey: RGB = (0, 0, 0),
        origin_offset: tuple[float, float] = (0, 0),
        flip_y: bool = False,
        flip_x: bool = False,  # still unsupported in your base class
    ) -> None:
        super().__init__(nm_size, color, colorkey, origin_offset, flip_y, flip_x)

        self.type: STMType = type
        # sprite size in pixels (square)
        self._size_px: Tuple[int, int] = (self.nm_size * STM_SIZE, self.nm_size * STM_SIZE)
        self._center_px: Tuple[int, int] = (self._size_px[0] // 2, self._size_px[1] // 2)
        self._indicator_px: int = max(STM_INDICATOR_MIN, self.nm_size // 5)
        # local cached surface handle (pointing into the shared cache)
        self._surf: Optional[pg.Surface] = None

    def draw(self) -> pg.Surface:
        if self._surf is None:
            self._surf = self._get_or_build_surface()
        return self._surf

    def anchor(self, center_nm: tuple[float, float], surface_height: int) -> Tuple[int, int]:
        """
        World center (nm) -> top-left anchor for blitting this sprite.
        """
        cx_px, cy_px = self.center_to_pixel_xy(center_nm, surface_height)
        ax = int(cx_px - self._size_px[0] // 2)
        ay = int(cy_px - self._size_px[1] // 2)
        return (ax, ay)

    def center_to_pixel_xy(self, center_nm: tuple[float, float], surface_height: int) -> Tuple[int, int]:
        """World center (nm) -> pixel center with origin_offset and vertical flips."""
        x_nm, y_nm = center_nm
        x_px = int((x_nm + self.origin_offset[0]) * self.nm_size)
        y_px = int(self.y_factor * (y_nm + self.origin_offset[1]) * self.nm_size)
        if self.flip_y:
            y_px += surface_height
        return (x_px, y_px)

    def _get_or_build_surface(self) -> pg.Surface:
        key = (self.nm_size, self.type, self.color)
        surf = STM._CACHE.get(key)
        if surf is not None:
            return surf

        surf = pg.Surface(self._size_px, pg.SRCALPHA)
        cx, cy = self._center_px

        pg.draw.rect( surf, self.color, (cx - self._indicator_px // 2, cy - self._indicator_px // 2, self._indicator_px, self._indicator_px),)

        ring_radius = self._size_px[0] // 2
        ring_width = max(STM_INDICATOR_WIDTH_MIN, self.nm_size // 5)
        pg.draw.circle(surf, self.color, (cx, cy), ring_radius, width=ring_width)

        surf = surf.convert_alpha()

        STM._CACHE[key] = surf
        return surf


class MatchingObject(PyGameObject):
    def __init__(
        self,
        matching: "Matching",
        nm_size: int,
        *,
        highlight: "Highlight" = Highlight(),
        index: int | None = None,
        origin_offset: tuple[float, float] = (0, 0),
        flip_y: bool = False,
        flip_x: bool = False,
        corridor_config: Optional["CorridorConfiguration"] = None,
    ) -> None:
        color = c(Color(MATCHING_INDICATOR_COLOR).rgb)
        if highlight.enabled:
            color = c(Color(highlight.colour).rgb)

        super().__init__(nm_size, color, colorkey=(0, 0, 0), origin_offset=origin_offset, flip_y=flip_y, flip_x=flip_x)

        self.matching = matching
        self.index = index

        self.width = max(MATCHING_INDICATOR_MIN, self.nm_size // 10)
        if highlight.enabled:
            self.width = highlight.width

        self.draw_corridor = False
        self.corridor_draw_width = 1
        self._corridor_world = None
        if highlight.enabled and getattr(highlight, "draw_corridor", False):
            assert corridor_config, f"When highlight.draw_corridor=True you must pass corridor_config"
            self.draw_corridor = True
            self._corridor_world = compute_corridor(
                self.matching,
                corridor_width=corridor_config.width,
                parking_distance=corridor_config.parking_distance,
                parking_buffer=corridor_config.parking_buffer,
            )
            self.corridor_draw_width = highlight.corridor_draw_width

        self._surf: Optional[pg.Surface] = None
        self._cache_key = None
        self._anchor_min_local: Tuple[int, int] = (0, 0)  # (minx_local, miny_local)
        self._size_px: Tuple[int, int] = (0, 0)

    # ---------- public API ---------------------------------------------------

    def draw(self) -> pg.Surface:
        """Return the cached sprite surface (build on first use / when invalidated)."""
        self._ensure_surface()
        return self._surf

    def anchor(self, surface_height: int) -> Tuple[int, int]:
        """
        Top-left anchor in window pixels.
        = origin_offset contribution + local min corner (+ flip_y compensation).
        """
        minx_local, miny_local = self._anchor_min_local
        ax = int(self.origin_offset[0] * self.nm_size + minx_local)
        ay = int(self.y_factor * self.origin_offset[1] * self.nm_size + miny_local)
        if self.flip_y:
            ay += surface_height
        return (ax, ay)

    # ---------- internals ----------------------------------------------------

    def _waypoints_local_px(self) -> List[Tuple[float, float]]:
        """Waypoints in *local pixel space* (no origin_offset applied)."""
        pts = []
        for p in self.matching.anchors:
            x = p.x * self.nm_size
            y = self.y_factor * p.y * self.nm_size
            if isfinite(x) and isfinite(y):
                pts.append((x, y))
        return pts

    def _corridor_local_px(self):
        if not self.draw_corridor or self._corridor_world is None:
            return None
        # scale x by nm_size, y by y_factor*nm_size (so it lives in same local pixel space)
        scaled = shapely.affinity.scale(self._corridor_world, self.nm_size, self.y_factor * self.nm_size, origin=(0, 0))
        return shapely.get_coordinates(scaled).tolist()

    def _compute_cache_key(self):
        anchors_nm = tuple((float(p.x), float(p.y)) for p in self.matching.anchors)
        key = (
            anchors_nm,
            self.nm_size,
            self.width,
            tuple(self.color),
            bool(self.draw_corridor),
            self.corridor_draw_width if self.draw_corridor else 0,
            # Include a light thumbprint of the corridor (bounds) if present
            (tuple(map(float, self._corridor_world.bounds)) if self._corridor_world is not None else None),
            # Index text affects pixels if drawn
            (self.index, self.font_size) if self.index is not None else None,
        )
        return key

    def _ensure_surface(self) -> None:
        cache_key = self._compute_cache_key()
        if cache_key == self._cache_key and self._surf is not None:
            return

        pts = self._waypoints_local_px()

        xs, ys = [], []
        xs.extend(p[0] for p in pts)
        ys.extend(p[1] for p in pts)

        corridor_coords = self._corridor_local_px()
        if corridor_coords:
            for x, y in corridor_coords:
                xs.append(x)
                ys.append(y)

        if not xs or not ys:
            surf = pg.Surface((1, 1), pg.SRCALPHA)
            self._surf = surf.convert_alpha()
            self._size_px = (1, 1)
            self._anchor_min_local = (0, 0)
            self._cache_key = cache_key
            return

        minx = int(min(xs))
        miny = int(min(ys))
        maxx = int(max(xs))
        maxy = int(max(ys))

        pad = 2
        w = max(1, (maxx - minx) + 1 + 2 * pad)
        h = max(1, (maxy - miny) + 1 + 2 * pad)

        shift_x = -minx + pad
        shift_y = -miny + pad

        surf = pg.Surface((w, h), pg.SRCALPHA)

        if len(pts) >= 2:
            p0x, p0y = pts[0]
            for p1x, p1y in pts[1:]:
                pg.draw.line(
                    surf,
                    self.color,
                    (int(p0x + shift_x), int(p0y + shift_y)),
                    (int(p1x + shift_x), int(p1y + shift_y)),
                    width=self.width,
                )
                p0x, p0y = p1x, p1y

        if corridor_coords:
            coords = corridor_coords[:]
            if len(coords) > 1 and coords[0] == coords[-1]:
                coords = coords[:-1]
            pts_i = [(int(x + shift_x), int(y + shift_y)) for (x, y) in coords]
            if len(pts_i) >= 2:
                pg.draw.polygon(surf, self.color, pts_i, width=self.corridor_draw_width)

        if self.index is not None:
            cx_nm, cy_nm = self.matching.centroid.x, self.matching.centroid.y
            cx = cx_nm * self.nm_size + shift_x
            cy = self.y_factor * cy_nm * self.nm_size + shift_y + (self.font_size / 3)
            label = self.font.render(f"{self.index}", True, self.color)
            #surf.blit(label, (int(cx), int(cy)))

        self._surf = surf.convert_alpha()
        self._cache_key = cache_key
        self._anchor_min_local = (minx, miny)
        self._size_px = (w, h)
