# curve.py
from __future__ import annotations
from abc import ABC, abstractmethod
import numpy as np
import numpy.linalg as la
import numpy as np
from scipy.spatial import cKDTree          # 10× faster than naïve loops

__all__ = ["ParametricCurve", "RandomFourierCurve", "ScaledCurve"]


class ParametricCurve(ABC):
    """
    Abstract 1‑D curve in R^d, parameter domain t∈[0,1].
    Generic tools (arc length, curvature, scaling) live here.
    Child classes only need to implement c(t); analytic c', c'' are optional.
    """
    def __init__(self, d: int):
        self.d = int(d)

    # -------- mandatory --------------------------------------------------
    @abstractmethod
    def c(self, t):
        """Point(s) at parameter t  (broadcast over numpy arrays)."""
    # ---------------------------------------------------------------------

    # -------- optional fast overrides ------------------------------------
    def c_prime(self, t, eps=1e-5):
        """1st derivative; finite‑diff fallback if not overriden."""
        t = np.asarray(t, dtype=float)
        return (self.c(t + eps) - self.c(t - eps)) / (2 * eps)

    def c_double_prime(self, t, eps=1e-4):
        """2nd derivative; finite‑diff fallback if not overriden."""
        t = np.asarray(t, dtype=float)
        return (self.c(t + eps) - 2 * self.c(t) + self.c(t - eps)) / eps**2
    # ---------------------------------------------------------------------

    # ===== geometry ======================================================
    def speed(self, t):
        return la.norm(self.c_prime(t), axis=-1)

    def curvature(self, t):
        """‖dT/ds‖ with numerical safety for non‑unit speeds."""
        cp  = self.c_prime(t)
        cpp = self.c_double_prime(t)

        cp_norm2 = np.sum(cp**2, axis=-1, keepdims=True)          # ‖c′‖²
        proj     = (np.sum(cpp * cp, axis=-1, keepdims=True) / (cp_norm2 + 1e-12)) * cp
        num      = la.norm(cpp - proj, axis=-1)                   # ‖c′′⊥‖
        den      = (cp_norm2.squeeze() + 1e-12)                   # ‖c′‖²   ← fixed!
        return num / den

    # ===== arc‑length re‑parametrisation =================================
    def unit_speed_grid(self, n_pts=2000):
        """Return (t_grid, s_grid) where s is uniform arc‑length grid."""
        t = np.linspace(0.0, 1.0, n_pts)
        speeds = self.speed(t)
        # cum. trapezoidal integral → s(t)
        s = np.concatenate(([0], np.cumsum((speeds[:-1] + speeds[1:]) *
                                           np.diff(t) / 2)))
        L = s[-1]
        s_uniform = np.linspace(0.0, L, n_pts)
        t_uniform = np.interp(s_uniform, s, t)
        return t_uniform, s_uniform

    # ===== curvature capping via spatial scaling =========================
    def stretch_to_curvature(self, kappa_max, n_probe=2000):
        """Return a *new* curve scaled so that max κ ≤ kappa_max."""
        t = np.linspace(0.0, 1.0, n_probe)
        current = self.curvature(t).max()
        # print(current)
        if current <= kappa_max:            # already OK
            return self
        lam = current / kappa_max           # >1  (enlarge to reduce κ)
        return ScaledCurve(self, lam)
    # =====================================================================

    def has_self_intersection(self, n_pts=5000, eps=1e-3, neighbor_skip=5):
        """
        Returns True if any two non‑neighboring sample points are closer than eps.
        n_pts        : how densely to sample the parameter [0,1]
        eps          : distance threshold for declaring a 'collision'
        neighbor_skip: how many adjacent indices to ignore (numeric adjacency)
        """
        t = np.linspace(0.0, 1.0, n_pts)
        P = self.c(t)                           # shape (n_pts, d)
        tree = cKDTree(P)
        pairs = tree.query_pairs(r=eps)
        for i, j in pairs:
            if abs(i - j) > neighbor_skip:      # ignore local neighbours
                return True                     # self‑intersection detected
        return False


# -------------------------------------------------------------------------
class ScaledCurve(ParametricCurve):
    """Pure spatial scaling:  c̃(t) = λ · c(t)."""
    def __init__(self, base: ParametricCurve, lam: float):
        super().__init__(base.d)
        self.base, self.lam = base, float(lam)

    def c(self, t):
        return self.lam * self.base.c(t)

    def c_prime(self, t):
        return self.lam * self.base.c_prime(t)

    def c_double_prime(self, t):
        return self.lam * self.base.c_double_prime(t)

# -------------------------------------------------------------------------
class RandomFourierCurve(ParametricCurve):
    """
    Random truncated Fourier series.

    Parameters
    ----------
    d      : ambient dimension
    K      : highest harmonic
    alpha  : decay rate of Fourier amplitudes
    span   : fraction of a full 2π cycle to traverse as t runs 0→1
             • span = 1.0  → closed loop (old behaviour)
             • span = 0.25 → quarter‑loop, open curve
    seed   : RNG seed
    """
    def __init__(self, d, K=8, alpha=1.5, span=1.0, seed=None):
        super().__init__(d)
        self.span = float(span)
        rng       = np.random.default_rng(seed)
        scales    = np.array([k**(-alpha) for k in range(1, K + 1)])
        self.a    = rng.normal(scale=scales, size=(d, K))
        self.b    = rng.normal(scale=scales, size=(d, K))
        self.K    = K
        self._w   = 2 * np.pi * self.span                
        
    # ------- analytic formulas ------------------------------------------
    def c(self, t):
        t = np.asarray(t)[..., None]
        res = np.zeros(t.shape[:-1] + (self.d,))
        for k in range(1, self.K + 1):
            ang = self._w * k * t                           # span‑scaled
            res += (self.a[:, k-1] * np.sin(ang) +
                    self.b[:, k-1] * np.cos(ang))
        return res

    def c_prime(self, t):
        t = np.asarray(t)[..., None]
        res = np.zeros(t.shape[:-1] + (self.d,))
        for k in range(1, self.K + 1):
            ang  = self._w * k * t
            coef = self._w * k                              # span‑scaled
            res += coef * (self.a[:, k-1] * np.cos(ang) -
                        self.b[:, k-1] * np.sin(ang))
        return res

    def c_double_prime(self, t):
        t = np.asarray(t)[..., None]
        res = np.zeros(t.shape[:-1] + (self.d,))
        for k in range(1, self.K + 1):
            ang   = self._w * k * t
            coef2 = (self._w * k) ** 2                      # span‑scaled
            res  += -coef2 * (self.a[:, k-1] * np.sin(ang) +
                            self.b[:, k-1] * np.cos(ang))
        return res    



class RandomLineCurve(ParametricCurve):
    """
    Straight line segment centred at `center` with random unit direction.

    Parameters
    ----------
    d       : ambient dimension.
    length  : total arc length of the segment (speed == length).
    center  : optional array‑like, the midpoint of the segment; default 0.
    seed    : RNG seed for reproducibility.
    """
    def __init__(self, d, length=1.0, center=None, seed=None):
        super().__init__(d)
        rng      = np.random.default_rng(seed)
        v        = rng.normal(size=d)
        v       /= la.norm(v)                 # random unit direction
        self.v   = v * length                # displacement over t∈[0,1]
        self.len = float(length)

        if center is None:
            center = np.zeros(d, dtype=float)
        self.center = np.asarray(center, dtype=float)
        self.x0     = self.center - 0.5 * self.v   # start point

    # analytic formulas ---------------------------------------------------
    def c(self, t):
        t = np.asarray(t)[..., None]          # broadcast
        return self.x0 + t * self.v           # shape (..., d)

    def c_prime(self, t):
        t = np.asarray(t)
        shape = t.shape + (self.d,)
        return np.broadcast_to(self.v, shape)      # constant speed (== length)

    def c_double_prime(self, t):
        t = np.asarray(t)
        shape = t.shape + (self.d,)
        return np.zeros(shape)                     # curvature ≡ 0