import numpy as np
from numpy.typing import NDArray
import shapely

from .helpers import simplify_and_segmentize_polygon

def ellipse_points(width: float, height: float, num_points: int, x_offset: float = 0.0, y_offset: float = 0.0, span: float = 2*np.pi, span_start: float = 0.0) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
    space = np.linspace(span_start, span_start + span, num_points + 1)[:-1]
    x = 0.5 * width * np.cos(space) + x_offset
    y = 0.5 * height * np.sin(space) + y_offset
    return x,y

def compute_semi_ellipse(width: float, height: float, num_points: int, x_offset: float = 0.0, y_offset: float = 0.0, span: float = np.pi, span_start: float = 0.0) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
    return ellipse_points(width, height, num_points, x_offset, y_offset, span, span_start)

def compute_semi_circle(radius: float, num_points: int, x_offset: float = 0.0, y_offset: float = 0.0, span: float = np.pi, span_start: float = 0.0) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
    return ellipse_points(radius, radius, num_points, x_offset, y_offset, span, span_start)

def compute_circle(radius: float, num_points: int, x_offset: float = 0.0 , y_offset: float = 0.0) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
    return ellipse_points(radius, radius, num_points, x_offset, y_offset)

def compute_ellipse(width: float, height: float, num_points: int, x_offset: float = 0.0 , y_offset: float = 0.0) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
    return ellipse_points(width, height, num_points, x_offset, y_offset)

def interpolate_ellipse(width: float, height: float, max_segment_length: float, x_offset: float = 0.0, y_offset = 0.0, tolerance: float = 0.5) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
    circle = shapely.geometry.Point(x_offset, y_offset).buffer(1)
    ellipse = shapely.affinity.scale(circle, width, height)
    return simplify_and_segmentize_polygon(ellipse, tolerance, max_segment_length)

def compute_rectangle(width: float, height: float, num_points: int, x_offset: float = 0.0, y_offset: float = 0.0) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
    if num_points < 4:
        raise ValueError("Rectangle requires at least 4 points")

    w = width / 2
    h = height / 2

    perimeter = 2 * (width + height)

    # allocate points per edge (at least 1 per edge)
    n_w = max(1, int(num_points * width / perimeter))
    n_h = max(1, int(num_points * height / perimeter))

    # adjust to ensure total is exactly num_points
    counts = np.array([n_w, n_h, n_w, n_h])
    diff = num_points - counts.sum()

    i = 0
    while diff != 0:
        counts[i % 4] += 1 if diff > 0 else -1
        diff = num_points - counts.sum()
        i += 1

    bottom, right, top, left = counts

    pts = []

    pts.append(np.column_stack([
        np.linspace(-w, w, bottom, endpoint=False),
        np.full(bottom, -h),
    ]))

    pts.append(np.column_stack([
        np.full(right, w),
        np.linspace(-h, h, right, endpoint=False),
    ]))

    pts.append(np.column_stack([
        np.linspace(w, -w, top, endpoint=False),
        np.full(top, h),
    ]))

    pts.append(np.column_stack([
        np.full(left, -w),
        np.linspace(h, -h, left, endpoint=False),
    ]))

    pts = np.vstack(pts)

    # numerical safety
    pts = pts[:num_points]

    pts[:, 0] += x_offset
    pts[:, 1] += y_offset

    return pts[:, 0], pts[:, 1]
