#!/usr/bin/env python3
"""
Run C2 (metric ⟂ drift) experiment at publication quality.

Pipeline:
1) Simulate 1D LGSSM and build observed marginal densities (anchors)
2) Solve MMSB with two drifts (OU params) using IPFP (compiled loop)
3) Build metric baselines: W2 geodesic quantile lines from anchors
4) Compute FP residuals per interval (for each drift), normalize by W2(μ_k, μ_{k+1})
5) Save density-time plots with geodesic quantiles and residual bars
"""

from __future__ import annotations

import argparse
import os
from pathlib import Path
import math

import jax
import jax.numpy as jnp

# project path
BASE_DIR = Path(__file__).resolve().parents[2]
import sys
sys.path.append(str(BASE_DIR))

from mmsbvi.core.types import GridConfig1D, OUProcessParams, MMSBProblem, IPFPConfig
from mmsbvi.algorithms.ipfp_1d import solve_mmsb_ipfp_1d_fixed
from theoretical_verification.utils.rts import simulate_lgssm, kalman_rts
from theoretical_verification.core_experiments.metric_vs_constraint import (
    w2_distance_1d,
    quantile_lines_for_geodesic,
    fp_residual_series_ou,
)
from visualization.metric_vs_constraint_visualization import save_c2_figures, save_density_overlay


jax.config.update("jax_enable_x64", True)


def _gauss(x, mu, var):
    return (1.0 / jnp.sqrt(2 * jnp.pi * var)) * jnp.exp(-0.5 * (x - mu) ** 2 / var)


def _map_lg_to_ou(A: float, Q: float, dt: float = 1.0) -> OUProcessParams:
    theta = -math.log(A) / dt
    sigma = math.sqrt(2.0 * theta * Q / max(1e-12, (1.0 - A ** 2)))
    return OUProcessParams(mean_reversion=theta, diffusion=sigma, equilibrium_mean=0.0)


def run_c2(
    K_steps: int = 32,
    grid_points: int = 601,
    tol: float = 1e-8,
    maxit: int = 600,
    compiled: bool = True,
    # LGSSM params
    A1: float = 0.9,
    A2: float = 0.7,
    C: float = 1.0,
    Q: float = 0.1 ** 2,
    R: float = 0.2 ** 2,
    seed: int = 0,
    coverage_sigma: float = 7.0,
    out_dir: str = "theoretical_verification/results",
):
    key = jax.random.PRNGKey(seed)
    xs, ys = simulate_lgssm(key, A1, C, Q, R, K_steps)
    m_rts, P_rts = kalman_rts(ys, A1, C, Q, R)

    # grid from RTS stats
    mu_min = float(jnp.min(m_rts))
    mu_max = float(jnp.max(m_rts))
    sigma_max = float(jnp.max(jnp.sqrt(P_rts)))
    bounds = (mu_min - coverage_sigma * sigma_max, mu_max + coverage_sigma * sigma_max)
    GRID = GridConfig1D.create(grid_points, bounds)
    x = GRID.points
    h = GRID.spacing

    # observed marginals from RTS (anchors)
    anchors = []
    for k in range(K_steps):
        d = _gauss(x, m_rts[k], P_rts[k])
        d = d / (h * (jnp.sum(d) - 0.5 * (d[0] + d[-1])) + 1e-15)
        anchors.append(d)
    anchors_arr = jnp.stack(anchors, axis=0)

    # two drifts → OU params
    ou1 = _map_lg_to_ou(A1, Q)
    ou2 = _map_lg_to_ou(A2, Q)

    # common config
    ipfp_cfg = lambda: IPFPConfig(
        max_iterations=maxit,
        tolerance=tol,
        check_interval=10,
        verbose=False,
        compiled_loop=compiled,
        compiled_max_iterations=maxit,
        compiled_check_interval=10,
    )

    # solve for b1
    prob1 = MMSBProblem(
        observation_times=jnp.arange(float(K_steps)),
        observed_marginals=anchors,
        ou_params=ou1,
        grid=GRID,
        C=C,
        R=R,
    )
    sol1 = solve_mmsb_ipfp_1d_fixed(prob1, ipfp_cfg())

    # solve for b2
    prob2 = MMSBProblem(
        observation_times=jnp.arange(float(K_steps)),
        observed_marginals=anchors,
        ou_params=ou2,
        grid=GRID,
        C=C,
        R=R,
    )
    sol2 = solve_mmsb_ipfp_1d_fixed(prob2, ipfp_cfg())

    # build geodesic quantile lines from endpoints (anchors)
    geo_q = quantile_lines_for_geodesic(anchors_arr, x, quantiles=(0.1, 0.5, 0.9))

    # residuals per interval
    rho1 = jnp.stack(sol1.path_densities, axis=0)
    rho2 = jnp.stack(sol2.path_densities, axis=0)
    residuals1 = fp_residual_series_ou(rho1, x, 1.0, ou1)
    residuals2 = fp_residual_series_ou(rho2, x, 1.0, ou2)

    # ε-稳健归一化：W2_norm = max(W2, ε * median(W2))
    w2_list = []
    for k in range(K_steps - 1):
        w2_list.append(w2_distance_1d(anchors_arr[k], anchors_arr[k + 1], x) + 1e-12)
    w2_arr = jnp.array(w2_list)
    eps = 0.05
    floor = eps * jnp.median(w2_arr)
    w2_norm = jnp.maximum(w2_arr, floor)
    r_norm1 = residuals1 / w2_norm
    r_norm2 = residuals2 / w2_norm

    os.makedirs(out_dir, exist_ok=True)
    # choose peak intervals for annotations: top-3 of residuals2 (more pronounced)
    top_idx = jnp.argsort(residuals2)[-3:]
    peaks = list(map(int, list(top_idx)))
    meta = f"σ={float(ou1.diffusion):.2f}/{float(ou2.diffusion):.2f}, K={K_steps}, grid={grid_points}"
    save_c2_figures(rho1, rho2, geo_q, x, residuals1, residuals2, r_norm1, r_norm2, out_dir=out_dir, meta_title=meta, peaks_k=peaks)
    # overlay plot for central lines
    save_density_overlay(rho1, rho2, x, geo_q, peaks, f"Density-time overlay (b1 base, q=0.5 & b2 mean) [{meta}]",
                         os.path.join(out_dir, 'c2_density_overlay.png'))

    # report
    rpt = os.path.join(out_dir, 'c2_metric_vs_constraint_report.txt')
    with open(rpt, 'w') as f:
        f.write('C2 metric vs constraint (publication settings)\n')
        f.write(f'K={K_steps}, grid={grid_points}, tol={tol}, maxit={maxit}, compiled={compiled}\n')
        f.write(f'A1={A1}, A2={A2}, Q={Q}, R={R}\n')
        f.write(f'mean residual b1={float(jnp.mean(residuals1)):.3e}, b2={float(jnp.mean(residuals2)):.3e}\n')
        f.write(f'mean normalized residual b1={float(jnp.mean(r_norm1)):.3e}, b2={float(jnp.mean(r_norm2)):.3e}\n')
    print(f'C2 done. figures in {out_dir}, report={rpt}')


def main():
    p = argparse.ArgumentParser()
    p.add_argument('--K', type=int, default=32)
    p.add_argument('--grid', type=int, default=601)
    p.add_argument('--tol', type=float, default=1e-8)
    p.add_argument('--maxit', type=int, default=600)
    p.add_argument('--compiled', action='store_true')
    p.add_argument('--A1', type=float, default=0.9)
    p.add_argument('--A2', type=float, default=0.7)
    p.add_argument('--Q', type=float, default=0.1 ** 2)
    p.add_argument('--R', type=float, default=0.2 ** 2)
    args = p.parse_args()
    run_c2(K_steps=args.K, grid_points=args.grid, tol=args.tol, maxit=args.maxit,
           compiled=args.compiled, A1=args.A1, A2=args.A2, Q=args.Q, R=args.R)


if __name__ == '__main__':
    main()


