#!/usr/bin/env python3
"""
Doob–h vs Anchor (E1) Visualization
与几何极限可视化一致的发表级配色与风格
"""

from __future__ import annotations

import os
from pathlib import Path
from typing import Dict, Tuple

import jax
import jax.numpy as jnp
from jax.scipy.special import gammaln
import matplotlib.pyplot as plt

# 与 geometric_limits_visualization 保持一致的风格/配色
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman'],
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 16,
    'legend.fontsize': 11,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'figure.dpi': 200,
    'savefig.dpi': 400,
    'text.usetex': False,
    'axes.unicode_minus': False,
    'figure.constrained_layout.use': True
})

COLORS = {
    'primary': '#9D110E',       # Deep red
    'secondary': '#000000',     # Black
    'background': '#FFFFFF',    # White
    'grid': '#E5E5E5'           # Light gray for grid
}


def _stats_text(l1: jnp.ndarray, grid) -> str:
    l1 = jnp.asarray(l1)
    return (
        f"E1: Doob vs Anchors\n"
        f"max L1 = {float(jnp.max(l1)):.2e}\n"
        f"mean L1 = {float(jnp.mean(l1)):.2e}\n"
        f"median L1 = {float(jnp.median(l1)):.2e}\n"
        f"p95 L1 = {float(jnp.percentile(l1, 95.0)):.2e}\n\n"
        f"grid_points = {int(getattr(grid, 'n_points', 0))}\n"
        f"bounds = ({float(grid.bounds[0]):.2f}, {float(grid.bounds[1]):.2f})\n"
    )


def save_e1_figures(res: Dict, save_dir: str = "theoretical_verification/results") -> Tuple[str, str]:
    """
    保存 E1 实验图（无折线图，便于对比，长宽一致）：
    - e1_composite.png: 方形图，仅展示 L1 的色条热力图 + 颜色条
    - e1_l1_strip.png: 方形图，独立的 L1 色条热力图
    """
    os.makedirs(save_dir, exist_ok=True)
    l1 = jnp.asarray(res["l1_diffs"])  # shape (K,)
    K = l1.shape[0]
    x = jnp.arange(K)
    grid = res.get("GRID")

    # 1) Composite (square) — heat strip + colorbar
    fig, ax = plt.subplots(1, 1, figsize=(10, 6), facecolor=COLORS['background'])
    vals = jnp.log10(jnp.clip(l1, 1e-16))
    im = ax.imshow(vals[None, :], aspect='auto', cmap='magma',
                   extent=[-0.5, K - 0.5, 0.0, 1.0],
                   vmin=float(vals.min()), vmax=float(vals.max()))
    ax.set_yticks([])
    ax.set_xlim(-0.5, K - 0.5)
    ax.set_xlabel('time k')
    ax.set_title('E1. L1(ρ_doob, ρ_anchor) — log10 scale')
    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label(r'$\log_{10}(\mathrm{L1})$')
    comp_path = os.path.join(save_dir, 'e1_composite.png')
    plt.savefig(comp_path, dpi=400, bbox_inches='tight', facecolor=COLORS['background'])
    plt.close(fig)

    # 2) Independent (square) — heat strip only
    fig2, ax2 = plt.subplots(1, 1, figsize=(10, 6), facecolor=COLORS['background'])
    im2 = ax2.imshow(vals[None, :], aspect='auto', cmap='magma',
                     extent=[-0.5, K - 0.5, 0.0, 1.0],
                     vmin=float(vals.min()), vmax=float(vals.max()))
    ax2.set_yticks([])
    ax2.set_xlim(-0.5, K - 0.5)
    ax2.set_xlabel('time k')
    ax2.set_title('E1. L1 — log10 scale (strip)')
    series_path = os.path.join(save_dir, 'e1_l1_strip.png')
    plt.savefig(series_path, dpi=400, bbox_inches='tight', facecolor=COLORS['background'])
    plt.close(fig2)

    return comp_path, series_path



def save_e2_figures(res: Dict, save_dir: str = "theoretical_verification/results") -> Tuple[str, str]:
    """
    保存 E2 实验图（风格与 E1 一致）：
    - e2_composite.png: 方形图，展示 L1 色条热力图 + 颜色条
    - e2_l1_strip.png: 方形图，仅 L1 色条热力图
    """
    os.makedirs(save_dir, exist_ok=True)
    l1 = jnp.asarray(res["l1_diffs"])  # shape (K,)
    K = l1.shape[0]
    grid = res.get("GRID")

    vals = jnp.log10(jnp.clip(l1, 1e-16))

    # Composite
    fig, ax = plt.subplots(1, 1, figsize=(10, 6), facecolor=COLORS['background'])
    im = ax.imshow(vals[None, :], aspect='auto', cmap='magma',
                   extent=[-0.5, K - 0.5, 0.0, 1.0],
                   vmin=float(vals.min()), vmax=float(vals.max()))
    ax.set_yticks([])
    ax.set_xlim(-0.5, K - 0.5)
    ax.set_xlabel('time k')
    ax.set_title('E2. L1(ρ_doob, ρ_obs-anchored) — log10 scale')
    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label(r'$\log_{10}(\mathrm{L1})$')
    comp_path = os.path.join(save_dir, 'e2_composite.png')
    plt.savefig(comp_path, dpi=400, bbox_inches='tight', facecolor=COLORS['background'])
    plt.close(fig)

    # Strip only
    fig2, ax2 = plt.subplots(1, 1, figsize=(10, 6), facecolor=COLORS['background'])
    im2 = ax2.imshow(vals[None, :], aspect='auto', cmap='magma',
                     extent=[-0.5, K - 0.5, 0.0, 1.0],
                     vmin=float(vals.min()), vmax=float(vals.max()))
    ax2.set_yticks([])
    ax2.set_xlim(-0.5, K - 0.5)
    ax2.set_xlabel('time k')
    ax2.set_title('E2. L1 — log10 scale (strip)')
    series_path = os.path.join(save_dir, 'e2_l1_strip.png')
    plt.savefig(series_path, dpi=400, bbox_inches='tight', facecolor=COLORS['background'])
    plt.close(fig2)

    return comp_path, series_path


def _trapz_weights(n: int) -> jnp.ndarray:
    w = jnp.ones((n,), dtype=jnp.float64)
    w = w.at[0].set(0.5).at[-1].set(0.5)
    return w


def _discrete_quantiles(x: jnp.ndarray, pdf: jnp.ndarray, qs: jnp.ndarray) -> jnp.ndarray:
    """
    在离散网格上对 PDF 计算分位点。梯形累计近似 CDF，并做线性插值。
    Compute quantiles on a discrete grid via trapezoid CDF and linear interpolation.
    Returns shape (len(qs),).
    """
    n = x.shape[0]
    h = float(x[1] - x[0]) if n > 1 else 1.0
    w = _trapz_weights(n)
    cdf = jnp.cumsum(pdf * w) * h
    cdf = cdf / (cdf[-1] + 1e-15)

    def q_one(q):
        idx = jnp.searchsorted(cdf, q, side='left')
        idx = jnp.clip(idx, 1, n - 1)
        x0 = x[idx - 1]
        x1 = x[idx]
        c0 = cdf[idx - 1]
        c1 = cdf[idx]
        t = (q - c0) / (jnp.maximum(c1 - c0, 1e-15))
        return x0 + t * (x1 - x0)

    return jnp.array([q_one(q) for q in qs])


def save_e2_quantile_trajectories(res: Dict, save_dir: str = "theoretical_verification/results") -> str:
    """
    分位数轨迹（0.1/0.5/0.9）：对比 Doob 与 ρ_obs 锚点路径的分位数随时间的演化。
    """
    os.makedirs(save_dir, exist_ok=True)
    dens_doob = jnp.asarray(res.get("dens_doob"))  # (K,n)
    dens_obs = jnp.asarray(res.get("dens_obs"))    # (K,n)
    x = jnp.asarray(res.get("x"))
    K = dens_doob.shape[0]
    qs = jnp.array([0.1, 0.5, 0.9])

    q_doob = []
    q_obs = []
    for k in range(K):
        qd = _discrete_quantiles(x, dens_doob[k], qs)
        qo = _discrete_quantiles(x, dens_obs[k], qs)
        q_doob.append(qd)
        q_obs.append(qo)
    q_doob = jnp.stack(q_doob, axis=0)  # (K,3)
    q_obs = jnp.stack(q_obs, axis=0)    # (K,3)

    t = jnp.arange(K)
    fig, ax = plt.subplots(1, 1, figsize=(10, 6), facecolor=COLORS['background'])
    # Doob solid
    ax.plot(t, q_doob[:, 1], '-', color=COLORS['primary'], linewidth=3, label='Doob median (0.5)')
    ax.plot(t, q_doob[:, 0], '-', color=COLORS['primary'], linewidth=1.5, alpha=0.7, label='Doob q=0.1')
    ax.plot(t, q_doob[:, 2], '-', color=COLORS['primary'], linewidth=1.5, alpha=0.7, label='Doob q=0.9')
    # Anchors dashed
    ax.plot(t, q_obs[:, 1], '--', color=COLORS['secondary'], linewidth=3, label='ρ_obs median (0.5)')
    ax.plot(t, q_obs[:, 0], '--', color=COLORS['secondary'], linewidth=1.5, alpha=0.7, label='ρ_obs q=0.1')
    ax.plot(t, q_obs[:, 2], '--', color=COLORS['secondary'], linewidth=1.5, alpha=0.7, label='ρ_obs q=0.9')

    ax.set_xlabel('time k')
    ax.set_ylabel('quantiles')
    ax.set_title('E2. Quantile trajectories (0.1/0.5/0.9): Doob vs ρ_obs')
    ax.grid(True, alpha=0.3, color=COLORS['grid'])
    ax.legend(ncol=2)
    out = os.path.join(save_dir, 'e2_quantiles.png')
    plt.savefig(out, dpi=400, bbox_inches='tight', facecolor=COLORS['background'])
    plt.close(fig)
    return out


def _trapz(x: jnp.ndarray, y: jnp.ndarray) -> float:
    n = x.shape[0]
    h = float(x[1] - x[0]) if n > 1 else 1.0
    w = _trapz_weights(n)
    return float(h * jnp.sum(y * w))


def save_e2_moment_trajectories(res: Dict, save_dir: str = "theoretical_verification/results") -> str:
    """
    均值/标准差轨迹：两条路径的均值与标准差随时间的演化。
    """
    os.makedirs(save_dir, exist_ok=True)
    dens_doob = jnp.asarray(res.get("dens_doob"))  # (K,n)
    dens_obs = jnp.asarray(res.get("dens_obs"))    # (K,n)
    x = jnp.asarray(res.get("x"))
    K = dens_doob.shape[0]

    mean_d, std_d, mean_o, std_o = [], [], [], []
    for k in range(K):
        rho_d = dens_doob[k]
        rho_o = dens_obs[k]
        m_d = _trapz(x, rho_d * x)
        v_d = _trapz(x, rho_d * (x - m_d) ** 2)
        m_o = _trapz(x, rho_o * x)
        v_o = _trapz(x, rho_o * (x - m_o) ** 2)
        mean_d.append(m_d)
        std_d.append(float(jnp.sqrt(v_d)))
        mean_o.append(m_o)
        std_o.append(float(jnp.sqrt(v_o)))

    t = jnp.arange(K)
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 7), sharex=True, facecolor=COLORS['background'])

    # Means
    ax1.plot(t, mean_d, '-', color=COLORS['primary'], linewidth=3, label='Doob mean')
    ax1.plot(t, mean_o, '--', color=COLORS['secondary'], linewidth=3, label='ρ_obs mean')
    ax1.set_ylabel('mean')
    ax1.set_title('E2. Mean and Std trajectories: Doob vs ρ_obs')
    ax1.grid(True, alpha=0.3, color=COLORS['grid'])
    ax1.legend()

    # Std
    ax2.plot(t, std_d, '-', color=COLORS['primary'], linewidth=3, label='Doob std')
    ax2.plot(t, std_o, '--', color=COLORS['secondary'], linewidth=3, label='ρ_obs std')
    ax2.set_xlabel('time k')
    ax2.set_ylabel('std')
    ax2.grid(True, alpha=0.3, color=COLORS['grid'])
    ax2.legend()

    out = os.path.join(save_dir, 'e2_moments.png')
    plt.savefig(out, dpi=400, bbox_inches='tight', facecolor=COLORS['background'])
    plt.close(fig)
    return out


def save_e2_ridgeline_overlays(res: Dict, save_dir: str = "theoretical_verification/results") -> str:
    """
    Ridgeline 叠加：在若干高差异时刻（Top-L1）叠加 Doob 与 ρ_obs 的密度曲线，纵向错位。
    """
    os.makedirs(save_dir, exist_ok=True)
    dens_doob = jnp.asarray(res.get("dens_doob"))  # (K,n)
    dens_obs = jnp.asarray(res.get("dens_obs"))    # (K,n)
    x = jnp.asarray(res.get("x"))
    peaks = res.get("peaks_k", list(range(min(5, dens_doob.shape[0]))))

    # 计算全局尺度用于间距
    ymax = float(jnp.max(jnp.stack([jnp.max(dens_doob), jnp.max(dens_obs)])))
    gap = 1.2 * ymax

    fig, ax = plt.subplots(1, 1, figsize=(10, 7), facecolor=COLORS['background'])
    for i, k in enumerate(sorted(peaks)):
        base = i * gap
        y_d = base + dens_doob[k]
        y_o = base + dens_obs[k]
        ax.fill_between(x, base, y_d, color=COLORS['primary'], alpha=0.35, linewidth=0)
        ax.plot(x, y_d, color=COLORS['primary'], linewidth=2.5, label='Doob' if i == 0 else None)
        ax.fill_between(x, base, y_o, color=COLORS['secondary'], alpha=0.25, linewidth=0)
        ax.plot(x, y_o, color=COLORS['secondary'], linewidth=2.5, linestyle='--', label='ρ_obs' if i == 0 else None)
        ax.text(x[0], base + 0.02 * gap, f"k={k}", fontsize=10, color=COLORS['secondary'])

    ax.set_xlabel('x')
    ax.set_ylabel('offset index')
    ax.set_title('E2. Ridgeline overlays at Top-L1 time indices')
    ax.grid(True, alpha=0.2, color=COLORS['grid'])
    ax.legend()
    out = os.path.join(save_dir, 'e2_ridgeline.png')
    plt.savefig(out, dpi=400, bbox_inches='tight', facecolor=COLORS['background'])
    plt.close(fig)
    return out


def _log_r_ou(x: jnp.ndarray, ou_params) -> jnp.ndarray:
    """OU 平衡边际 r(x) 的对数密度（与 1D 算法中构造锚点一致）。"""
    theta = jnp.asarray(getattr(ou_params, 'mean_reversion'), jnp.float64)
    sigma = jnp.asarray(getattr(ou_params, 'diffusion'), jnp.float64)
    mu = jnp.asarray(getattr(ou_params, 'equilibrium_mean'), jnp.float64)
    var_stat = jnp.where(theta > 1e-10, (sigma**2) / (2.0 * theta), (sigma**2) * 10.0)
    return -0.5 * jnp.log(2 * jnp.pi * var_stat) - 0.5 * (x - mu) ** 2 / var_stat


def _log_lik_student(x: jnp.ndarray, y: float, C: float, R: float, nu: float = 50.0) -> jnp.ndarray:
    """与求解器一致的 Student-t 对数似然（忽略常数差异）。"""
    coef = (
        gammaln((nu + 1.0) / 2.0)
        - gammaln(nu / 2.0)
        - 0.5 * jnp.log(nu * jnp.pi * R)
    )
    diff = C * x - y
    return coef - ((nu + 1.0) / 2.0) * jnp.log1p((diff * diff) / (nu * R))


def save_e2_compatibility_residuals(res: Dict, save_dir: str = "theoretical_verification/results") -> str:
    """
    兼容性残差 R_k：
    R_k = sqrt(∫ (log ρ_doob,k(x) − (log r(x) + log ℓ(y_k|x) + c_k))^2 dx)
    其中 c_k 取使均值为 0 的常数（最小二乘意义下最佳平移）。
    可视化为条形图，并叠加 L1 strip（次轴）。
    """
    os.makedirs(save_dir, exist_ok=True)
    dens_doob = jnp.asarray(res.get("dens_doob"))  # (K,n)
    x = jnp.asarray(res.get("x"))
    K = dens_doob.shape[0]
    y_obs = jnp.asarray(res.get("y_obs"))
    C = float(res.get("C"))
    R = float(res.get("R"))
    ou_params = res.get("ou_params")

    # 权重与步长
    n = x.shape[0]
    h = float(x[1] - x[0]) if n > 1 else 1.0
    w = _trapz_weights(n)

    log_r = _log_r_ou(x, ou_params)  # (n,)

    def rk_for_k(k):
        rho = jnp.maximum(dens_doob[k], jnp.asarray(1e-300))
        log_rho = jnp.log(rho)
        # 保持 JAX 标量，禁止 Python float 强制转换
        log_lik = _log_lik_student(x, y_obs[k], C, R)
        base = log_r + log_lik
        # c_k: 加权均值差，使误差均值为 0
        diff = log_rho - base
        c = (h * jnp.sum((diff) * w)) / (h * jnp.sum(w))
        err = diff - c
        l2 = jnp.sqrt(h * jnp.sum((err * err) * w))
        return l2

    R_series = jax.vmap(rk_for_k)(jnp.arange(K))
    L1 = jnp.asarray(res.get("l1_diffs"))

    fig, ax1 = plt.subplots(1, 1, figsize=(10, 6), facecolor=COLORS['background'])
    ks = jnp.arange(K)
    ax1.bar(ks, R_series, color=COLORS['primary'], alpha=0.7, label='R_k (compatibility residual)')
    ax1.set_xlabel('time k')
    ax1.set_ylabel('R_k (L2 of log-diff)')
    ax1.grid(True, alpha=0.3, color=COLORS['grid'])

    ax2 = ax1.twinx()
    ax2.plot(ks, L1, '--', color=COLORS['secondary'], linewidth=2.5, label='L1(ρ_doob, ρ_obs)')
    ax2.set_ylabel('L1 distance')

    # 合并图例
    lines, labels = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines + lines2, labels + labels2, loc='upper right')
    ax1.set_title('E2. Compatibility residual R_k with L1 overlay')

    out = os.path.join(save_dir, 'e2_compatibility_residuals.png')
    plt.savefig(out, dpi=400, bbox_inches='tight', facecolor=COLORS['background'])
    plt.close(fig)
    return out


def _w2_1d_from_quantiles(p: jnp.ndarray, q: jnp.ndarray, x: jnp.ndarray, num_q: int = 256) -> jnp.ndarray:
    """
    近似 W2 距离（1D）：使用分位点插值公式。
    W2^2 ≈ mean_u (F_p^{-1}(u) - F_q^{-1}(u))^2, u=linspace(0,1,num_q).
    返回 W2（非平方）。
    """
    us = jnp.linspace(0.0, 1.0, num_q)
    # 避免端点数值问题
    us = jnp.clip(us, 1e-6, 1.0 - 1e-6)
    xp = _discrete_quantiles(x, p, us)
    xq = _discrete_quantiles(x, q, us)
    w2_sq = jnp.mean((xp - xq) ** 2)
    return jnp.sqrt(jnp.maximum(w2_sq, 0.0))


def save_e2_w2_series(res: Dict, save_dir: str = "theoretical_verification/results") -> str:
    """
    W2(ρ_doob, ρ_obs) 序列：按时间展示 1D Wasserstein-2 距离（离散近似）。
    """
    os.makedirs(save_dir, exist_ok=True)
    dens_doob = jnp.asarray(res.get("dens_doob"))  # (K,n)
    dens_obs = jnp.asarray(res.get("dens_obs"))    # (K,n)
    x = jnp.asarray(res.get("x"))
    K = dens_doob.shape[0]

    def w2k(k):
        return _w2_1d_from_quantiles(dens_doob[k], dens_obs[k], x, num_q=256)

    W2 = jax.vmap(w2k)(jnp.arange(K))
    ks = jnp.arange(K)

    fig, ax = plt.subplots(1, 1, figsize=(10, 6), facecolor=COLORS['background'])
    ax.plot(ks, W2, 'o-', color=COLORS['secondary'], linewidth=3, markersize=6)
    ax.set_xlabel('time k')
    ax.set_ylabel('W2(ρ_doob, ρ_obs)')
    ax.set_title('E2. Wasserstein-2 distance vs time')
    ax.grid(True, alpha=0.3, color=COLORS['grid'])

    out = os.path.join(save_dir, 'e2_w2_series.png')
    plt.savefig(out, dpi=400, bbox_inches='tight', facecolor=COLORS['background'])
    plt.close(fig)
    return out
