﻿from __future__ import annotations

import argparse
import json
import math
from dataclasses import asdict
from datetime import datetime
from pathlib import Path
from typing import Any

import numpy as np

HERE = Path(__file__).resolve()
CORE_DIR = HERE.parent
if str(CORE_DIR) not in __import__("sys").path:
    __import__("sys").path.insert(0, str(CORE_DIR))

from probe_core import (
    ScanConfig,
    compute_delta_true,
    compute_lambda_used_seq,
    config_asdict,
    cosine_mean_windows,
    estimate_lag1_rho,
    estimate_lle_benettin,
    generate_inputs,
    generate_targets,
    init_base_weights,
    init_output_weights,
    lowpass_ema,
    simulate_tanh_rnn,
    split_rngs,
)


def _timestamp() -> str:
    return datetime.now().strftime("%Y%m%d_%H%M%S")


def _ensure_dir(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


def _arange_inclusive(start: float, stop: float, step: float) -> np.ndarray:
    start_f, stop_f, step_f = float(start), float(stop), float(step)
    if step_f <= 0:
        raise ValueError("step must be > 0")
    n = int(math.floor((stop_f - start_f) / step_f + 1e-12)) + 1
    if n <= 0:
        return np.array([], dtype=np.float64)
    x = start_f + step_f * np.arange(n, dtype=np.float64)
    x = x[x <= (stop_f + 1e-12)]
    return x


def _find_gcrit(
    w_hh_base: np.ndarray,
    w_xh: np.ndarray,
    bias: np.ndarray,
    inputs0: np.ndarray,
    burnin_steps: int,
    seed: int,
    coarse_min: float,
    coarse_max: float,
    coarse_points: int,
    fine_window: float,
    fine_step: float,
    lle_trials: int,
) -> dict[str, Any]:
    def eval_grid(gains: np.ndarray) -> np.ndarray:
        lles = np.full((len(gains),), np.nan, dtype=np.float64)
        for i, g in enumerate(gains):
            w_hh = float(g) * w_hh_base
            _pre, _states, u_seq0 = simulate_tanh_rnn(w_hh, w_xh, bias, inputs0)
            vals = []
            for j in range(max(1, int(lle_trials))):
                lle_rng = np.random.default_rng(int(seed) * 10_000 + i * 97 + j * 997)
                vals.append(estimate_lle_benettin(w_hh, u_seq0, burnin_steps, lle_rng))
            lles[i] = float(np.nanmean(vals))
        return lles

    coarse = np.linspace(float(coarse_min), float(coarse_max), int(coarse_points), dtype=np.float64)
    lles_coarse = eval_grid(coarse)
    idx0 = int(np.nanargmin(np.abs(lles_coarse)))
    g0 = float(coarse[idx0])

    fine = _arange_inclusive(g0 - float(fine_window), g0 + float(fine_window), float(fine_step))
    lles_fine = eval_grid(fine)
    idx1 = int(np.nanargmin(np.abs(lles_fine)))
    g1 = float(fine[idx1])
    lle1 = float(lles_fine[idx1])

    return {
        "coarse": {"gains": coarse.tolist(), "lles": lles_coarse.tolist(), "best_gain": g0, "best_lle": float(lles_coarse[idx0])},
        "fine": {"gains": fine.tolist(), "lles": lles_fine.tolist(), "best_gain": g1, "best_lle": lle1},
        "gcrit": g1,
        "lle_at_gcrit": lle1,
    }


def _compute_g_seq(w_out: np.ndarray, states: np.ndarray, targets: np.ndarray) -> np.ndarray:
    time_steps = int(states.shape[0])
    hidden_size = int(states.shape[1])
    batch_size = int(states.shape[2])
    g_seq = np.zeros((time_steps, hidden_size, batch_size), dtype=np.float64)
    for t in range(time_steps):
        y_t = w_out @ states[t]
        e_t = y_t - targets[t]
        g_seq[t] = w_out.T @ e_t
    return g_seq


def run_sweep(args: argparse.Namespace) -> dict[str, Any]:
    seeds = list(range(int(args.seed_start), int(args.seed_end) + 1))
    target_rhos = [float(x) for x in args.target_ar1_rho]

    error_lp_grid = np.array([float(x) for x in args.error_lp_values], dtype=np.float64)
    if error_lp_grid.size == 0:
        raise ValueError("No error_lp_values configured")

    thr = float(args.cos_threshold)

    results_seeds: list[dict[str, Any]] = []
    for s in seeds:
        rngs = split_rngs(int(s))
        cfg_template = ScanConfig(
            hidden_size=int(args.hidden_size),
            batch_size=int(args.batch_size),
            time_steps=int(args.time_steps),
            burnin_steps=int(args.burnin_steps),
            input_dim=1,
            input_mode=str(args.input_mode),
            input_std=float(args.input_std),
            poisson_rate=float(args.poisson_rate),
            gext_mode="output_mse",
            gext_std=float(args.gext_std),
            gext_ar1_rho=0.0,
            w_hh_mode="low_rank",
            low_rank_rank=int(args.low_rank_rank),
            low_rank_frac=float(args.low_rank_frac),
            output_dim=int(args.output_dim),
            target_std=float(args.target_std),
            target_ar1_rho=float(target_rhos[0]) if target_rhos else float(args.target_std),
            error_lp_rho=float(error_lp_grid[0]),
            output_weight_scale=float(args.output_weight_scale),
            output_weight_mode="align_low_rank",
            output_basis_rank=int(args.output_basis_rank),
            alpha_rho=float(args.alpha_rho),
            alpha_source=str(args.alpha_source),
            lambda_window=int(args.lambda_window),
            eps_lambda=float(args.eps_lambda),
            lam_cap=float(args.lam_cap),
            denom_floor=float(args.denom_floor),
            fit_burnin_steps=int(args.fit_burnin_steps),
            use_safe_cap=bool(args.use_safe_cap),
        )

        inputs0 = generate_inputs(rngs["inputs"], cfg_template)
        w_hh_base, w_xh, bias, low_rank_basis = init_base_weights(rngs["recurrent"], cfg_template)

        gcrit_info = _find_gcrit(
            w_hh_base,
            w_xh,
            bias,
            inputs0,
            int(args.burnin_steps),
            int(s),
            float(args.gain_coarse_min),
            float(args.gain_coarse_max),
            int(args.gain_coarse_points),
            float(args.gain_fine_window),
            float(args.gain_fine_step),
            int(args.lle_trials),
        )
        gcrit = float(gcrit_info["gcrit"])
        lle_at_gcrit = float(gcrit_info["lle_at_gcrit"])
        w_hh = gcrit * w_hh_base

        targets_by_rho: list[dict[str, Any]] = []
        for rho_y in target_rhos:
            cfg = dataclasses_replace(cfg_template, target_ar1_rho=float(rho_y))

            w_out, _b_out = init_output_weights(rngs["output"], cfg, w_hh, low_rank_basis)
            targets = generate_targets(
                rngs["targets"],
                int(cfg.time_steps),
                int(cfg.output_dim),
                int(cfg.batch_size),
                float(cfg.target_std),
                float(cfg.target_ar1_rho),
            )

            _pre, states, u_seq = simulate_tanh_rnn(w_hh, w_xh, bias, inputs0)
            g_seq = _compute_g_seq(w_out, states, targets)

            cos_vals: list[float] = []
            pass_vals: list[bool] = []
            min_error_lp: dict[str, Any] | None = None
            rho_eff: list[float] = []
            lambda_used_mean: list[float] = []
            alpha_hat_mean: list[float] = []

            for elp in error_lp_grid.tolist():
                g_lp = lowpass_ema(g_seq, float(elp))
                lambda_used_seq, debug = compute_lambda_used_seq(states, u_seq, g_lp, cfg)
                delta_true = compute_delta_true(w_hh, u_seq, g_lp)

                cos_mean = cosine_mean_windows(
                    delta_true,
                    u_seq,
                    g_lp,
                    lambda_used_seq,
                    float(cfg.denom_floor),
                    int(cfg.fit_burnin_steps),
                    [int(cfg.fit_burnin_steps)],
                )[0]

                cos_vals.append(float(cos_mean))
                passed = bool(cos_mean >= thr)
                pass_vals.append(passed)

                if min_error_lp is None and passed:
                    min_error_lp = {"error_lp_rho": float(elp), "cos": float(cos_mean)}

                rho_eff.append(float(estimate_lag1_rho(g_lp, int(cfg.burnin_steps))))
                lambda_used_mean.append(float(debug.get("lambda_used_mean", float("nan"))))
                alpha_hat_mean.append(float(debug.get("alpha_hat_mean", float("nan"))))

            targets_by_rho.append(
                {
                    "target_ar1_rho": float(rho_y),
                    "error_lp_rho_grid": error_lp_grid.tolist(),
                    "cos_mean": cos_vals,
                    "pass": pass_vals,
                    "min_error_lp": min_error_lp,
                    "rho_eff": rho_eff,
                    "lambda_used_mean": lambda_used_mean,
                    "alpha_hat_mean": alpha_hat_mean,
                }
            )

        results_seeds.append(
            {
                "seed": int(s),
                "gcrit": float(gcrit),
                "lle_at_gcrit": float(lle_at_gcrit),
                "gcrit_scan": gcrit_info,
                "target_results": targets_by_rho,
            }
        )

    sweep_cfg = {
        "seeds": {"start": seeds[0], "end": seeds[-1], "count": len(seeds)},
        "target_ar1_rho": target_rhos,
        "error_lp_rho_grid": error_lp_grid.tolist(),
        "threshold_cos": thr,
        "gcrit_scan": {
            "coarse_min": float(args.gain_coarse_min),
            "coarse_max": float(args.gain_coarse_max),
            "coarse_points": int(args.gain_coarse_points),
            "fine_window": float(args.gain_fine_window),
            "fine_step": float(args.gain_fine_step),
            "lle_trials": int(args.lle_trials),
        },
        "base_config": config_asdict(
            ScanConfig(
                hidden_size=int(args.hidden_size),
                batch_size=int(args.batch_size),
                time_steps=int(args.time_steps),
                burnin_steps=int(args.burnin_steps),
                input_dim=1,
                input_mode=str(args.input_mode),
                input_std=float(args.input_std),
                poisson_rate=float(args.poisson_rate),
                gext_mode="output_mse",
                gext_std=float(args.gext_std),
                gext_ar1_rho=0.0,
                w_hh_mode="low_rank",
                low_rank_rank=int(args.low_rank_rank),
                low_rank_frac=float(args.low_rank_frac),
                output_dim=int(args.output_dim),
                target_std=float(args.target_std),
                target_ar1_rho=float(target_rhos[0]) if target_rhos else float(args.target_std),
                error_lp_rho=float(error_lp_grid[0]),
                output_weight_scale=float(args.output_weight_scale),
                output_weight_mode="align_low_rank",
                output_basis_rank=int(args.output_basis_rank),
                alpha_rho=float(args.alpha_rho),
                alpha_source=str(args.alpha_source),
                lambda_window=int(args.lambda_window),
                eps_lambda=float(args.eps_lambda),
                lam_cap=float(args.lam_cap),
                denom_floor=float(args.denom_floor),
                fit_burnin_steps=int(args.fit_burnin_steps),
                use_safe_cap=bool(args.use_safe_cap),
            )
        ),
    }

    return {"sweep_cfg": sweep_cfg, "seeds": results_seeds}


def dataclasses_replace(cfg: ScanConfig, **kwargs: Any) -> ScanConfig:
    data = asdict(cfg)
    data.update(kwargs)
    return ScanConfig(**data)


def _apply_plot_style() -> None:
    import matplotlib as mpl

    mpl.rcParams.update(
        {
            "font.family": "sans-serif",
            "font.sans-serif": ["Arial", "DejaVu Sans"],
            "font.size": 8.5,
            "axes.titlesize": 8.5,
            "axes.labelsize": 8.5,
            "xtick.labelsize": 8,
            "ytick.labelsize": 8,
            "legend.fontsize": 7.5,
            "xtick.direction": "out",
            "ytick.direction": "out",
            "axes.linewidth": 0.8,
            "pdf.fonttype": 42,
            "ps.fonttype": 42,
            "mathtext.fontset": "stixsans",
        }
    )


def plot_results(results: dict[str, Any], out_dir: Path) -> None:
    try:
        import matplotlib.pyplot as plt
    except Exception as exc:
        print(f"[PLOT] matplotlib unavailable ({exc}); skipping plots.")
        return

    seeds = results.get("seeds", [])
    if not seeds:
        print("[PLOT] No seeds found.")
        return

    _apply_plot_style()

    seed_ids = [int(s.get("seed", i)) for i, s in enumerate(seeds)]
    gcrit_vals = np.array([s.get("gcrit", float("nan")) for s in seeds], dtype=np.float64)
    lle_vals = np.abs(np.array([s.get("lle_at_gcrit", float("nan")) for s in seeds], dtype=np.float64))

    fig, ax1 = plt.subplots(1, 1, figsize=(3.35, 2.35), constrained_layout=True)
    c1 = "#1f77b4"
    c2 = "#ff7f0e"
    ax1.plot(seed_ids, gcrit_vals, marker="o", markersize=3.5, linewidth=1.1, color=c1, label=r"$g_{\mathrm{crit}}$")
    ax1.set_xlabel("Seed")
    ax1.set_ylabel(r"$g_{\mathrm{crit}}$", color=c1)
    ax1.tick_params(axis="y", labelcolor=c1)
    ax1.grid(True, linestyle=":", alpha=0.3)

    ax2 = ax1.twinx()
    ax2.plot(seed_ids, lle_vals, marker="s", markersize=3.2, linewidth=1.1, linestyle="--", color=c2, label=r"$|\mathrm{LLE}|$")
    ax2.set_ylabel(r"$|\mathrm{LLE}|$", color=c2)
    ax2.tick_params(axis="y", labelcolor=c2)
    if np.all(np.isfinite(lle_vals)) and float(np.nanmax(lle_vals)) > 0.0:
        ax2.set_yscale("log")

    if len(seed_ids) > 1:
        tick_step = max(1, len(seed_ids) // 4)
        xt = seed_ids[::tick_step]
        if xt[-1] != seed_ids[-1]:
            xt = xt + [seed_ids[-1]]
        ax1.set_xticks(xt)

    h1, l1 = ax1.get_legend_handles_labels()
    h2, l2 = ax2.get_legend_handles_labels()
    ax1.legend(h1 + h2, l1 + l2, frameon=False, loc="upper right")
    fig.savefig(out_dir / "stage1_gcrit_lle.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "stage1_gcrit_lle.png", dpi=240, bbox_inches="tight")
    plt.close(fig)

    target_rhos = [tr.get("target_ar1_rho") for tr in seeds[0].get("target_results", [])]
    if not target_rhos:
        return

    error_lp = np.array(seeds[0]["target_results"][0]["error_lp_rho_grid"], dtype=np.float64)

    fig, ax = plt.subplots(1, 1, figsize=(3.35, 2.35), constrained_layout=True)
    colors = plt.cm.tab10(np.linspace(0.0, 0.9, len(target_rhos)))
    for i, rho_y in enumerate(target_rhos):
        pass_mat = []
        for s in seeds:
            tr = s.get("target_results", [])[i]
            pass_mat.append(np.array(tr.get("pass", []), dtype=float))
        if not pass_mat:
            continue
        frac = np.mean(np.stack(pass_mat, axis=0), axis=0)
        ax.plot(error_lp, frac, marker="o", markersize=3.2, linewidth=1.1, color=colors[i], label=rf"$\rho_y={rho_y}$")
    ax.set_xlabel(r"Error low-pass $\rho_e$")
    ax.set_ylabel("Pass fraction")
    ax.set_ylim(-0.02, 1.02)
    ax.grid(True, linestyle=":", alpha=0.3)
    ax.legend(frameon=False, ncol=1)
    fig.savefig(out_dir / "stage1_pass_fraction.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "stage1_pass_fraction.png", dpi=240, bbox_inches="tight")
    plt.close(fig)

    fig, ax = plt.subplots(1, 1, figsize=(3.35, 2.35), constrained_layout=True)
    for i, rho_y in enumerate(target_rhos):
        vals = []
        for s in seeds:
            tr = s.get("target_results", [])[i]
            mp = tr.get("min_error_lp")
            if mp is None:
                continue
            vals.append(float(mp.get("error_lp_rho", float("nan"))))
        if not vals:
            continue
        mean_v = float(np.mean(vals))
        std_v = float(np.std(vals))
        ax.errorbar([rho_y], [mean_v], yerr=[std_v], fmt="o", markersize=3.5, capsize=2.5, color=colors[i])
    ax.set_xlabel(r"Target AR(1) $\rho_y$")
    ax.set_ylabel(r"Min required $\rho_e$")
    ax.set_ylim(0.0, 1.01)
    ax.grid(True, linestyle=":", alpha=0.3)
    fig.savefig(out_dir / "stage1_min_error_lp.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "stage1_min_error_lp.png", dpi=240, bbox_inches="tight")
    plt.close(fig)


def main() -> None:
    parser = argparse.ArgumentParser(description="Stage-1 threshold sweep (low-rank existence + gcrit/LLE)")
    parser.add_argument("--seed-start", type=int, default=120)
    parser.add_argument("--seed-end", type=int, default=140)
    parser.add_argument("--input-mode", type=str, default="gaussian", choices=["gaussian", "poisson"])

    parser.add_argument("--hidden-size", type=int, default=256)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--time-steps", type=int, default=800)
    parser.add_argument("--burnin-steps", type=int, default=100)
    parser.add_argument("--fit-burnin-steps", type=int, default=100)

    parser.add_argument("--input-std", type=float, default=0.05)
    parser.add_argument("--poisson-rate", type=float, default=5.0)
    parser.add_argument("--gext-std", type=float, default=1.0)

    parser.add_argument("--low-rank-rank", type=int, default=1)
    parser.add_argument("--low-rank-frac", type=float, default=0.95)

    parser.add_argument("--output-dim", type=int, default=16)
    parser.add_argument("--output-weight-scale", type=float, default=1.0)
    parser.add_argument("--output-basis-rank", type=int, default=0)

    parser.add_argument("--target-std", type=float, default=1.0)
    parser.add_argument("--target-ar1-rho", type=float, nargs="+", default=[0.995, 0.99, 0.95, 0.9])

    parser.add_argument("--alpha-rho", type=float, default=0.995)
    parser.add_argument("--alpha-source", type=str, default="h", choices=["h", "g"])
    parser.add_argument("--lambda-window", type=int, default=50)
    parser.add_argument("--eps-lambda", type=float, default=1e-8)
    parser.add_argument("--lam-cap", type=float, default=0.99)
    parser.add_argument("--denom-floor", type=float, default=1e-3)
    parser.add_argument("--use-safe-cap", action="store_true", default=True)

    parser.add_argument("--gain-coarse-min", type=float, default=0.8)
    parser.add_argument("--gain-coarse-max", type=float, default=1.3)
    parser.add_argument("--gain-coarse-points", type=int, default=51)
    parser.add_argument("--gain-fine-window", type=float, default=0.02)
    parser.add_argument("--gain-fine-step", type=float, default=0.001)
    parser.add_argument("--lle-trials", type=int, default=2)

    parser.add_argument("--error-lp-values", type=float, nargs="+", default=[0.99, 0.991, 0.992, 0.993, 0.994, 0.995])
    parser.add_argument("--cos-threshold", type=float, default=0.70)

    parser.add_argument("--plot-only", type=str, default=None)
    parser.add_argument("--out-dir", type=str, default=None)
    parser.add_argument("--overwrite", action="store_true")

    args = parser.parse_args()

    stamp = _timestamp()
    if args.out_dir is not None:
        out_dir = Path(str(args.out_dir))
        if out_dir.exists() and not bool(args.overwrite):
            raise SystemExit(f"[OUT] {out_dir} exists; pass --overwrite to write into it.")
        _ensure_dir(out_dir)
    else:
        out_dir = Path("paper_figure") / f"stage1_threshold_sweep_{stamp}"
        _ensure_dir(out_dir)
    print(f"[OUT] {out_dir}")

    if args.plot_only is not None:
        data = json.loads(Path(args.plot_only).read_text(encoding="utf-8"))
        plot_results(data, out_dir)
        return

    results = run_sweep(args)
    (out_dir / "results.json").write_text(json.dumps(results, indent=2), encoding="utf-8")
    plot_results(results, out_dir)


if __name__ == "__main__":
    main()
