"""
plannability_metrics.py – Independent plannability metrics

All functions are standalone and NumPy-only. Hyperparameters are passed via
arguments with practical defaults. Trajectory tensors use (..., T, 2) (x, y).
The time dimension is selected by `axis` (default -2).
"""
from __future__ import annotations

import numpy as np
from numpy.typing import ArrayLike

__all__ = [
    # Feasibility & comfort
    "feasibility_rate",
    "jerk_rms",
    "curvature_rms",
    "dcurvature_rms",
    "speed_var_score",
    "yaw_rate_rms_deg",
    "trajectory_consistency",
    "comfort_score_3",
    "comfort_raw",
    "comfort_raw_gt",
    "fit_comfort_scales",
    "comfort_score_from_raw",
    "comfort_score_norm",
    "speed_score",
    # Consistency metrics
    "velocity_consistency",
    "acceleration_consistency",
    # Quality aggregator helpers
    "get_others_quality",
]

# ------------------------------------------------------------------
# Common preprocessing
# ------------------------------------------------------------------
def _prep_xy(arr: np.ndarray, axis: int):
    """
    Move time axis to -2 and flatten leading batch dims.
    Returns: flat (N, T, 2) and original batch shape.
    """
    arr = np.asarray(arr, float)[..., :2]
    arr = np.moveaxis(arr, axis, -2)          # (..., T, 2)
    batch_shape = arr.shape[:-2]
    T = arr.shape[-2]
    arr = arr.reshape(-1, T, 2)               # (N, T, 2)
    return arr, batch_shape

# ------------------------------------------------------------------
# Kinematics helper
# ------------------------------------------------------------------
def _kinematics_xy(traj_xy: np.ndarray, dt: float):
    """
    traj_xy: (N, T, 2)
    Returns dict of kinematic sequences aligned to their natural lengths.
    """
    v_vec = np.diff(traj_xy, axis=1) / dt         # (N, T-1, 2)
    a_vec = np.diff(v_vec,  axis=1) / dt          # (N, T-2, 2)
    j_vec = np.diff(a_vec,  axis=1) / dt          # (N, T-3, 2)

    speed = np.linalg.norm(v_vec, axis=-1)        # (N, T-1)
    accel = np.linalg.norm(a_vec, axis=-1)        # (N, T-2)
    jerk  = np.linalg.norm(j_vec, axis=-1)        # (N, T-3)

    heading   = np.arctan2(v_vec[..., 1], v_vec[..., 0])      # (N, T-1)
    yaw_rate  = np.diff(heading, axis=1) / dt                 # (N, T-2)

    # curvature aligned with yaw_rate length (T-2)
    curvature = yaw_rate / (speed[:, 1:] + 1e-9)              # (N, T-2)
    dcurv_dt  = np.diff(curvature, axis=1) / dt               # (N, T-3)

    return dict(
        v_vec=v_vec, a_vec=a_vec, j_vec=j_vec,
        speed=speed, accel=accel, jerk=jerk,
        heading=heading, yaw_rate=yaw_rate,
        curvature=curvature, dcurvature_dt=dcurv_dt,
    )

# ------------------------------------------------------------------
# Feasibility
# ------------------------------------------------------------------
def feasibility_rate(
    traj_xy: ArrayLike,
    *,
    dt: float = 0.1,                 # typical 0.05–0.1 s
    speed_max: float = 30.0,         # km/h urban ~30–50; highway ~80–120
    accel_max: float = 3.0,          # m/s^2 comfortable 2–4
    mu: float = 0.9,                 # dry 0.9–1.0; wet 0.5–0.7
    axis: int = -2,
    reduce: str = "none",
) -> np.ndarray | float:
    """Fraction in [0,1] of timesteps satisfying all constraints."""
    xy, bs = _prep_xy(np.asarray(traj_xy, float), axis)
    k = _kinematics_xy(xy, dt)

    g = 9.81
    v_lim = speed_max / 3.6

    speed_ok = k["speed"][:, 1:] <= v_lim              # align to accel (T-2)
    accel_ok = k["accel"] <= accel_max                 # (N, T-2)
    # crude lateral check via y-accel component in world frame (proxy)
    lat_ok   = np.abs(k["a_vec"][..., 1]) <= mu * g    # (N, T-2)

    ok = speed_ok & accel_ok & lat_ok
    out = ok.mean(1).reshape(bs)
    return out if reduce == "none" else float(out.mean())

# ------------------------------------------------------------------
# Utility: masks for moving segments & static clips
# ------------------------------------------------------------------
def _speed_mask(k, v_static=0.1):
    """
    Returns:
      mask_j   : valid mask for jerk-aligned series (T-3)
      mask_c_y : valid mask for curvature/yaw-rate-aligned series (T-2)
      is_static: per-clip static flag
    """
    speed = k["speed"]                           # (N, T-1)
    is_static = speed.max(axis=1) < v_static

    mask_j   = speed[:, 2:] > v_static           # (N, T-3)
    mask_c_y = speed[:, 1:] > v_static           # (N, T-2)
    return mask_j, mask_c_y, is_static

# ------------------------------------------------------------------
# Jerk RMS
# ------------------------------------------------------------------
def jerk_rms(
    traj_xy: ArrayLike,
    *,
    dt: float = 0.1,
    axis: int = -2,
    reduce: str = "none",
) -> np.ndarray | float:
    """RMS jerk (m/s^3). Static clips → NaN."""
    xy, bs = _prep_xy(np.asarray(traj_xy, float), axis)
    k = _kinematics_xy(xy, dt)
    jerks = k["jerk"]                               # (N, T-3)
    mask_j, _, is_static = _speed_mask(k)

    out = np.full(jerks.shape[0], np.nan)
    for i, (j, m, st) in enumerate(zip(jerks, mask_j, is_static)):
        if not st and m.any():
            out[i] = np.sqrt((j[m] ** 2).mean())

    out = out.reshape(bs)
    return out if reduce == "none" else float(np.nanmean(out))

# ------------------------------------------------------------------
# Curvature RMS  → score in (0,1], S=1/(1+RMS)
# ------------------------------------------------------------------
def curvature_rms(
    traj_xy: ArrayLike,
    *,
    dt: float = 0.1,
    axis: int = -2,
    eps: float = 1e-9,
    v_static: float = 0.1,
    pct: float | None = None,   # optional percentile filter over time
    reduce: str = "none",
) -> np.ndarray | float:
    """
    Returns S_curv = 1 / (1 + RMS(kappa)); higher is better.
    Optionally drops curvature spikes above `pct` percentile before RMS.
    Static clips → NaN.
    """
    xy, bs = _prep_xy(np.asarray(traj_xy, float), axis)
    N, T, _ = xy.shape
    if T < 3:
        raise ValueError("Need ≥3 frames to compute curvature")

    v = np.diff(xy, axis=1) / dt                                 # (N, T-1, 2)
    speed = np.linalg.norm(v, axis=-1)
    moving = speed.max(axis=1) >= v_static

    scores = np.full(N, np.nan, dtype=float)
    if moving.any():
        idx   = np.where(moving)[0]
        v_mv  = v[idx]
        a_mv  = np.diff(v_mv, axis=1) / dt                       # (M, T-2, 2)

        x_dot, y_dot = v_mv[..., 1:, 0], v_mv[..., 1:, 1]        # (M, T-2)
        x_dd , y_dd  = a_mv[..., :, 0],  a_mv[..., :, 1]
        num   = np.abs(x_dot * y_dd - y_dot * x_dd)
        den   = (x_dot**2 + y_dot**2 + eps) ** 1.5
        kappa = num / den                                        # (M, T-2)

        if pct is not None:
            thr   = np.percentile(kappa, pct, axis=-1, keepdims=True)
            kappa = np.where(kappa <= thr + eps, kappa, np.nan)

        rms = np.sqrt(np.nanmean(kappa**2, axis=-1))            # (M,)
        scores[idx] = 1.0 / (1.0 + rms)

    out = scores.reshape(bs)
    return out if reduce == "none" else float(np.nanmean(out))

# ------------------------------------------------------------------
# dCurvature RMS  → score in (0,1], S=1/(1+RMS(dkappa/dt))
# ------------------------------------------------------------------
def dcurvature_rms(
    traj_xy: ArrayLike,
    *,
    dt: float = 0.1,
    axis: int = -2,
    v_static: float = 0.1,
    reduce: str = "none",
) -> np.ndarray | float:
    """RMS of curvature rate (1/m/s) turned into score via 1/(1+RMS). Static → NaN."""
    xy, bs = _prep_xy(np.asarray(traj_xy, float), axis)
    k = _kinematics_xy(xy, dt)
    dcurv = k["dcurvature_dt"]                      # (N, T-3)
    mask_j, _, is_static = _speed_mask(k)

    out = np.full(dcurv.shape[0], np.nan)
    for i, (dc, m, st) in enumerate(zip(dcurv, mask_j, is_static)):
        if not st and m.any():
            rms = np.sqrt((dc[m] ** 2).mean())
            out[i] = 1.0 / (1.0 + rms)

    out = out.reshape(bs)
    return out if reduce == "none" else float(np.nanmean(out))

# ------------------------------------------------------------------
# Speed variability score via coefficient of variation
# ------------------------------------------------------------------
def speed_var_score(
    traj_xy: ArrayLike,
    *,
    dt: float = 0.1,
    axis: int = -2,
    v_static: float = 0.1,
    eps: float = 1e-9,
    reduce: str = "none",
) -> np.ndarray | float:
    """
    Score in (0,1], higher is steadier:
        cv = std(speed) / (mean(speed)+eps)
        S = 1 / (1 + cv)
    Static clips → NaN.
    """
    xy, bs = _prep_xy(np.asarray(traj_xy, float), axis)
    v = np.linalg.norm(np.diff(xy, axis=1) / dt, axis=-1)  # (N, T-1)

    moving = v.max(axis=1) >= v_static
    out = np.full(v.shape[0], np.nan, dtype=float)

    if moving.any():
        vm = v[moving]
        mu = vm.mean(axis=1)
        sd = vm.std(axis=1)
        cv = sd / (mu + eps)
        out[moving] = 1.0 / (1.0 + cv)

    out = out.reshape(bs)
    return out if reduce == "none" else float(np.nanmean(out))

# ------------------------------------------------------------------
# Yaw-rate RMS (deg/s)
# ------------------------------------------------------------------
def yaw_rate_rms_deg(
    traj_xy: ArrayLike,
    *,
    dt: float = 0.1,
    axis: int = -2,
    reduce: str = "none",
) -> np.ndarray | float:
    """RMS yaw-rate (deg/s). Static clips → NaN."""
    xy, bs = _prep_xy(np.asarray(traj_xy, float), axis)
    k = _kinematics_xy(xy, dt)
    yaws = k["yaw_rate"] * 180.0 / np.pi            # (N, T-2)
    _, mask_cy, is_static = _speed_mask(k)

    out = np.full(yaws.shape[0], np.nan)
    for i, (y, m, st) in enumerate(zip(yaws, mask_cy, is_static)):
        if not st and m.any():
            out[i] = np.sqrt((y[m] ** 2).mean())

    out = out.reshape(bs)
    return out if reduce == "none" else float(np.nanmean(out))

# ------------------------------------------------------------------
# Velocity/Acceleration consistency
# ------------------------------------------------------------------
def trajectory_consistency(
    traj_xy: ArrayLike,
    *,
    dt: float = 0.1,
    axis: int = -2,
    eps: float = 1e-9,
    v_static: float = 0.1,
    reduce: str = "none",
) -> np.ndarray | float:
    """
    Combined consistency score:
        S = 0.5 * [ exp(-σ_v/(μ_v+eps)) + exp(-σ_a/(μ_a+eps)) ]
    Static clips → NaN.
    """
    xy, bs = _prep_xy(np.asarray(traj_xy, float), axis)
    v = np.linalg.norm(np.diff(xy, axis=1) / dt, axis=-1)        # (N, T-1)

    scores = np.full(v.shape[0], np.nan, dtype=float)
    moving = v.max(axis=1) >= v_static

    if moving.any():
        v_m = v[moving]
        mu_v = v_m.mean(axis=1)
        sigma_v = v_m.std(axis=1)
        s_v = np.exp(-sigma_v / (mu_v + eps))

        a_m = np.diff(v_m, axis=1) / dt                          # (M, T-2)
        mu_a = np.mean(np.abs(a_m), axis=1)
        sigma_a = np.std(a_m, axis=1)
        s_a = np.exp(-sigma_a / (mu_a + eps))

        scores[moving] = 0.5 * (s_v + s_a)

    out = scores.reshape(bs)
    return out if reduce == "none" else float(np.nanmean(out))

def velocity_consistency(
    traj_xy: ArrayLike,
    *,
    dt: float = 0.1,
    axis: int = -2,
    eps: float = 1e-9,
    v_static: float = 0.1,
    reduce: str = "none",
) -> np.ndarray | float:
    """
    exp(-σ_v / (μ_v + eps)) on speed sequence; Static → NaN.
    """
    xy, bs = _prep_xy(np.asarray(traj_xy, float), axis)
    v = np.linalg.norm(np.diff(xy, axis=1) / dt, axis=-1)  # (N, T-1)

    moving = v.max(axis=1) >= v_static
    out = np.full(v.shape[0], np.nan, dtype=float)
    if moving.any():
        vm = v[moving]
        mu = vm.mean(axis=1)
        sd = vm.std(axis=1)
        out[moving] = np.exp(-sd / (mu + eps))

    out = out.reshape(bs)
    return out if reduce == "none" else float(np.nanmean(out))

def acceleration_consistency(
    traj_xy: ArrayLike,
    *,
    dt: float = 0.1,
    axis: int = -2,
    eps: float = 1e-9,
    v_static: float = 0.1,
    reduce: str = "none",
) -> np.ndarray | float:
    """
    exp(-σ_a / (μ_a + eps)) on |a| sequence; Static → NaN.
    """
    xy, bs = _prep_xy(np.asarray(traj_xy, float), axis)
    v = np.diff(xy, axis=1) / dt                     # (N, T-1, 2)
    a = np.diff(v, axis=1) / dt                      # (N, T-2, 2)
    speed = np.linalg.norm(v, axis=-1)               # (N, T-1)

    moving = speed.max(axis=1) >= v_static
    out = np.full(v.shape[0], np.nan, dtype=float)
    if moving.any():
        am = np.linalg.norm(a[moving], axis=-1)      # (M, T-2)
        mu = np.mean(np.abs(am), axis=1)
        sd = np.std(am, axis=1)
        out[moving] = np.exp(-sd / (mu + eps))

    out = out.reshape(bs)
    return out if reduce == "none" else float(np.nanmean(out))

# ------------------------------------------------------------------
# Comfort metrics (variants)
# ------------------------------------------------------------------
def comfort_score_3(
    traj_xy: ArrayLike, *, dt: float = 0.1, axis: int = -2,
    eps: float = 1e-9, v_static: float = 0.1, reduce: str = "none"
) -> np.ndarray | float:
    """
    Simple comfort proxy using jerk_lon / a_lat / yaw_rate, aggregated by
    geometric mean of exp(-|·|/scale) terms. Static → NaN.
    """
    xy, bs = _prep_xy(np.asarray(traj_xy, float), axis)
    v = np.diff(xy, axis=1) / dt
    speed = np.linalg.norm(v, axis=-1)
    moving = speed.max(axis=1) >= v_static

    heading = np.arctan2(v[..., 1], v[..., 0])
    yaw_rate = np.diff(heading, axis=1) / dt                 # (N, T-2)

    a = np.diff(v, axis=1) / dt                              # (N, T-2, 2)
    h = v[..., :-1, :]
    h_unit = h / (np.linalg.norm(h, axis=-1, keepdims=True) + eps)
    perp = np.stack([-h_unit[..., 1], h_unit[..., 0]], -1)
    a_lon = (a * h_unit).sum(-1)
    a_lat = (a * perp).sum(-1)
    j_lon = np.diff(a_lon, axis=1) / dt                      # (N, T-3)

    jerks = np.abs(j_lon).max(axis=1)
    lats  = np.abs(a_lat).max(axis=1)
    yaws  = np.abs(yaw_rate).max(axis=1)

    s_jerk = np.exp(-jerks / 4.13)
    s_lat  = np.exp(-lats  / 4.89)
    s_yaw  = np.exp(-yaws  / 0.95)

    score = (s_jerk * s_lat * s_yaw) ** (1/3)
    score[~moving] = np.nan

    out = score.reshape(bs)
    return out if reduce == "none" else float(np.nanmean(out))

def fit_comfort_scales(raw_Nx3, method="median", p=95):
    """
    Given raw (N,3) = [jerk_per_m, acc_per_m, yaw_per_m], fit per-dimension scales.
    """
    raw = np.asarray(raw_Nx3, float)
    sigmas = np.zeros(3, dtype=float)
    for i in range(3):
        xi = raw[:, i]
        xi = xi[np.isfinite(xi)]
        if xi.size == 0:
            sigmas[i] = 1.0
        else:
            if method == "percentile":
                sigmas[i] = np.percentile(xi, p)
            else:
                sigmas[i] = np.median(xi)
    return np.maximum(sigmas, 1e-9)

def comfort_score_from_raw(raw_Nx3, sigmas):
    """
    Map (N,3) spikes to [0,1] per dim with S=1/(1+q/sigma) and return geometric mean.
    """
    q = np.asarray(raw_Nx3, float)
    S = 1.0 / (1.0 + q / sigmas)
    S = np.where(np.isfinite(q), S, np.nan)
    return np.exp(np.nanmean(np.log(S), axis=1))

def comfort_raw(
    traj_xy: ArrayLike,
    *,
    dt: float = 0.1,
    axis: int = -2,
    eps: float = 1e-9,
    v_static: float = 0.1,
    reduce: str = "none",
    win: int = 11,
    poly: int = 3,
    pct: float | None = None,       # percentile over time (e.g., 95); None→mean
    length_eps: float = 1.0,        # min path length; otherwise NaN
):
    """
    Return three per-meter spikes (jerk_per_m, acc_per_m, yaw_per_m).
    Pipeline:
      - (optional) smoothing
      - centered diffs for v, a, jerk, yaw-rate
      - time aggregation by mean or percentile
      - normalize by total path length to avoid long-clip bias
      - static or too-short clips → NaN
    """
    xy_s, bs = _prep_xy(np.asarray(traj_xy, float), axis)
    N, T, _ = xy_s.shape
    if T < 5:
        raise ValueError("Need ≥5 frames")

    # (Optional) Savitzky–Golay smoothing (disabled by default)
    # from scipy.signal import savgol_filter
    # xy_s = savgol_filter(xy_s, win, poly, axis=-2, mode="interp")

    v = (xy_s[..., 2:, :] - xy_s[..., :-2, :]) / (2 * dt)                      # (N,T-2,2)
    a = (xy_s[..., 2:, :] - 2*xy_s[..., 1:-1, :] + xy_s[..., :-2, :]) / dt**2  # (N,T-2,2)
    a_c = a[..., 1:-1, :]                                                      # (N,T-4,2)
    acc_mag = np.linalg.norm(a_c, axis=-1)                                     # (N,T-4)
    j = (a[..., 2:, :] - a[..., :-2, :]) / (2 * dt)                            # (N,T-4,2)

    speed = np.linalg.norm(v, axis=-1)                                         # (N,T-2)
    moving = speed.max(axis=1) >= v_static

    def angle_diff(a0, a1):
        d = a0 - a1
        return (d + np.pi) % (2*np.pi) - np.pi

    yaw_rt = angle_diff(
        np.arctan2(v[..., 2:, 1], v[..., 2:, 0]),
        np.arctan2(v[..., :-2, 1], v[..., :-2, 0])
    ) / (2 * dt)                                                               # (N,T-4)

    def t_reduce(x):
        return np.percentile(x, pct, axis=-1) if pct is not None else x.mean(axis=-1)

    jerk_p = t_reduce(np.linalg.norm(j, axis=-1))      # (N,)
    acc_p  = t_reduce(acc_mag)                         # (N,)
    yaw_p  = t_reduce(np.abs(yaw_rt))                  # (N,)

    lengths = np.sum(np.linalg.norm(np.diff(xy_s, axis=-2), axis=-1), axis=-1)  # (N,)
    valid = moving & (lengths > length_eps)

    jerk_p = np.where(valid, jerk_p / (lengths + eps), np.nan)
    acc_p  = np.where(valid, acc_p  / (lengths + eps), np.nan)
    yaw_p  = np.where(valid, yaw_p  / (lengths + eps), np.nan)

    out = np.stack([jerk_p, acc_p, yaw_p], axis=-1).reshape(*bs, 3)
    if reduce == "none":
        return out
    return np.nanmean(out, axis=tuple(range(out.ndim - 1)))

def comfort_raw_gt(
    traj_state: ArrayLike,
    *,
    dt: float = 0.1,
    axis: int = -2,
    eps: float = 1e-9,
    v_static: float = 0.1,
    reduce: str = "none",
    use_percentile: float | None = 99.0,   # None → use max()
) -> np.ndarray | float:
    """
    Input (..., T, 7): [x, y, heading(rad), vx, vy, ax, ay]
    Output is same layout as comfort_raw: (jerk_max_lon, lat_acc_max_abs, yaw_rate_max_abs).
    Static → NaN.
    """
    state = np.asarray(traj_state, float)
    state = np.moveaxis(state, axis, -2)
    *lead, T, _ = state.shape
    if T < 4:
        raise ValueError("Need ≥4 frames")

    psi     = state[..., 2]
    vx, vy  = state[..., 3], state[..., 4]
    ax, ay  = state[..., 5], state[..., 6]

    speed = np.sqrt(vx**2 + vy**2)
    moving = speed.max(axis=-1) >= v_static

    hx, hy = np.cos(psi), np.sin(psi)
    a_lon = ax * hx + ay * hy
    a_lat = -ax * hy + ay * hx

    j_lon = np.diff(a_lon, axis=-1) / dt
    yaw_rate = np.diff(psi, axis=-1) / dt

    def agg(arr):
        if use_percentile is None:
            return np.abs(arr).max(axis=-1)
        return np.percentile(np.abs(arr), use_percentile, axis=-1)

    jerk_max = agg(j_lon)
    lat_acc_max = agg(a_lat[..., 1:])     # align with j_lon / yaw_rate
    yaw_rate_max = agg(yaw_rate)

    jerk_max = np.where(moving, jerk_max, np.nan)
    lat_acc_max = np.where(moving, lat_acc_max, np.nan)
    yaw_rate_max = np.where(moving, yaw_rate_max, np.nan)

    out = np.stack([jerk_max, lat_acc_max, yaw_rate_max], axis=-1)
    out = np.moveaxis(out, -2, axis)

    if reduce == "none":
        return out
    return np.nanmean(out, axis=tuple(range(out.ndim - 1)))

def comfort_score_norm(
    traj_xy: ArrayLike,
    *,
    dt: float = 0.1,
    axis: int = -2,
    eps: float = 1e-9,
    v_static: float = 0.1,
    pct: float | None = None,
    length_eps: float = 1.0,
    j_scale: float = 1.0,
    a_scale: float = 1.0,
    y_scale: float = 1.0,
    reduce: str = "none",
    return_components: bool = False,
):
    """
    Normalized comfort score in (0,1], higher is better.
    Uses per-meter spikes then S = 1/(1 + value/scale), aggregated by geometric mean.
    Static or too-short clips → NaN.
    """
    xy = np.asarray(traj_xy, float)
    if axis != -2:
        xy = np.moveaxis(xy, axis, -2)
    bs = xy.shape[:-2]
    xy_s = xy.reshape(-1, xy.shape[-2], 2)
    N, T, _ = xy_s.shape
    if T < 5:
        raise ValueError("Need ≥5 frames")

    v = (xy_s[:, 2:, :] - xy_s[:, :-2, :]) / (2 * dt)
    a = (xy_s[:, 2:, :] - 2*xy_s[:, 1:-1, :] + xy_s[:, :-2, :]) / (dt**2)
    a_c = a[:, 1:-1, :]
    j = (a[:, 2:, :] - a[:, :-2, :]) / (2 * dt)

    speed = np.linalg.norm(v, axis=-1)
    moving = speed.max(axis=1) >= v_static

    th2 = np.arctan2(v[:, 2:, 1], v[:, 2:, 0])
    th0 = np.arctan2(v[:, :-2, 1], v[:, :-2, 0])
    yaw_rt = ((th2 - th0 + np.pi) % (2*np.pi) - np.pi) / (2 * dt)

    def t_reduce(x):
        return np.percentile(x, pct, axis=1) if pct is not None else x.mean(axis=1)

    jerk_p = t_reduce(np.linalg.norm(j, axis=-1))
    acc_p  = t_reduce(np.linalg.norm(a_c, axis=-1))
    yaw_p  = t_reduce(np.abs(yaw_rt))

    lengths = np.sum(np.linalg.norm(np.diff(xy_s, axis=1), axis=-1), axis=1)
    valid = moving & (lengths > length_eps)

    jerk_pm = np.where(valid, jerk_p / (lengths + eps), np.nan)
    acc_pm  = np.where(valid, acc_p  / (lengths + eps), np.nan)
    yaw_pm  = np.where(valid, yaw_p  / (lengths + eps), np.nan)

    Sj = 1.0 / (1.0 + (jerk_pm / max(j_scale, eps)))
    Sa = 1.0 / (1.0 + (acc_pm  / max(a_scale, eps)))
    Sy = 1.0 / (1.0 + (yaw_pm  / max(y_scale, eps)))

    comp = np.vstack([Sj, Sa, Sy])
    S = np.exp(np.nanmean(np.log(comp), axis=0)).reshape(bs)

    if return_components:
        comps = np.stack([Sj, Sa, Sy], axis=-1).reshape(*bs, 3)
        if reduce == "none":
            return S, comps
        return float(np.nanmean(S)), np.nanmean(comps, axis=tuple(range(comps.ndim - 1)))
    else:
        if reduce == "none":
            return S
        return float(np.nanmean(S))

def speed_score(
    traj_xy: ArrayLike,
    *,
    dt: float = 0.1,
    axis: int = -2,
    v_ref: float = 6.0,
    k: float = 2.5,           # v_max = k * v_ref
    v_static: float = 0.1,
    use_percentile=None,      # None→mean; or e.g., 90/95
    reduce: str = "none",
):
    """
    Log-linear mapping encouraging faster movement (up to a soft cap):
        v_max = k * v_ref
        S = log(1 + v_stat) / log(1 + v_max), clipped to [0,1]
    Non-moving clips get 0.0 to keep monotonicity across datasets.
    """
    xy, bs = _prep_xy(np.asarray(traj_xy, float), axis)
    v = np.linalg.norm(np.diff(xy, axis=1) / dt, axis=-1)

    v_stat = (v.mean(axis=1) if use_percentile is None
              else np.percentile(v, use_percentile, axis=1))

    moving = v.max(axis=1) >= v_static
    scores = np.full_like(v_stat, 0.0, dtype=float)  # default 0.0 for non-moving

    v_max = k * v_ref
    denom = np.log1p(v_max)
    denom = np.maximum(denom, 1e-12)
    num = np.log1p(v_stat[moving])
    scores[moving] = np.clip(num / denom, 0.0, 1.0)

    out = scores.reshape(bs)
    return out if reduce == "none" else float(np.nanmean(out))

# ------------------------------------------------------------------
# Aggregation helper using external distribution metric
# ------------------------------------------------------------------
from .traj_distribution import traj_fid_mtr

def get_others_quality(agents_traj, gts):
    """
    Aggregate agent-level quality & consistency.
    agents_traj: iterable of per-agent [(ids, traj_xy), ...] where traj_xy len ≥ 11.
    gts: ground-truth set for FID metric.
    Returns (agent_quality_dict|None, scene_quality|None, agent_consistency_dict|None)
    """
    vals = []
    vals_consistency = []
    trajs_11 = []
    for agent_traj in agents_traj:
        this_vals = []
        this_val_consistency = []
        for traj_idxs, traj in agent_traj:
            if len(traj) >= 11:
                trajs_11.append(np.array(traj))
            else:
                continue
            traj = [traj]
            agent_comfort_r = comfort_raw(traj, reduce='mean')  # (3,)
            agent_ss = speed_score(traj, reduce='mean')
            agent_crms = curvature_rms(traj, reduce='mean')
            agent_consistency = trajectory_consistency(traj, reduce='mean')
            this_vals.append([agent_comfort_r[0], agent_comfort_r[1], agent_comfort_r[2], agent_ss, agent_crms])
            this_val_consistency.append(agent_consistency)
        if len(this_vals) == 0:
            continue
        this_vals = np.array(this_vals)
        this_vals = np.nanmean(this_vals, axis=0)
        vals.append(this_vals.tolist())
        if len(this_val_consistency) == 0:
            continue
        this_vals_consistency = np.array(this_val_consistency)
        this_vals_consistency = np.nanmean(this_vals_consistency, axis=0)
        vals_consistency.append(this_vals_consistency)

    if len(vals) > 0:
        vals = np.array(vals)
        vals = np.nanmean(vals, axis=0)
        agent_jerk, agent_acc, agent_yaw_rate, agent_ss, agent_crms = vals

        agents_fid = traj_fid_mtr(trajs_11, gts, stride=1)

        agent_quality = {
            'agent_2_ego_distribution': {'traj_fid': agents_fid},
            'agent_quality': {
                'comfort_raw': [agent_jerk, agent_acc, agent_yaw_rate],
                'speed_score': agent_ss,
                'geometry-curvature': agent_crms,
            }
        }
        scene_quality = None
    else:
        agent_quality = None
        scene_quality = None

    if len(vals_consistency) > 0:
        vals_consistency = np.array(vals_consistency)
        vals_consistency = np.nanmean(vals_consistency)
        agent_consistency = {'agent_consistency': vals_consistency}
    else:
        agent_consistency = None

    return agent_quality, scene_quality, agent_consistency

# ------------------------------------------------------------------
# Self-test
# ------------------------------------------------------------------
if __name__ == "__main__":
    np.random.seed(0)
    B = 8
    T = 101
    t = np.linspace(0, 4 * np.pi, T)

    single_traj = np.stack([30 * np.cos(t), 30 * np.sin(t)], -1)      # (T,2)
    pred = np.repeat(single_traj[None, ...], B, axis=0)               # (B,T,2)
    pred += 0.4 * np.random.randn(B, T, 2) + 0.3 * np.roll(pred, 10, axis=1)

    print("Feasibility (mean):", feasibility_rate(pred, reduce="mean"))
    print("Jerk RMS:",           jerk_rms(pred, reduce="mean"))
    print("Curvature RMS (score):", curvature_rms(pred, reduce="mean"))
    print("dCurvature RMS (score):", dcurvature_rms(pred, reduce="mean"))
    print("Speed-var score:",    speed_var_score(pred, reduce="mean"))
    print("Yaw-rate RMS (deg/s):", yaw_rate_rms_deg(pred, reduce="mean"))
    print("Vel-consistency:",    velocity_consistency(pred, reduce="mean"))
    print("Acc-consistency:",    acceleration_consistency(pred, reduce="mean"))
    print("Trajectory consistency:", trajectory_consistency(pred, reduce="mean"))
