﻿from __future__ import annotations

import argparse
import json
from datetime import datetime
from pathlib import Path

import numpy as np

HERE = Path(__file__).resolve()
CORE_DIR = HERE.parents[1] / "oll_stage1_probe"
if str(CORE_DIR) not in __import__("sys").path:
    __import__("sys").path.insert(0, str(CORE_DIR))

from probe_core import (
    ScanConfig,
    compute_delta_true,
    compute_lambda_used_seq,
    cosine_mean_windows,
    estimate_lle_benettin,
    generate_inputs,
    generate_targets,
    init_base_weights,
    init_output_weights,
    lowpass_ema,
    simulate_tanh_rnn,
    split_rngs,
)


def _timestamp() -> str:
    return datetime.now().strftime("%Y%m%d_%H%M%S")


def _ensure_dir(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


def _find_gcrit(w_hh_base, w_xh, bias, inputs0, burnin_steps, seed, coarse_min, coarse_max, coarse_points, fine_window, fine_step, lle_trials):
    def eval_grid(gains):
        lles = np.full((len(gains),), np.nan, dtype=np.float64)
        for i, g in enumerate(gains):
            w_hh = float(g) * w_hh_base
            _pre, _states, u_seq0 = simulate_tanh_rnn(w_hh, w_xh, bias, inputs0)
            vals = []
            for j in range(max(1, int(lle_trials))):
                lle_rng = np.random.default_rng(int(seed) * 10_000 + i * 97 + j * 997)
                vals.append(estimate_lle_benettin(w_hh, u_seq0, burnin_steps, lle_rng))
            lles[i] = float(np.nanmean(vals))
        return lles

    coarse = np.linspace(float(coarse_min), float(coarse_max), int(coarse_points), dtype=np.float64)
    lles_coarse = eval_grid(coarse)
    idx0 = int(np.nanargmin(np.abs(lles_coarse)))
    g0 = float(coarse[idx0])

    fine = np.linspace(float(g0 - fine_window), float(g0 + fine_window), int(max(3, fine_window / fine_step)) + 1)
    lles_fine = eval_grid(fine)
    idx1 = int(np.nanargmin(np.abs(lles_fine)))
    g1 = float(fine[idx1])
    lle1 = float(lles_fine[idx1])

    return {"gcrit": g1, "lle_at_gcrit": lle1}


def _apply_plot_style() -> None:
    import matplotlib as mpl

    mpl.rcParams.update(
        {
            "font.family": ["Times New Roman"],
            "font.size": 10,
            "axes.titlesize": 10,
            "axes.labelsize": 10,
            "xtick.labelsize": 9,
            "ytick.labelsize": 9,
            "xtick.direction": "out",
            "ytick.direction": "out",
            "axes.linewidth": 0.8,
            "pdf.fonttype": 42,
            "ps.fonttype": 42,
            "mathtext.fontset": "stix",
        }
    )


def main() -> None:
    parser = argparse.ArgumentParser(description="Stage-2: single-seed short-horizon scan")
    parser.add_argument("--seed", type=int, default=120)
    parser.add_argument("--input-mode", type=str, default="gaussian", choices=["gaussian", "poisson"])

    parser.add_argument("--hidden-size", type=int, default=256)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--time-steps", type=int, default=800)
    parser.add_argument("--burnin-steps", type=int, default=100)
    parser.add_argument("--fit-burnin-steps", type=int, default=100)

    parser.add_argument("--input-std", type=float, default=0.05)
    parser.add_argument("--poisson-rate", type=float, default=5.0)
    parser.add_argument("--gext-std", type=float, default=1.0)

    parser.add_argument("--low-rank-rank", type=int, default=1)
    parser.add_argument("--low-rank-frac", type=float, default=0.95)

    parser.add_argument("--output-dim", type=int, default=16)
    parser.add_argument("--output-weight-scale", type=float, default=1.0)

    parser.add_argument("--target-std", type=float, default=1.0)
    parser.add_argument("--target-ar1-rho", type=float, default=0.995)

    parser.add_argument("--alpha-rho", type=float, default=0.995)
    parser.add_argument("--alpha-source", type=str, default="h", choices=["h", "g"])
    parser.add_argument("--lambda-window", type=int, default=50)
    parser.add_argument("--eps-lambda", type=float, default=1e-8)
    parser.add_argument("--lam-cap", type=float, default=0.99)
    parser.add_argument("--denom-floor", type=float, default=1e-3)
    parser.add_argument("--use-safe-cap", action="store_true", default=True)

    parser.add_argument("--gain-coarse-min", type=float, default=0.8)
    parser.add_argument("--gain-coarse-max", type=float, default=1.3)
    parser.add_argument("--gain-coarse-points", type=int, default=51)
    parser.add_argument("--gain-fine-window", type=float, default=0.02)
    parser.add_argument("--gain-fine-step", type=float, default=0.001)
    parser.add_argument("--lle-trials", type=int, default=2)

    parser.add_argument("--error-lp-values", type=float, nargs="+", default=[0.95, 0.98, 0.99, 0.995, 0.998, 0.999, 0.9995])
    parser.add_argument("--eval-window", type=int, default=512)
    parser.add_argument("--cos-threshold", type=float, default=0.70)

    parser.add_argument("--out-dir", type=str, default=None)
    parser.add_argument("--overwrite", action="store_true")
    args = parser.parse_args()

    stamp = _timestamp()
    if args.out_dir is not None:
        out_dir = Path(str(args.out_dir))
        if out_dir.exists() and not bool(args.overwrite):
            raise SystemExit(f"[OUT] {out_dir} exists; pass --overwrite to write into it.")
        _ensure_dir(out_dir)
    else:
        out_dir = Path("paper_figure") / f"stage2_short_horizon_{stamp}"
        _ensure_dir(out_dir)
    print(f"[OUT] {out_dir}")

    rngs = split_rngs(int(args.seed))
    cfg = ScanConfig(
        hidden_size=int(args.hidden_size),
        batch_size=int(args.batch_size),
        time_steps=int(args.time_steps),
        burnin_steps=int(args.burnin_steps),
        input_dim=1,
        input_mode=str(args.input_mode),
        input_std=float(args.input_std),
        poisson_rate=float(args.poisson_rate),
        gext_mode="output_mse",
        gext_std=float(args.gext_std),
        gext_ar1_rho=0.0,
        w_hh_mode="low_rank",
        low_rank_rank=int(args.low_rank_rank),
        low_rank_frac=float(args.low_rank_frac),
        output_dim=int(args.output_dim),
        target_std=float(args.target_std),
        target_ar1_rho=float(args.target_ar1_rho),
        error_lp_rho=float(args.error_lp_values[0]),
        output_weight_scale=float(args.output_weight_scale),
        output_weight_mode="align_low_rank",
        output_basis_rank=0,
        alpha_rho=float(args.alpha_rho),
        alpha_source=str(args.alpha_source),
        lambda_window=int(args.lambda_window),
        eps_lambda=float(args.eps_lambda),
        lam_cap=float(args.lam_cap),
        denom_floor=float(args.denom_floor),
        fit_burnin_steps=int(args.fit_burnin_steps),
        use_safe_cap=bool(args.use_safe_cap),
    )

    inputs0 = generate_inputs(rngs["inputs"], cfg)
    w_hh_base, w_xh, bias, low_rank_basis = init_base_weights(rngs["recurrent"], cfg)
    gcrit_info = _find_gcrit(
        w_hh_base,
        w_xh,
        bias,
        inputs0,
        int(args.burnin_steps),
        int(args.seed),
        float(args.gain_coarse_min),
        float(args.gain_coarse_max),
        int(args.gain_coarse_points),
        float(args.gain_fine_window),
        float(args.gain_fine_step),
        int(args.lle_trials),
    )
    gcrit = float(gcrit_info["gcrit"])
    w_hh = gcrit * w_hh_base

    w_out, _b_out = init_output_weights(rngs["output"], cfg, w_hh, low_rank_basis)
    targets = generate_targets(
        rngs["targets"],
        int(cfg.time_steps),
        int(cfg.output_dim),
        int(cfg.batch_size),
        float(cfg.target_std),
        float(cfg.target_ar1_rho),
    )

    _pre, states, u_seq = simulate_tanh_rnn(w_hh, w_xh, bias, inputs0)
    g_seq = np.zeros_like(u_seq)
    for t in range(int(cfg.time_steps)):
        y_t = w_out @ states[t]
        e_t = y_t - targets[t]
        g_seq[t] = w_out.T @ e_t

    error_lp_grid = np.array([float(x) for x in args.error_lp_values], dtype=np.float64)
    start_eval = [int(cfg.time_steps - int(args.eval_window))]
    cos_vals = []
    for elp in error_lp_grid.tolist():
        g_lp = lowpass_ema(g_seq, float(elp))
        lambda_used_seq, _debug = compute_lambda_used_seq(states, u_seq, g_lp, cfg)
        delta_true = compute_delta_true(w_hh, u_seq, g_lp)
        cos_val = cosine_mean_windows(
            delta_true,
            u_seq,
            g_lp,
            lambda_used_seq,
            float(cfg.denom_floor),
            int(cfg.fit_burnin_steps),
            start_eval,
        )[0]
        cos_vals.append(float(cos_val))

    results = {
        "seed": int(args.seed),
        "gcrit": float(gcrit),
        "lle_at_gcrit": float(gcrit_info["lle_at_gcrit"]),
        "error_lp_rho_grid": error_lp_grid.tolist(),
        "cos_mean": cos_vals,
        "eval_window": int(args.eval_window),
        "threshold_cos": float(args.cos_threshold),
    }
    (out_dir / "results.json").write_text(json.dumps(results, indent=2), encoding="utf-8")

    try:
        import matplotlib.pyplot as plt
    except Exception as exc:
        print(f"[PLOT] matplotlib unavailable ({exc}); skipping plots.")
        return

    _apply_plot_style()

    fig, ax = plt.subplots(1, 1, figsize=(5.6, 3.2), constrained_layout=True)
    ax.plot(error_lp_grid, cos_vals, marker="o", markersize=4, linewidth=1.2)
    ax.axhline(float(args.cos_threshold), color="black", linestyle="--", linewidth=1.0, alpha=0.6)
    ax.set_xlabel(r"Error low-pass $\rho_e$")
    ax.set_ylabel(r"Mean $|\cos|$")
    ax.set_ylim(0.0, 1.0)
    ax.grid(True, linestyle=":", alpha=0.3)
    fig.suptitle(rf"Single-seed alignment (K={int(args.eval_window)})")
    fig.savefig(out_dir / "stage2_single_seed_cos_vs_error_lp.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "stage2_single_seed_cos_vs_error_lp.png", dpi=240, bbox_inches="tight")
    plt.close(fig)


if __name__ == "__main__":
    main()
