"""
Metric vs Constraint (C2) – 1D W2 geodesic baselines, FP residuals, and fast JAX utilities.
度量与约束分离（C2）——1D W2 测地基线、Fokker–Planck 残差、JAX 高性能工具。

提供：
- build_cdf_1d / invert_cdf_1d：CDF/分位反演（单调重排）
- w2_distance_1d：W2 距离（分位 L2 积分）
- quantile_lines_for_geodesic：基线测地的分位等值线（用于叠加）
- fp_residual_series_ou：给定 OU 漂移下的 FP 残差（时序 L² 范数）
"""

from __future__ import annotations

from typing import Tuple, Sequence

import jax
import jax.numpy as jnp

from mmsbvi.core.types import OUProcessParams
from mmsbvi.solvers.pde_solver_1d import _gradient_neumann_1d_fixed as grad_n
from mmsbvi.solvers.pde_solver_1d import _divergence_neumann_1d_fixed as div_n


@jax.jit
def jax_trapz(y: jnp.ndarray, dx: float) -> float:
    return dx * (jnp.sum(y) - 0.5 * (y[0] + y[-1]))


@jax.jit
def build_cdf_1d(rho: jnp.ndarray, h: float) -> jnp.ndarray:
    """Piecewise-linear CDF on uniform grid; last value ≈ 1.
    在均匀网格上构造分段线性 CDF（尾值≈1）。"""
    # cumulative trapezoid
    # cdf[i] = ∫_{0}^{i} ρ dx (使用梯形增量)
    n = rho.shape[0]
    edge = 0.5 * (rho[:-1] + rho[1:]) * h
    # prefix sum of edge contributions; cdf[0]=0
    cdf = jnp.concatenate([jnp.array([0.0], dtype=rho.dtype), jnp.cumsum(edge)])
    # 归一化（数值稳健）
    total = jax_trapz(rho, dx=h) + 1e-15
    return jnp.clip(cdf / total, 0.0, 1.0)


def _interp_inverse_cdf(cdf: jnp.ndarray, x: jnp.ndarray, u: jnp.ndarray) -> jnp.ndarray:
    """Inverse CDF via searchsorted + linear interpolation (JAX compatible).
    用 searchsorted+线性插值实现的分位反演。"""
    # ensure strictly increasing by tiny eps for stability
    eps = 1e-12
    cdf = jnp.maximum(cdf, eps)
    cdf = jnp.minimum(cdf, 1.0 - eps)
    idx = jnp.searchsorted(cdf, u, side="right")
    idx = jnp.clip(idx, 1, cdf.shape[0] - 1)
    x0 = x[idx - 1]
    x1 = x[idx]
    c0 = cdf[idx - 1]
    c1 = cdf[idx]
    t = (u - c0) / (c1 - c0)
    return x0 + t * (x1 - x0)


@jax.jit
def w2_distance_1d(rho0: jnp.ndarray, rho1: jnp.ndarray, x: jnp.ndarray) -> float:
    """Compute W2 distance (1D) using quantile formula.
    用分位公式计算 1D W2 距离。"""
    h = (x[1] - x[0]).astype(jnp.float64)
    n = x.shape[0]
    u = jnp.linspace(0.0 + 1e-9, 1.0 - 1e-9, n, dtype=jnp.float64)
    F0 = build_cdf_1d(rho0, h)
    F1 = build_cdf_1d(rho1, h)
    Q0 = _interp_inverse_cdf(F0, x, u)
    Q1 = _interp_inverse_cdf(F1, x, u)
    diff2 = (Q0 - Q1) ** 2
    # ∫_0^1 (Q0-Q1)^2 du  ≈ 1/n Σ diff2  (等距 u 网格)
    return jnp.sqrt(jnp.mean(diff2))


def quantile_lines_for_geodesic(
    anchors: jnp.ndarray,  # (K, n)
    x: jnp.ndarray,        # (n,)
    quantiles: Sequence[float] = (0.1, 0.5, 0.9),
) -> jnp.ndarray:
    """Return quantile positions for baseline W2 geodesic at each time index (piecewise linear).
    对每个时间步返回基线测地的分位位置（在离散时间上分段线性连接端点）。
    输出形状：(len(quantiles), K)
    """
    K, n = anchors.shape
    h = (x[1] - x[0]).astype(jnp.float64)
    qs = jnp.array(quantiles, dtype=jnp.float64)

    def quantiles_of_density(rho):
        F = build_cdf_1d(rho, h)
        return _interp_inverse_cdf(F, x, qs)  # (len(qs),)

    lines = jax.vmap(quantiles_of_density, in_axes=0)(anchors)  # (K, Q)
    return lines.T  # (Q, K)


def fp_residual_series_ou(
    rho_series: jnp.ndarray,  # (K, n)
    x: jnp.ndarray,            # (n,)
    dt: float,
    ou_params: OUProcessParams,
) -> jnp.ndarray:
    """L² residual of FP: ∂_tρ + ∂_x(−bρ) − (σ²/2)∂_{xx}ρ, per interval.
    返回每个区间的 FP 残差 L² 范数。
    """
    h = (x[1] - x[0]).astype(jnp.float64)
    theta = jnp.asarray(ou_params.mean_reversion, jnp.float64)
    sigma = jnp.asarray(ou_params.diffusion, jnp.float64)
    mu_inf = jnp.asarray(ou_params.equilibrium_mean, jnp.float64)
    b = -theta * (x - mu_inf)  # (n,)

    def one_step(k, carry):
        acc = carry
        rho_k = rho_series[k]
        rho_n = rho_series[k + 1]
        dt_val = jnp.asarray(dt, jnp.float64)

        # time derivative
        dtrho = (rho_n - rho_k) / dt_val
        # spatial operators (Neumann consistent)
        grad_rho = grad_n(rho_k, h)
        lap_rho = div_n(grad_rho, h)
        # convective term ∂_x(b ρ)
        dxb_rho = grad_n(b * rho_k, h)
        # residual
        res = dtrho + dxb_rho - 0.5 * sigma ** 2 * lap_rho
        l2 = jnp.sqrt(jax_trapz(res ** 2, dx=h))
        return acc.at[k].set(l2)

    K = rho_series.shape[0]
    init = jnp.zeros((K - 1,), dtype=jnp.float64)
    residuals = jax.lax.fori_loop(0, K - 1, one_step, init)
    return residuals


