"""
traj_gt_metrics.py – Ground-truth-aware trajectory error metrics.

Each metric is a standalone NumPy function. No external deps except an
optional SciPy import for DTW; a pure-NumPy fallback is provided.

Conventions
-----------
* Prediction & GT arrays have the same shape (..., T, C) with C∈{2,3};
  leading batch dims are preserved in outputs.
* `axis` selects the time dimension (default -2).

Implemented metrics
-------------------
1. ade                  – Average Displacement Error
2. fde                  – Final Displacement Error
3. success_rate         – FDE < threshold ratio
4. hausdorff            – Symmetric Hausdorff distance
5. ndtw                 – Normalised Dynamic Time-Warping score
6. dynamic_consistency  – Motion-dynamic similarity (Wasserstein)
7. sdtw                 – Success-weighted nDTW
"""
from __future__ import annotations

import numpy as np
from numpy.typing import ArrayLike

__all__ = [
    "ade",
    "fde",
    "success_rate",
    "hausdorff",
    "ndtw",
    "sdtw",
    "dynamic_consistency",
]

# ------------------------------------------------------------------
# Common preprocessing
# ------------------------------------------------------------------
def _prep(a: np.ndarray, b: np.ndarray, axis: int):
    a = np.asarray(a)
    b = np.asarray(b)[..., :2]  # enforce 2D
    if a.shape != b.shape:
        raise ValueError("pred and gt must share shape")

    a = np.moveaxis(a, axis, -2)          # (..., T, C)
    b = np.moveaxis(b, axis, -2)
    batch_shape = a.shape[:-2]
    T, C = a.shape[-2:]
    a = a.reshape(-1, T, C)               # (N, T, C)
    b = b.reshape(-1, T, C)
    return a, b, batch_shape

# ------------------------------------------------------------------
# 1. ADE / 2. FDE / 3. Success-rate
# ------------------------------------------------------------------
def ade(pred: ArrayLike, gt: ArrayLike, *, axis: int = -2, reduce: str = "none"):
    p, g, bs = _prep(pred, gt, axis)
    err = np.linalg.norm(p - g, axis=-1).mean(-1)     # (N,)
    err = err.reshape(bs)
    return err if reduce == "none" else err.mean()

def fde(pred: ArrayLike, gt: ArrayLike, *, axis: int = -2, reduce: str = "none"):
    p, g, bs = _prep(pred, gt, axis)
    err = np.linalg.norm(p[:, -1] - g[:, -1], axis=-1)
    err = err.reshape(bs)
    return err if reduce == "none" else err.mean()

def success_rate(pred: ArrayLike, gt: ArrayLike, *, threshold: float = 3.0,
                 axis: int = -2, reduce: str = "none"):
    sr = (fde(pred, gt, axis=axis, reduce="none") < threshold)
    return sr if reduce == "none" else sr.mean()

# ------------------------------------------------------------------
# 4. Hausdorff (per-sample loop)
# ------------------------------------------------------------------
def hausdorff(pred: ArrayLike, gt: ArrayLike, *, axis: int = -2, reduce: str = "none"):
    p, g, bs = _prep(pred, gt, axis)
    N = p.shape[0]
    out = np.empty(N)
    for i in range(N):
        D = np.linalg.norm(p[i, :, None] - g[i, None, :], axis=-1)  # (T,T)
        out[i] = max(D.min(1).max(), D.min(0).max())
    out = out.reshape(bs)
    return out if reduce == "none" else out.mean()

# ------------------------------------------------------------------
# 5. nDTW (O(T^2) DTW)
# ------------------------------------------------------------------
def _dtw(a: np.ndarray, b: np.ndarray) -> float:
    """Classic DTW cost with L2; inputs (T, C)."""
    n, m = len(a), len(b)
    D = np.full((n + 1, m + 1), np.inf)
    D[0, 0] = 0.0
    for i in range(1, n + 1):
        ai = a[i - 1]
        for j in range(1, m + 1):
            dist = np.linalg.norm(ai - b[j - 1])
            D[i, j] = dist + min(D[i - 1, j], D[i, j - 1], D[i - 1, j - 1])
    return float(D[n, m])

def ndtw(pred: ArrayLike, gt: ArrayLike, *, alpha: float = 4.0,
         axis: int = -2, reduce: str = "none"):
    """
    Normalised DTW score in (0,1], higher is better:
        score = exp(-DTW_cost / (alpha * T))
    """
    p, g, bs = _prep(pred, gt, axis)
    N, T, _ = p.shape
    score = np.empty(N, dtype=float)
    denom = max(alpha * T, 1e-12)
    for i in range(N):
        cost = _dtw(p[i], g[i])
        score[i] = np.exp(-cost / denom)
    score = score.reshape(bs)
    return score if reduce == "none" else score.mean()

def sdtw(pred: ArrayLike, gt: ArrayLike, *, threshold: float = 2.0, alpha: float = 4.0,
         axis: int = -2, reduce: str = "none"):
    """
    Success-weighted nDTW = success_rate (0/1) * nDTW (0..1).
    """
    sr = success_rate(pred, gt, threshold=threshold, axis=axis, reduce="none").astype(float)
    nd = ndtw(pred, gt, alpha=alpha, axis=axis, reduce="none")
    out = sr * nd
    return out if reduce == "none" else out.mean()

# ------------------------------------------------------------------
# 6. Dynamic-consistency (Wasserstein)  (per-sample loop)
# ------------------------------------------------------------------
def dynamic_consistency(pred: ArrayLike, gt: ArrayLike, *, dt: float = 0.1,
                        axis: int = -2, reduce: str = "none"):
    """
    Returns exp(-W1(speed)) * exp(-W1(accel)) in (0,1]. Higher = closer dynamics.

    Uses 1-D Wasserstein (Earth Mover's) distance between magnitudes of speed and
    acceleration distributions. If SciPy is unavailable, falls back to an
    empirical CDF integration. Complexity O(N log N).
    """
    try:
        from scipy.stats import wasserstein_distance as w1
    except Exception:
        def w1(x, y):
            x, y = np.sort(x), np.sort(y)
            allv = np.sort(np.concatenate([x, y]))
            cdfx = np.searchsorted(x, allv, side="right") / len(x)
            cdfy = np.searchsorted(y, allv, side="right") / len(y)
            return np.trapz(np.abs(cdfx - cdfy), allv)

    p, g, bs = _prep(pred, gt, axis)
    N, T, _ = p.shape
    out = np.empty(N, dtype=float)
    for i in range(N):
        v_p = np.linalg.norm(np.diff(p[i], axis=0) / dt, axis=-1)
        v_g = np.linalg.norm(np.diff(g[i], axis=0) / dt, axis=-1)
        a_p, a_g = np.diff(v_p), np.diff(v_g)
        dv = w1(v_p, v_g)
        da = w1(a_p, a_g)
        out[i] = np.exp(-dv) * np.exp(-da)
    out = out.reshape(bs)
    return out if reduce == "none" else out.mean()

# -----------------------------------------------------------------------------
# Quick demo
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    np.random.seed(0)
    B = 8
    T = 101
    t = np.linspace(0, 4 * np.pi, T)

    single_traj = np.stack([30 * np.cos(t), 30 * np.sin(t)], -1)      # (T,2)
    gt = np.repeat(single_traj[None, ...], B, axis=0)                 # (B,T,2)

    pred = (
        gt
        + 0.4 * np.random.randn(B, T, 2)
        + 0.3 * np.roll(gt, 10, axis=1)   # time shift
    )

    print("ADE:", ade(pred, gt, reduce="mean"))
    print("FDE:", fde(pred, gt, reduce="mean"))
    print("SR @2m:", success_rate(pred, gt, threshold=2.0, reduce="mean"))
    print("Hausdorff:", hausdorff(pred, gt, reduce="mean"))
    print("nDTW:", ndtw(pred, gt, reduce="mean"))
    # print("DynCons:", dynamic_consistency(pred, gt, reduce="mean"))
