import numpy as np
from math import sqrt, cos, sin, radians
from shapely.geometry import Polygon
import shapely.affinity
from typing import Tuple

from .helpers import unique_xy_with_tol

from loguru import logger

def hexagon(width: float,
            height: float | None = None,
            center: Tuple[float, float] = (0.0, 0.0),
            orientation: str = "flat") -> Polygon:
    """
    Create a hexagon polygon centered at `center` with the given width and height.
    """
    if height is None:
        height = width
    if orientation not in ("flat", "pointy"):
        raise ValueError("orientation must be 'flat' or 'pointy'")

    rot = 0 if orientation == "flat" else 30
    unit_hex = Polygon([(cos(radians(a + rot)), sin(radians(a + rot)))
                        for a in (0, 60, 120, 180, 240, 300)])

    if orientation == "flat":
        unit_w, unit_h = 2.0, sqrt(3.0)
    else:
        unit_w, unit_h = sqrt(3.0), 2.0

    sx = width / unit_w
    sy = height / unit_h
    hex_scaled = shapely.affinity.scale(unit_hex, xfact=sx, yfact=sy, origin=(0, 0))
    return shapely.affinity.translate(hex_scaled, xoff=center[0], yoff=center[1])

def honeycomb(cols: int,
              rows: int,
              hex_w: float,
              hex_h: float | None = None,
              center: Tuple[float, float] = (0.0, 0.0),
              orientation: str = "flat") -> tuple[np.ndarray, np.ndarray]:
    if cols <= 0 or rows <= 0:
        return np.array([]), np.array([])
    if hex_h is None:
        hex_h = hex_w
    if orientation not in ("flat", "pointy"):
        raise ValueError("orientation must be 'flat' or 'pointy'")

    pts = []

    if orientation == "flat":
        # unit bbox (2, √3) → spacings (3/2, √3), vertical stagger = √3/2
        sx = hex_w / 2.0
        sy = hex_h / sqrt(3.0)

        dx = (3.0 / 2.0) * sx
        dy = sqrt(3.0) * sy
        shift_y = (sqrt(3.0) / 2.0) * sy

        # only add vertical stagger if there is more than one column
        extra_h = shift_y if cols > 1 else 0.0
        total_w = (cols - 1) * dx + hex_w
        total_h = (rows - 1) * dy + extra_h + hex_h

        x0 = center[0] - total_w / 2.0 + hex_w / 2.0
        y0 = center[1] - total_h / 2.0 + hex_h / 2.0

        for r in range(rows):
            for c in range(cols):
                cx = x0 + c * dx
                cy = y0 + r * dy + (shift_y if (c & 1) else 0.0)
                h = hexagon(hex_w, hex_h, (cx, cy), orientation="flat")
                pts.extend(h.exterior.coords[:-1])

    else:  # pointy
        # unit bbox (√3, 2) → spacings (√3, 3/2), horizontal stagger = √3/2
        sx = hex_w / sqrt(3.0)
        sy = hex_h / 2.0

        dx = sqrt(3.0) * sx
        dy = (3.0 / 2.0) * sy
        shift_x = (sqrt(3.0) / 2.0) * sx

        # only add horizontal stagger if there is more than one row
        extra_w = shift_x if rows > 1 else 0.0
        total_w = (cols - 1) * dx + extra_w + hex_w
        total_h = (rows - 1) * dy + hex_h

        x0 = center[0] - total_w / 2.0 + hex_w / 2.0
        y0 = center[1] - total_h / 2.0 + hex_h / 2.0

        for r in range(rows):
            for c in range(cols):
                cx = x0 + c * dx + (shift_x if (r & 1) else 0.0)
                cy = y0 + r * dy
                h = hexagon(hex_w, hex_h, (cx, cy), orientation="pointy")
                pts.extend(h.exterior.coords[:-1])

    if not pts:
        return np.array([]), np.array([])

    A = np.asarray(pts, dtype=float)
    A = unique_xy_with_tol(A, tol=1e-3, reducer="first")
    return A[:, 0], A[:, 1]

