"""
C4 线性高斯极限 ⇒ RTS 回收（Sanity 钟）
专业化实验脚本：生成双子图与统计表，验证我们的方法在线性高斯极限回收 RTS（机器精度级）。
"""

import os
import math
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import ticker as mticker

from mmsbvi.utils.logger import get_logger
from mmsbvi.core.types import GridConfig1D, OUProcessParams, MMSBProblem, IPFPConfig
from mmsbvi.algorithms.ipfp_1d import solve_mmsb_ipfp_1d_fixed, jax_trapz
from ..utils.rts import simulate_lgssm, kalman_rts
from ..visualization.rts_plots import save_rts_recovery_figs


logger = get_logger(__name__)
jax.config.update("jax_enable_x64", True)


def _gauss(x, mu, var):
    return (1.0 / jnp.sqrt(2 * jnp.pi * var)) * jnp.exp(-0.5 * (x - mu) ** 2 / var)


def run_single_seed(
    seed: int = 0,
    A: float = 0.9,
    C: float = 1.0,
    Q: float = 0.1 ** 2,
    R: float = 0.2 ** 2,
    K_steps: int = 64,
    grid_points: int = 801,
    grid_coverage_sigma: float = 6.0,
    ipfp_tol: float = 1e-5,
    ipfp_maxit: int = 400,
):
    """
    Run one experiment instance and write plots & table rows
    运行单实例实验，输出图与表行
    """
    key = jax.random.PRNGKey(seed)
    # 1) Simulate LG-SSM
    xs, ys = simulate_lgssm(key, A, C, Q, R, K_steps)
    # 2) Baseline RTS
    m_rts, P_rts = kalman_rts(ys, A, C, Q, R)

    # 3) Map (A,Q) to OU(θ,σ) with dt=1 so that discrete variance matches
    dt = 1.0
    theta = -math.log(A) / dt
    sigma = math.sqrt(2.0 * theta * Q / (1.0 - A ** 2))
    ou_params = OUProcessParams(mean_reversion=theta, diffusion=sigma, equilibrium_mean=0.0)
    obs_times = jnp.arange(float(K_steps))

    # 4) Grid bounds from filtered stats (tight but safe)
    mu_min = float(jnp.min(m_rts))
    mu_max = float(jnp.max(m_rts))
    sigma_max = float(jnp.max(jnp.sqrt(P_rts)))
    bounds = (mu_min - grid_coverage_sigma * sigma_max, mu_max + grid_coverage_sigma * sigma_max)
    GRID = GridConfig1D.create(grid_points, bounds)

    # 5) Observed densities ρ_k ≈ N(m_f[k], P_f[k]) normalized with trapezoid
    h = GRID.spacing
    obs_densities = []
    for k in range(K_steps):
        d = _gauss(GRID.points, m_rts[k], P_rts[k])
        d = d / (jax_trapz(d, dx=h) + 1e-15)
        obs_densities.append(d)

    problem = MMSBProblem(
        observation_times=obs_times,
        observed_marginals=obs_densities,
        ou_params=ou_params,
        grid=GRID,
    )

    # 6) Our method: IPFP (Doob–h) to recover path densities
    config = IPFPConfig(
        max_iterations=ipfp_maxit,
        tolerance=ipfp_tol,  # 目标精度 / targeting accuracy
        check_interval=10,
        verbose=True,
    )
    solution = solve_mmsb_ipfp_1d_fixed(problem, config)

    # 7) Extract moment curves from densities and compare with RTS
    mu_mmsb = jnp.array([
        jnp.sum(solution.path_densities[k] * GRID.points) * h for k in range(K_steps)
    ])
    P_mmsb = jnp.array([
        jnp.sum(solution.path_densities[k] * (GRID.points - mu_mmsb[k]) ** 2) * h for k in range(K_steps)
    ])

    mean_abs_err = jnp.abs(mu_mmsb - m_rts)
    cov_abs_err = jnp.abs(P_mmsb - P_rts)
    max_mean_err = float(jnp.max(mean_abs_err))
    max_cov_err = float(jnp.max(cov_abs_err))

    return {
        'GRID': GRID,
        'solution': solution,
        'm_rts': m_rts,
        'P_rts': P_rts,
        'mu_mmsb': mu_mmsb,
        'P_mmsb': P_mmsb,
        'mean_abs_err': mean_abs_err,
        'cov_abs_err': cov_abs_err,
        'max_mean_err': max_mean_err,
        'max_cov_err': max_cov_err,
    }


def save_fig_and_table(res: dict, out_dir: str = "theoretical_verification/results"):
    os.makedirs(out_dir, exist_ok=True)
    GRID = res['GRID']
    mean_abs_err = res['mean_abs_err']
    cov_abs_err = res['cov_abs_err']

    # delegate plotting to visualization module / 使用独立可视化模块
    save_rts_recovery_figs(mean_abs_err, cov_abs_err, GRID, out_dir, dpi=300)

    # Table-RTS: 记录最大绝对误差
    table_path = os.path.join(out_dir, "rts_recovery_report.txt")
    with open(table_path, 'w') as f:
        f.write("C4 RTS recovery (1D)\n")
        f.write(f"max |μ̂−μ^{{RTS}}| = {res['max_mean_err']:.2e}\n")
        f.write(f"max ||Σ̂−Σ^{{RTS}}||∞ = {res['max_cov_err']:.2e}\n")
        f.write(f"grid_points = {GRID.n_points}, bounds = {GRID.bounds}\n")

    curves_path = os.path.join(out_dir, "rts_recovery_curves.png")
    mean_path = os.path.join(out_dir, "rts_mean_error.png")
    cov_path = os.path.join(out_dir, "rts_cov_error.png")
    logger.info(
        "Saved RTS recovery artifacts\n  curves: %s\n  mean  : %s\n  cov   : %s\n  table : %s",
        curves_path,
        mean_path,
        cov_path,
        table_path,
    )


def main():
    """
    CLI entry / 命令行入口
    --K: 时间步数（默认64，用于长链鲁棒性测试）；可改为32/16以加速
    --grid: 网格点数（默认801）；可配合减小到401/301
    """
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--K", type=int, default=64)
    parser.add_argument("--grid", type=int, default=801)
    parser.add_argument("--tol", type=float, default=1e-10)
    parser.add_argument("--maxit", type=int, default=600)
    args = parser.parse_args()

    # 运行并覆盖快速/严格参数（通过CLI控制）
    res = run_single_seed(K_steps=args.K, grid_points=args.grid, ipfp_tol=args.tol, ipfp_maxit=args.maxit)
    save_fig_and_table(res)


if __name__ == "__main__":
    main()