﻿from __future__ import annotations

import argparse
import json
import math
from dataclasses import asdict
from datetime import datetime
from pathlib import Path
from typing import Any

import numpy as np

HERE = Path(__file__).resolve()
CORE_DIR = HERE.parents[1] / "oll_stage1_probe"
if str(CORE_DIR) not in __import__("sys").path:
    __import__("sys").path.insert(0, str(CORE_DIR))

from probe_core import (
    ScanConfig,
    compute_delta_true_tbptt,
    compute_lambda_used_seq,
    config_asdict,
    cosine_mean_windows,
    estimate_lag1_rho,
    estimate_lle_benettin,
    generate_inputs,
    generate_targets,
    init_base_weights,
    init_output_weights,
    lowpass_ema,
    simulate_tanh_rnn,
    split_rngs,
)


def _timestamp() -> str:
    return datetime.now().strftime("%Y%m%d_%H%M%S")


def _ensure_dir(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


def _arange_inclusive(start: float, stop: float, step: float) -> np.ndarray:
    start_f, stop_f, step_f = float(start), float(stop), float(step)
    if step_f <= 0:
        raise ValueError("step must be > 0")
    n = int(math.floor((stop_f - start_f) / step_f + 1e-12)) + 1
    if n <= 0:
        return np.array([], dtype=np.float64)
    x = start_f + step_f * np.arange(n, dtype=np.float64)
    x = x[x <= (stop_f + 1e-12)]
    return x


def _find_gcrit(
    w_hh_base: np.ndarray,
    w_xh: np.ndarray,
    bias: np.ndarray,
    inputs0: np.ndarray,
    burnin_steps: int,
    seed: int,
    coarse_min: float,
    coarse_max: float,
    coarse_points: int,
    fine_window: float,
    fine_step: float,
    lle_trials: int,
) -> dict[str, Any]:
    def eval_grid(gains: np.ndarray) -> np.ndarray:
        lles = np.full((len(gains),), np.nan, dtype=np.float64)
        for i, g in enumerate(gains):
            w_hh = float(g) * w_hh_base
            _pre, _states, u_seq0 = simulate_tanh_rnn(w_hh, w_xh, bias, inputs0)
            vals = []
            for j in range(max(1, int(lle_trials))):
                lle_rng = np.random.default_rng(int(seed) * 10_000 + i * 97 + j * 997)
                vals.append(estimate_lle_benettin(w_hh, u_seq0, burnin_steps, lle_rng))
            lles[i] = float(np.nanmean(vals))
        return lles

    coarse = np.linspace(float(coarse_min), float(coarse_max), int(coarse_points), dtype=np.float64)
    lles_coarse = eval_grid(coarse)
    idx0 = int(np.nanargmin(np.abs(lles_coarse)))
    g0 = float(coarse[idx0])

    fine = _arange_inclusive(g0 - float(fine_window), g0 + float(fine_window), float(fine_step))
    lles_fine = eval_grid(fine)
    idx1 = int(np.nanargmin(np.abs(lles_fine)))
    g1 = float(fine[idx1])
    lle1 = float(lles_fine[idx1])

    return {
        "coarse": {"gains": coarse.tolist(), "lles": lles_coarse.tolist(), "best_gain": g0, "best_lle": float(lles_coarse[idx0])},
        "fine": {"gains": fine.tolist(), "lles": lles_fine.tolist(), "best_gain": g1, "best_lle": lle1},
        "gcrit": g1,
        "lle_at_gcrit": lle1,
    }


def _compute_g_seq(w_out: np.ndarray, states: np.ndarray, targets: np.ndarray) -> np.ndarray:
    time_steps = int(states.shape[0])
    hidden_size = int(states.shape[1])
    batch_size = int(states.shape[2])
    g_seq = np.zeros((time_steps, hidden_size, batch_size), dtype=np.float64)
    for t in range(time_steps):
        y_t = w_out @ states[t]
        e_t = y_t - targets[t]
        g_seq[t] = w_out.T @ e_t
    return g_seq


def run_sweep(args: argparse.Namespace) -> dict[str, Any]:
    seeds = list(range(int(args.seed_start), int(args.seed_end) + 1))
    input_mode = str(args.input_mode)

    tbptt_ks = sorted({int(x) for x in args.tbptt_k_values})
    if not tbptt_ks:
        raise ValueError("No tbptt_k_values provided")

    error_lp_grid = np.array([float(x) for x in args.error_lp_values], dtype=np.float64)
    if error_lp_grid.size == 0:
        raise ValueError("No error_lp_values configured")

    thr = float(args.cos_threshold)
    results: list[dict[str, Any]] = []

    for s in seeds:
        rngs = split_rngs(int(s))
        cfg_template = ScanConfig(
            hidden_size=int(args.hidden_size),
            batch_size=int(args.batch_size),
            time_steps=int(args.time_steps),
            burnin_steps=int(args.burnin_steps),
            input_dim=1,
            input_mode=str(input_mode),
            input_std=float(args.input_std),
            poisson_rate=float(args.poisson_rate),
            gext_mode="output_mse",
            gext_std=float(args.gext_std),
            gext_ar1_rho=0.0,
            w_hh_mode="low_rank",
            low_rank_rank=int(args.low_rank_rank),
            low_rank_frac=float(args.low_rank_frac),
            output_dim=int(args.output_dim),
            target_std=float(args.target_std),
            target_ar1_rho=float(args.target_ar1_rho),
            error_lp_rho=float(error_lp_grid[0]),
            output_weight_scale=float(args.output_weight_scale),
            output_weight_mode="align_low_rank",
            output_basis_rank=int(args.output_basis_rank),
            alpha_rho=float(args.alpha_rho),
            alpha_source=str(args.alpha_source),
            lambda_window=int(args.lambda_window),
            eps_lambda=float(args.eps_lambda),
            lam_cap=float(args.lam_cap),
            denom_floor=float(args.denom_floor),
            fit_burnin_steps=int(args.fit_burnin_steps),
            use_safe_cap=bool(args.use_safe_cap),
        )

        inputs0 = generate_inputs(rngs["inputs"], cfg_template)
        w_hh_base, w_xh, bias, low_rank_basis = init_base_weights(rngs["recurrent"], cfg_template)

        gcrit_info = _find_gcrit(
            w_hh_base,
            w_xh,
            bias,
            inputs0,
            int(args.burnin_steps),
            int(s),
            float(args.gain_coarse_min),
            float(args.gain_coarse_max),
            int(args.gain_coarse_points),
            float(args.gain_fine_window),
            float(args.gain_fine_step),
            int(args.lle_trials),
        )
        gcrit = float(gcrit_info["gcrit"])
        lle_at_gcrit = float(gcrit_info["lle_at_gcrit"])
        w_hh = gcrit * w_hh_base

        w_out, _b_out = init_output_weights(rngs["output"], cfg_template, w_hh, low_rank_basis)
        targets = generate_targets(
            rngs["targets"],
            int(cfg_template.time_steps),
            int(cfg_template.output_dim),
            int(cfg_template.batch_size),
            float(cfg_template.target_std),
            float(cfg_template.target_ar1_rho),
        )

        _pre, states, u_seq = simulate_tanh_rnn(w_hh, w_xh, bias, inputs0)
        g_seq = _compute_g_seq(w_out, states, targets)

        per_k = [{"tbptt_k": int(k), "min_error_lp": None} for k in tbptt_ks]
        rho_eff: list[float] = []
        lambda_used_mean: list[float] = []
        alpha_hat_mean: list[float] = []

        for elp in error_lp_grid.tolist():
            g_lp = lowpass_ema(g_seq, float(elp))
            lambda_used_seq, debug = compute_lambda_used_seq(states, u_seq, g_lp, cfg_template)

            for wi, K in enumerate(tbptt_ks):
                delta_true = compute_delta_true_tbptt(w_hh, u_seq, g_lp, int(K))
                cos_mean = cosine_mean_windows(
                    delta_true,
                    u_seq,
                    g_lp,
                    lambda_used_seq,
                    float(cfg_template.denom_floor),
                    int(cfg_template.fit_burnin_steps),
                    [int(cfg_template.fit_burnin_steps)],
                )[0]

                if per_k[wi]["min_error_lp"] is None and bool(cos_mean >= thr):
                    per_k[wi]["min_error_lp"] = {
                        "error_lp_rho": float(elp),
                        "cos": float(cos_mean),
                    }

            rho_eff.append(float(estimate_lag1_rho(g_lp, int(cfg_template.burnin_steps))))
            lambda_used_mean.append(float(debug.get("lambda_used_mean", float("nan"))))
            alpha_hat_mean.append(float(debug.get("alpha_hat_mean", float("nan"))))

        results.append(
            {
                "seed": int(s),
                "gcrit": float(gcrit),
                "lle_at_gcrit": float(lle_at_gcrit),
                "gcrit_scan": gcrit_info,
                "tbptt_k_results": per_k,
                "rho_eff": rho_eff,
                "lambda_used_mean": lambda_used_mean,
                "alpha_hat_mean": alpha_hat_mean,
            }
        )

    sweep_cfg = {
        "input_mode": str(input_mode),
        "seeds": {"start": seeds[0], "end": seeds[-1], "count": len(seeds)},
        "target_ar1_rho": float(args.target_ar1_rho),
        "tbptt_ks": tbptt_ks,
        "error_lp_rho_grid": error_lp_grid.tolist(),
        "threshold_cos": thr,
        "gcrit_scan": {
            "coarse_min": float(args.gain_coarse_min),
            "coarse_max": float(args.gain_coarse_max),
            "coarse_points": int(args.gain_coarse_points),
            "fine_window": float(args.gain_fine_window),
            "fine_step": float(args.gain_fine_step),
            "lle_trials": int(args.lle_trials),
        },
        "base_config": config_asdict(cfg_template),
    }

    return {"sweep_cfg": sweep_cfg, "seeds": results}


def _apply_plot_style() -> None:
    import matplotlib as mpl

    mpl.rcParams.update(
        {
            "font.family": ["Times New Roman"],
            "font.size": 10,
            "axes.titlesize": 10,
            "axes.labelsize": 10,
            "xtick.labelsize": 9,
            "ytick.labelsize": 9,
            "xtick.direction": "out",
            "ytick.direction": "out",
            "axes.linewidth": 0.8,
            "pdf.fonttype": 42,
            "ps.fonttype": 42,
            "mathtext.fontset": "stix",
        }
    )


def plot_results(results: dict[str, Any], out_dir: Path) -> None:
    try:
        import matplotlib.pyplot as plt
        import matplotlib.ticker as mticker
    except Exception as exc:
        print(f"[PLOT] matplotlib unavailable ({exc}); skipping plots.")
        return

    seeds = results.get("seeds", [])
    if not seeds:
        print("[PLOT] No seeds found.")
        return

    _apply_plot_style()

    tbptt_ks = results["sweep_cfg"]["tbptt_ks"]

    req_rho_mean: list[float] = []
    req_rho_max: list[float] = []
    for wi, _k in enumerate(tbptt_ks):
        vals_rho = []
        for s in seeds:
            ew = s.get("tbptt_k_results", [])
            if wi >= len(ew):
                continue
            mp = ew[wi].get("min_error_lp") if isinstance(ew[wi], dict) else None
            if not mp:
                continue
            vals_rho.append(float(mp.get("error_lp_rho", float("nan"))))
        if vals_rho:
            req_rho_mean.append(float(np.nanmean(vals_rho)))
            req_rho_max.append(float(np.nanmax(vals_rho)))
        else:
            req_rho_mean.append(float("nan"))
            req_rho_max.append(float("nan"))

    fig, ax = plt.subplots(1, 1, figsize=(6.0, 3.2), constrained_layout=True)
    ax.plot(tbptt_ks, req_rho_mean, marker="o", markersize=4, linewidth=1.2)
    ax.scatter(tbptt_ks, req_rho_max, marker="^", s=28)
    ax.set_ylabel(r"Required $\rho_e$")
    ax.set_xlabel(r"TBPTT truncation $K$")
    ax.set_xscale("log", base=2)
    ax.set_xticks(tbptt_ks)
    ax.get_xaxis().set_major_formatter(mticker.FuncFormatter(lambda v, _: f"{int(v)}"))
    ax.grid(True, linestyle=":", alpha=0.3)
    fig.suptitle(rf"TBPTT sweep (threshold $|\cos| \geq {results['sweep_cfg']['threshold_cos']:.2f}$)")
    fig.savefig(out_dir / "stage2_tbptt_boundary.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "stage2_tbptt_boundary.png", dpi=240, bbox_inches="tight")
    plt.close(fig)


def main() -> None:
    parser = argparse.ArgumentParser(description="Stage-2: TBPTT-K sweep")
    parser.add_argument("--seed-start", type=int, default=120)
    parser.add_argument("--seed-end", type=int, default=121)
    parser.add_argument("--input-mode", type=str, default="gaussian", choices=["gaussian", "poisson"])

    parser.add_argument("--hidden-size", type=int, default=64)
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--time-steps", type=int, default=200)
    parser.add_argument("--burnin-steps", type=int, default=20)
    parser.add_argument("--fit-burnin-steps", type=int, default=20)

    parser.add_argument("--input-std", type=float, default=0.05)
    parser.add_argument("--poisson-rate", type=float, default=5.0)
    parser.add_argument("--gext-std", type=float, default=1.0)

    parser.add_argument("--low-rank-rank", type=int, default=1)
    parser.add_argument("--low-rank-frac", type=float, default=0.95)

    parser.add_argument("--output-dim", type=int, default=16)
    parser.add_argument("--output-weight-scale", type=float, default=1.0)
    parser.add_argument("--output-basis-rank", type=int, default=0)

    parser.add_argument("--target-std", type=float, default=1.0)
    parser.add_argument("--target-ar1-rho", type=float, default=0.995)

    parser.add_argument("--alpha-rho", type=float, default=0.995)
    parser.add_argument("--alpha-source", type=str, default="h", choices=["h", "g"])
    parser.add_argument("--lambda-window", type=int, default=50)
    parser.add_argument("--eps-lambda", type=float, default=1e-8)
    parser.add_argument("--lam-cap", type=float, default=0.99)
    parser.add_argument("--denom-floor", type=float, default=1e-3)
    parser.add_argument("--use-safe-cap", action="store_true", default=True)

    parser.add_argument("--gain-coarse-min", type=float, default=0.8)
    parser.add_argument("--gain-coarse-max", type=float, default=1.3)
    parser.add_argument("--gain-coarse-points", type=int, default=21)
    parser.add_argument("--gain-fine-window", type=float, default=0.02)
    parser.add_argument("--gain-fine-step", type=float, default=0.002)
    parser.add_argument("--lle-trials", type=int, default=1)

    parser.add_argument("--error-lp-values", type=float, nargs="+", default=[0.0, 0.9])
    parser.add_argument("--cos-threshold", type=float, default=0.70)

    parser.add_argument("--tbptt-k-values", type=int, nargs="+", default=[8, 16, 32])
    parser.add_argument("--plot-only", type=str, default=None)
    parser.add_argument("--out-dir", type=str, default=None)
    parser.add_argument("--overwrite", action="store_true")
    args = parser.parse_args()

    stamp = _timestamp()
    if args.out_dir is not None:
        out_dir = Path(str(args.out_dir))
        if out_dir.exists() and not bool(args.overwrite):
            raise SystemExit(f"[OUT] {out_dir} exists; pass --overwrite to write into it.")
        _ensure_dir(out_dir)
    else:
        out_dir = Path("paper_figure") / f"stage2_tbptt_sweep_{stamp}"
        _ensure_dir(out_dir)
    print(f"[OUT] {out_dir}")

    if args.plot_only is not None:
        data = json.loads(Path(args.plot_only).read_text(encoding="utf-8"))
        plot_results(data, out_dir)
        return

    results = run_sweep(args)
    (out_dir / "results.json").write_text(json.dumps(results, indent=2), encoding="utf-8")
    plot_results(results, out_dir)


if __name__ == "__main__":
    main()
