# datasets.py
"""Synthetic datasets tailored for STAGE benchmarks.

Each generator returns a `Dataset` object with fields:
    * `X`      – (n,d) noisy observations in ℝᵈ.
    * `t`      – (n,) ground‑truth parameter along the curve (arc‑length or user specified).
    * `order`  – indices that sort by *t* (pre‑computed for convenience).

Design principles
-----------------
* **Vectorised** NumPy implementation for speed.
* **Stateless** – no global RNG; each call accepts a `seed`.
* Provide both *low‑dim* (2‑D) and *high‑dim* datasets mirroring those in the
  original experiment scripts, but with cleaner APIs.
* No external dependencies beyond NumPy (and SciPy for optional helpers).
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple, Optional
import numpy as np

__all__ = [
    "Dataset",
    "sine_curve",
    "clothoid_spiral",
    "random_tangent_walk",
]


@dataclass(slots=True)
class Dataset:
    X: np.ndarray        # (n,d) noisy observations
    t: np.ndarray        # (n,) latent parameter
    order: np.ndarray    # indices that sort by *t*

    def __iter__(self):
        yield from (self.X, self.t, self.order)


# -----------------------------------------------------------------------------
# 1. Noisy sine curve in 2‑D (matches `generate_sin_data` from old scripts)
# -----------------------------------------------------------------------------

def sine_curve(
    n: int = 200,
    *,
    amplitude: float = 2.0,
    periods: float = 1.0,
    noise_std: float = 0.05,
    rotation_angle: float = 0.0,
    seed: Optional[int] = None,
) -> Dataset:
    """y = *amplitude*·sin(2π*periods*·t) sampled at *n* equispaced *t*.

    Parameters
    ----------
    n : int, default 200
        Number of points.
    amplitude : float, default 2.0
        Vertical scale of the sine wave.
    periods : float, default 1.0
        How many sin cycles across the x‑axis span.
    noise_std : float, default 0.05
        Standard deviation of isotropic Gaussian noise.
    rotation_angle : float, default 0.0
        Rotate the entire cloud by this radian angle (counter‑clockwise).
    seed : int, optional
        Random seed for the noise.
    """
    rng = np.random.default_rng(seed)

    t = np.linspace(0.0, 1.0, n)                    # latent parameter
    x = t * 2.0 * np.pi * periods                  # x‑coordinate (unwrapped)
    y = amplitude * np.sin(x)

    X = np.column_stack((x, y))
    X += rng.normal(scale=noise_std, size=X.shape)

    if rotation_angle != 0.0:
        c, s = np.cos(rotation_angle), np.sin(rotation_angle)
        R = np.array([[c, -s], [s, c]])
        X = X @ R.T

    order = np.arange(n)  # already sorted by construction
    return Dataset(X=X.astype(np.float64), t=t, order=order)


# -----------------------------------------------------------------------------
# 2. Clothoid / Fresnel spiral (linear curvature) – generalises `generate_spiral_data`
# -----------------------------------------------------------------------------

def clothoid_spiral(
    n: int = 200,
    *,
    kappa_end: float = 4.0 * np.pi,  # total heading change ≈ κ_end * length / 2
    length: float = 1.0,
    start: Tuple[float, float] = (0.0, 0.0),
    heading: float = 0.0,
    noise_std: float = 0.05,
    seed: Optional[int] = None,
) -> Dataset:
    """Generate a planar clothoid (curvature ∝ s) with added Gaussian noise."""
    if n < 2:
        raise ValueError("Need at least two points.")
    if length <= 0:
        raise ValueError("length must be positive.")

    rng = np.random.default_rng(seed)
    s = np.linspace(0.0, length, n)
    a = kappa_end / length               # curvature slope κ(s) = a s
    theta = 0.5 * a * s**2              # heading angle integral

    # Fresnel integrals (numeric trapezoid for speed & dependency‑free)
    ds = s[1] - s[0]
    cos_t = np.cos(theta)
    sin_t = np.sin(theta)
    x_rel = np.cumsum(np.concatenate([[0.0], 0.5 * (cos_t[:-1] + cos_t[1:]) * ds]))
    y_rel = np.cumsum(np.concatenate([[0.0], 0.5 * (sin_t[:-1] + sin_t[1:]) * ds]))

    # Global transform: rotation by *heading* and translation to *start*
    c, s_h = np.cos(heading), np.sin(heading)
    X = np.column_stack((c * x_rel - s_h * y_rel, s_h * x_rel + c * y_rel))
    X += start

    X += rng.normal(scale=noise_std, size=X.shape)

    t = s / length  # normalised arc‑length ∈ [0,1]
    order = np.argsort(t)
    return Dataset(X=X.astype(np.float64), t=t, order=order)


# -----------------------------------------------------------------------------
# 3. High‑dimensional random tangent walk (Brownian‐like curve in ℝᵈ)
# -----------------------------------------------------------------------------

def random_tangent_walk(
    n: int = 300,
    d: int = 10,
    *,
    step_scale: float = 0.02,
    noise_std: float = 0.01,
    seed: Optional[int] = None,
) -> Dataset:
    """1‑D curve in high dimension: integrate a random unit tangent field.

    The tangent at each step is drawn isotropically; successive directions are
    smoothed with a small correlation to create a gently varying path.
    """
    rng = np.random.default_rng(seed)

    # Generate random unit tangents with slight correlation
    tangents = rng.normal(size=(n, d))
    tangents /= np.linalg.norm(tangents, axis=1, keepdims=True)
    alpha = 0.9  # correlation coefficient between successive directions
    for i in range(1, n):
        tangents[i] = alpha * tangents[i - 1] + (1 - alpha) * tangents[i]
        tangents[i] /= np.linalg.norm(tangents[i])

    # Integrate tangents to obtain positions
    steps = tangents * step_scale
    X = np.cumsum(steps, axis=0)

    # Add small isotropic noise
    X += rng.normal(scale=noise_std, size=X.shape)

    t = np.linspace(0.0, 1.0, n)
    order = np.arange(n)
    return Dataset(X=X.astype(np.float64), t=t, order=order)
