"""
E1/E2 Experiments for Doob–h IPFP Consistency (JAX first-class)
Doob–h IPFP 一致性实验（JAX 一等实现）

E1: 与基于观测构造的 Anchor 路径比较（构造 ρ^{obs}_{t_k} vs 直接 Doob-likelihood）
    - 同一观测 {y_k} 下，比较两种观测处理模式得到的路径密度 L1 差异
    - 指标：逐时刻 L1(ρ_k^doob, ρ_k^anchor) 与最大/平均误差

E2: 将 RTS 近似边际当作“近真值”锚点，与 Doob–h 路径比较（稍后补充）

备注：为减少运行时间，默认使用轻量配置；可通过 CLI 参数调节。
"""

from __future__ import annotations

import os
import math
import argparse
from typing import Dict

import jax
import jax.numpy as jnp
from pathlib import Path
import sys

# Ensure project root is on sys.path for namespace imports / 确保项目根路径在sys.path中
sys.path.append(str(Path(__file__).resolve().parents[2]))

from mmsbvi.core.types import GridConfig1D, OUProcessParams, MMSBProblem, IPFPConfig
from mmsbvi.algorithms.ipfp_1d import solve_mmsb_ipfp_1d_fixed, jax_trapz
from theoretical_verification.utils.rts import simulate_lgssm, kalman_rts
from visualization.doobh_ipfp_visualization import (
    save_e1_figures,
    save_e2_figures,
    save_e2_quantile_trajectories,
    save_e2_moment_trajectories,
    save_e2_ridgeline_overlays,
    save_e2_compatibility_residuals,
    save_e2_w2_series,
)


jax.config.update("jax_enable_x64", True)


def _gauss(x: jnp.ndarray, mu: float, var: float) -> jnp.ndarray:
    return (1.0 / jnp.sqrt(2 * jnp.pi * var)) * jnp.exp(-0.5 * (x - mu) ** 2 / var)


def _map_lg_to_ou(A: float, Q: float, dt: float = 1.0) -> OUProcessParams:
    theta = -math.log(A) / dt
    sigma = math.sqrt(2.0 * theta * Q / max(1e-12, (1.0 - A ** 2)))
    return OUProcessParams(mean_reversion=theta, diffusion=sigma, equilibrium_mean=0.0)


def run_e1_doobh_vs_anchor(
    seed: int = 0,
    A: float = 0.9,
    C: float = 1.0,
    Q: float = 0.1 ** 2,
    R: float = 0.2 ** 2,
    K_steps: int = 32,
    grid_points: int = 401,
    coverage_sigma: float = 6.0,
    compiled: bool = True,
    tol: float = 1e-6,
    maxit: int = 400,
) -> Dict:
    """
    E1 实验：同一观测集下，比较 Doob-likelihood 与 construct_anchors 两种观测处理模式。
    返回逐时刻 L1 差异与汇总统计。
    """
    key = jax.random.PRNGKey(seed)
    xs, ys = simulate_lgssm(key, A, C, Q, R, K_steps)

    # OU 对应参数（dt=1）
    ou_params = _map_lg_to_ou(A, Q, dt=1.0)
    obs_times = jnp.arange(float(K_steps))

    # 用 RTS 确定合理网格范围
    m_rts, P_rts = kalman_rts(ys, A, C, Q, R)
    mu_min = float(jnp.min(m_rts))
    mu_max = float(jnp.max(m_rts))
    sigma_max = float(jnp.max(jnp.sqrt(P_rts)))
    bounds = (mu_min - coverage_sigma * sigma_max, mu_max + coverage_sigma * sigma_max)
    GRID = GridConfig1D.create(grid_points, bounds)

    # --- Doob-likelihood 模式 ---
    problem_doob = MMSBProblem(
        observation_times=obs_times,
        y_observations=ys,
        ou_params=ou_params,
        grid=GRID,
        C=C,
        R=R,
    )
    cfg_doob = IPFPConfig(
        max_iterations=maxit,
        tolerance=tol,
        check_interval=10,
        verbose=False,
        observation_mode="doob_likelihood",
        compiled_loop=compiled,
        compiled_max_iterations=maxit,
        compiled_check_interval=10,
    )
    sol_doob = solve_mmsb_ipfp_1d_fixed(problem_doob, cfg_doob)

    # --- Anchor：使用 Doob 路径密度作为硬边际（自洽一致性测试，理论应贴合）---
    anchors = list(sol_doob.path_densities)  # List[Density1D]
    problem_anchor = MMSBProblem(
        observation_times=obs_times,
        observed_marginals=anchors,
        ou_params=ou_params,
        grid=GRID,
        C=C,
        R=R,
    )
    cfg_anchor = IPFPConfig(
        max_iterations=maxit,
        tolerance=tol,
        check_interval=10,
        verbose=False,
        compiled_loop=compiled,
        compiled_max_iterations=maxit,
        compiled_check_interval=10,
        # 确保可更新自由势
        fixed_potential_mask=None,
    )
    sol_anchor = solve_mmsb_ipfp_1d_fixed(problem_anchor, cfg_anchor)

    # 逐时刻 L1(ρ^doob, ρ^anchor)
    h = GRID.spacing
    d1 = jnp.stack(sol_doob.path_densities, axis=0)
    d2 = jnp.stack(sol_anchor.path_densities, axis=0)
    l1_arr = h * (
        jnp.sum(jnp.abs(d1 - d2), axis=1)
        - 0.5 * (jnp.abs(d1[:, 0] - d2[:, 0]) + jnp.abs(d1[:, -1] - d2[:, -1]))
    )

    return {
        "GRID": GRID,
        "l1_diffs": l1_arr,
        "max_l1": float(jnp.max(l1_arr)),
        "mean_l1": float(jnp.mean(l1_arr)),
        "note": "anchors induced from Doob path (self-consistency)",
    }


def save_e1_report(res: Dict, out_dir: str = "theoretical_verification/results") -> str:
    os.makedirs(out_dir, exist_ok=True)
    path = os.path.join(out_dir, "doobh_ipfp_e1_report.txt")
    with open(path, "w") as f:
        f.write("E1: Doob-likelihood vs Constructed Anchors\n")
        f.write(f"max L1 = {res['max_l1']:.3e}\n")
        f.write(f"mean L1 = {res['mean_l1']:.3e}\n")
        f.write(f"grid_points = {res['GRID'].n_points}, bounds = {res['GRID'].bounds}\n")
    return path


def run_e2_doobh_vs_rho_obs(
    seed: int = 0,
    A: float = 0.9,
    C: float = 1.0,
    Q: float = 0.1 ** 2,
    R: float = 0.2 ** 2,
    K_steps: int = 16,
    grid_points: int = 301,
    coverage_sigma: float = 6.0,
    compiled: bool = True,
    tol: float = 1e-6,
    maxit: int = 400,
) -> Dict:
    """
    E2 实验（小规模）：同一观测集下，比较 Doob-likelihood 与 硬锚 ρ^{obs}_{t_k}（由 r·ℓ 构造）。
    指标：逐时刻 L1(ρ_k^doob, ρ_k^{obs-anchored}) 与汇总统计。

    Small-scale E2: Compare Doob–likelihood vs hard anchors ρ^{obs} constructed per-time.
    """
    key = jax.random.PRNGKey(seed)
    xs, ys = simulate_lgssm(key, A, C, Q, R, K_steps)

    # OU params and time grid
    ou_params = _map_lg_to_ou(A, Q, dt=1.0)
    obs_times = jnp.arange(float(K_steps))

    # Grid via RTS for robust coverage
    m_rts, P_rts = kalman_rts(ys, A, C, Q, R)
    mu_min = float(jnp.min(m_rts))
    mu_max = float(jnp.max(m_rts))
    sigma_max = float(jnp.max(jnp.sqrt(P_rts)))
    bounds = (mu_min - coverage_sigma * sigma_max, mu_max + coverage_sigma * sigma_max)
    GRID = GridConfig1D.create(grid_points, bounds)

    # --- Doob-likelihood path ---
    problem_doob = MMSBProblem(
        observation_times=obs_times,
        y_observations=ys,
        ou_params=ou_params,
        grid=GRID,
        C=C,
        R=R,
    )
    cfg_doob = IPFPConfig(
        max_iterations=maxit,
        tolerance=tol,
        check_interval=10,
        verbose=False,
        observation_mode="doob_likelihood",
        compiled_loop=compiled,
        compiled_max_iterations=maxit,
        compiled_check_interval=10,
    )
    sol_doob = solve_mmsb_ipfp_1d_fixed(problem_doob, cfg_doob)

    # --- ρ^{obs} 硬锚路径（构造 r·ℓ 并作为 observed_marginals）---
    # 使用求解器内置的 construct_anchors 模式，确保锚点为 r·ℓ 归一化
    problem_obs = MMSBProblem(
        observation_times=obs_times,
        y_observations=ys,
        ou_params=ou_params,
        grid=GRID,
        C=C,
        R=R,
    )
    cfg_obs = IPFPConfig(
        max_iterations=maxit,
        tolerance=tol,
        check_interval=10,
        verbose=False,
        observation_mode="construct_anchors",
        compiled_loop=compiled,
        compiled_max_iterations=maxit,
        compiled_check_interval=10,
        fixed_potential_mask=None,
    )
    sol_obs = solve_mmsb_ipfp_1d_fixed(problem_obs, cfg_obs)

    # 逐时刻 L1 差异
    h = GRID.spacing
    d_doob = jnp.stack(sol_doob.path_densities, axis=0)
    d_obs = jnp.stack(sol_obs.path_densities, axis=0)
    l1_arr = h * (
        jnp.sum(jnp.abs(d_doob - d_obs), axis=1)
        - 0.5 * (jnp.abs(d_doob[:, 0] - d_obs[:, 0]) + jnp.abs(d_doob[:, -1] - d_obs[:, -1]))
    )

    # 选择对比最明显的若干时刻（Top-5 L1）/ pick top-5 time indices by L1
    top_idx = jnp.argsort(l1_arr)[-5:]
    peaks = list(map(int, list(top_idx)))

    return {
        "GRID": GRID,
        "l1_diffs": l1_arr,
        "max_l1": float(jnp.max(l1_arr)),
        "mean_l1": float(jnp.mean(l1_arr)),
        "note": "anchors are rho_obs constructed via r*likelihood",
        # 供可视化：两条路径密度与网格/时间 / for visualization: both paths, grid and times
        "dens_doob": d_doob,
        "dens_obs": d_obs,
        "x": GRID.points,
        "times": obs_times,
        "peaks_k": peaks,
        # 额外信息：Doob兼容性残差需要 r 与 ℓ
        "y_obs": ys,
        "C": C,
        "R": R,
        "ou_params": ou_params,
    }


def save_e2_report(res: Dict, out_dir: str = "theoretical_verification/results") -> str:
    os.makedirs(out_dir, exist_ok=True)
    path = os.path.join(out_dir, "doobh_ipfp_e2_report.txt")
    with open(path, "w") as f:
        f.write("E2: Doob-likelihood vs Hard Anchors rho_obs\n")
        f.write(f"max L1 = {res['max_l1']:.3e}\n")
        f.write(f"mean L1 = {res['mean_l1']:.3e}\n")
        f.write(f"grid_points = {res['GRID'].n_points}, bounds = {res['GRID'].bounds}\n")
    return path


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--K", type=int, default=32)
    parser.add_argument("--grid", type=int, default=401)
    parser.add_argument("--compiled", action="store_true")
    parser.add_argument("--tol", type=float, default=1e-6)
    parser.add_argument("--maxit", type=int, default=400)
    parser.add_argument("--run_e2", action="store_true")
    args = parser.parse_args()

    # E1
    res_e1 = run_e1_doobh_vs_anchor(
        K_steps=args.K,
        grid_points=args.grid,
        compiled=args.compiled,
        tol=args.tol,
        maxit=args.maxit,
    )
    rpt = save_e1_report(res_e1)
    comp, series = save_e1_figures(res_e1)
    print(f"E1 done. report={rpt}, composite={comp}, series={series}, maxL1={res_e1['max_l1']:.3e}, meanL1={res_e1['mean_l1']:.3e}")

    # E2 （小规模）：Doob vs ρ^{obs} 硬锚
    if args.run_e2:
        res_e2 = run_e2_doobh_vs_rho_obs(
            K_steps=min(args.K, 16),  # 小规模
            grid_points=min(args.grid, 401),
            compiled=args.compiled,
            tol=args.tol,
            maxit=min(args.maxit, 400),
        )
        rpt2 = save_e2_report(res_e2)
        comp2, series2 = save_e2_figures(res_e2)
        # 额外图：分位数轨迹、矩统计、ridgeline 叠加 / extra figures for publication quality
        qfig = save_e2_quantile_trajectories(res_e2)
        mfig = save_e2_moment_trajectories(res_e2)
        rfig = save_e2_ridgeline_overlays(res_e2)
        cfig = save_e2_compatibility_residuals(res_e2)
        wfig = save_e2_w2_series(res_e2)
        print(
            f"E2 done. report={rpt2}, composite={comp2}, series={series2}, "
            f"maxL1={res_e2['max_l1']:.3e}, meanL1={res_e2['mean_l1']:.3e}"
        )


if __name__ == "__main__":
    main()


