import pygame as pg
from pygame import Surface
from loguru import logger
from typing import Optional, Tuple, List
from colour import Color
from molecule_movement.colour_utils import c

RGB = tuple[int, int, int]

LEGEND_COLOR="white"

_FONT_CACHE: dict[int, pg.font.Font] = {}

def get_font(size: int) -> pg.font.Font:
    font = _FONT_CACHE.get(size)
    if font is None:
        font = pg.font.SysFont(None, size)
        _FONT_CACHE[size] = font
    return font

class PyGameObject():
    def __init__(self,
                 nm_size: int,
                 color: tuple[int, int, int],
                 colorkey: tuple[int, int, int],
                 origin_offset: tuple[float, float] = (0,0),
                 flip_y: bool = False,
                 flip_x: bool = False) -> None:
        self.color = color
        self.nm_size = nm_size
        self.font_size = max(15, min(28, 6*self.nm_size))
        self.font = get_font(self.font_size)
        self.colorkey = colorkey
        self.origin_offset = origin_offset
        self.flip_y = flip_y

        if flip_x:
            raise NotImplementedError(f"Cannot flip x-axis yet.")
        self.flip_x = False

    def draw(self) -> Surface:
        pass

    def _prepare_surface(self, size: tuple[float, float]) -> Surface:
        surface = pg.Surface(size)
        surface.fill(self.colorkey)
        surface.set_colorkey(self.colorkey)
        return surface

    @property
    def y_factor(self):
        return -1 if self.flip_y else 1




class StaticPyGameObject(PyGameObject):
    """
    Generic cached sprite for any static polygonal object with:
      - .position: shapely Point (world units)
      - .size_x / .size_y: floats (world units)
      - .rotation: degrees
      - .scale(x,y): -> list[(x,y)] of LOCAL polygon coords (already rotated to origin)
    """
    def __init__(
        self,
        obj,
        nm_size: int,
        color: RGB,
        *,
        colorkey: RGB = (0, 0, 0),
        origin_offset: Tuple[float, float] = (0, 0),
        flip_y: bool = False,
        flip_x: bool = False,
        outline_width: int = 0,
    ) -> None:
        super().__init__(nm_size, color, colorkey, origin_offset, flip_y, flip_x)
        self.obj = obj
        self.outline_width = outline_width

        # cache only on things that change pixels; flip_x/y are renderer-constants
        self._surf: Optional[pg.Surface] = None
        self._cache_key = None  # (pts, nm_size, rotation, color, outline_width)

        # precompute pixel size for anchoring
        self._size_px = (int(self.obj.size_x * nm_size), int(self.obj.size_y * nm_size))

    # ---- internals ---------------------------------------------------------
    def _local_scaled_points(self) -> List[tuple[int, int]]:
        pts = self.obj.scale(self.nm_size, self.nm_size)  # local coords at origin
        if len(pts) > 1 and pts[0] == pts[-1]:
            pts = pts[:-1]
        return [(int(x), int(y)) for x, y in pts]

    def _ensure_surface(self):
        pts = self._local_scaled_points()
        cache_key = (tuple(pts), self.nm_size, self.obj.rotation, tuple(self.color), self.outline_width)

        if cache_key == self._cache_key and self._surf is not None:
            return

        xs = [p[0] for p in pts]
        ys = [p[1] for p in pts]
        w = max(1, max(xs) + 1)
        h = max(1, max(ys) + 1)

        surf = pg.Surface((w, h), pg.SRCALPHA)
        pg.draw.polygon(surf, self.color, pts, width=self.outline_width)

        # one-time flip for this renderer’s coordinate system (constant)
        if self.flip_x or self.flip_y:
            surf = pg.transform.flip(surf, self.flip_x, self.flip_y)

        self._surf = surf.convert_alpha()
        self._cache_key = cache_key

    # ---- public API (matches your draw/anchor pattern) ---------------------
    def draw(self) -> pg.Surface:
        self._ensure_surface()
        return self._surf

    def anchor(self, surface_height: int) -> Tuple[int, int]:
        sx, sy = self._size_px
        ax = int((self.obj.position.x + self.origin_offset[0]) * self.nm_size - sx / 2)
        ay = int(
            self.y_factor * (self.obj.position.y + self.origin_offset[1]) * self.nm_size
            - sy / 2
            + (surface_height if self.flip_y else 0)
        )
        return (ax, ay)



class GridCache:
    """
    Prebuilds a tiling grid tile (grey lines) and a per-frame axes overlay (white axes).
    Rebuilds tile only when nm_size or grid thickness changes.
    """
    def __init__(self) -> None:
        self._tile: Optional[pg.Surface] = None
        self._tile_size: Tuple[int, int] = (0, 0)
        self._tile_key: Optional[tuple] = None  # (nm_size, block_w, block_h, grid_color, line_width)

    def build_tile(
        self,
        nm_size: int,
        grid_color: RGB,
        *,
        min_cell_px: int = 10,
        max_cell_px: int = 100,
        line_width: int = 1,
    ) -> None:
        block_w = min(max_cell_px, max(min_cell_px, 10 * nm_size))
        block_h = min(max_cell_px, max(min_cell_px, 10 * nm_size))
        key = (nm_size, block_w, block_h, grid_color, line_width)
        if key == self._tile_key and self._tile is not None:
            return

        # Tile draws only the *top* and *left* lines to avoid double-thick seams when tiling.
        tile = pg.Surface((block_w, block_h), pg.SRCALPHA)
        pg.draw.line(tile, grid_color, (0, 0),           (block_w, 0), width=line_width)  # top
        pg.draw.line(tile, grid_color, (0, 0),           (0, block_h), width=line_width)  # left
        # no bottom/right borders in the tile!

        self._tile = tile.convert_alpha()
        self._tile_size = (block_w, block_h)
        self._tile_key = key

    def render_grid_layer(
        self,
        size_px: Tuple[int, int],
        origin_offset_nm: Tuple[float, float],
        nm_size: int,
        *,
        y_factor: int = 1,  # -1 if flip_y else 1
    ) -> pg.Surface:
        """Tile the prebuilt pattern with the correct pixel offset."""
        assert self._tile is not None, "call build_tile() first"
        W, H = size_px
        cell_w, cell_h = self._tile_size

        # Pixel offset within a cell, derived from the origin_offset in nm
        off_x = int((origin_offset_nm[0] * nm_size) % cell_w)
        off_y = int((y_factor * origin_offset_nm[1] * nm_size) % cell_h)

        surf = pg.Surface((W, H), pg.SRCALPHA)

        # Start tiling so that a grid line passes through the true origin
        start_x = -off_x
        start_y = -off_y

        # Tile blits (cheap). ~ (W/cell_w) * (H/cell_h) blits.
        x = start_x
        while x < W:
            y = start_y
            while y < H:
                surf.blit(self._tile, (x, y))
                y += cell_h
            x += cell_w

        return surf


class Legend(PyGameObject):
    _CACHE: dict[tuple[int, float, RGB], pg.Surface] = {}

    def __init__(self, nm_size: int, colorkey: RGB = (0,0,0)) -> None:
        super().__init__(nm_size, c(Color(LEGEND_COLOR).rgb), colorkey)

    def draw(self, count: float = 5) -> Surface:
        key = (self.nm_size, float(count), self.color)
        cached = Legend._CACHE.get(key)
        if cached is not None:
            return cached

        label = f"{int(count) if isinstance(count, int) or float(count).is_integer() else count} nm"
        name = self.font.render(label, True, self.color)
        size = (int(count * self.nm_size + name.get_width()),
                int(2 * self.nm_size + name.get_height() * 2))
        surface = self._prepare_surface(size)
        mid_y = self.nm_size // 2

        pg.draw.line(surface, self.color, (0, mid_y), (int(count * self.nm_size), mid_y), width=1)
        pg.draw.line(surface, self.color, (0, mid_y), (0, 0), width=1)
        pg.draw.line(surface, self.color, (int(count * self.nm_size), mid_y), (int(count * self.nm_size), 0), width=1)
        surface.blit(name, (int((count/4) * self.nm_size), mid_y))

        surface = surface.convert_alpha()
        Legend._CACHE[key] = surface
        return surface
