import numpy as np
from numpy.typing import NDArray
from numpy import floating
import shapely
from shapely import LineString, MultiPolygon, Polygon

def interpolate_line(x_start: float, y_start: float, max_segment_length: float, x_end: float, y_end: float, tolerance: float = 0.5) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
    linestring = shapely.geometry.LineString([(x_start, y_start), (x_end, y_end)])
    print(linestring)
    return shapely.segmentize(linestring, max_segment_length=max_segment_length).coords.xy


def simplify_and_segmentize_polygon(polygon: shapely.Polygon, tolerance: float, max_segment_length: float) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
    simplified_polygon = polygon.simplify(tolerance, preserve_topology=False).normalize()
    return remove_closing_point(shapely.segmentize(simplified_polygon, max_segment_length=max_segment_length).exterior.xy)

def line_string(x: NDArray[floating], y: NDArray[floating]) -> LineString:
    return LineString(zip(x,y))

def remove_closing_point(xy: tuple[np.floating, np.floating]) -> tuple[np.floating, np.floating]:
    return (xy[0][:-1], xy[1][:-1])

def get_multipolygon_coords(mp: MultiPolygon | Polygon) -> tuple[np.ndarray, np.ndarray]:
    if isinstance(mp, Polygon):
        geoms = [mp]
    else:
        geoms = mp

    xs, ys = zip(*[
        (x, y)
        for geom in geoms
        for ring in [geom.exterior, *geom.interiors]
        for x, y in ring.coords
    ])

    return np.array(xs), np.array(ys)


def unique_xy_with_tol(xy: np.ndarray, tol: float = 1e-9, reducer: str = "first") -> np.ndarray:
    """
    Deduplicate 2D points with a numeric tolerance by snapping to a tol grid first.

    Args:
        xy: (N,2) float array of points
        tol: grid size for quantization (e.g., 1e-9 .. 1e-6)
        reducer: 'first' keeps the first point per bucket (stable),
                 'mean' returns the centroid per bucket.

    Returns:
        (M,2) float array of unique points.
    """
    if xy.size == 0:
        return xy
    # Quantize to an integer grid to make equality robust
    keys = np.rint(xy / tol).astype(np.int64)         # shape (N,2)

    if reducer == "first":
        _, idx = np.unique(keys, axis=0, return_index=True)
        return xy[np.sort(idx)]

    # reducer == "mean": average all points in each bucket
    _, inv = np.unique(keys, axis=0, return_inverse=True)
    M = inv.max() + 1
    sums = np.zeros((M, 2), dtype=float)
    counts = np.bincount(inv)
    np.add.at(sums, inv, xy)
    return sums / counts[:, None]
