﻿from __future__ import annotations

import argparse
import json
import math
from dataclasses import asdict
from datetime import datetime
from pathlib import Path
from typing import Any

import numpy as np

HERE = Path(__file__).resolve()
CORE_DIR = HERE.parents[1] / "oll_stage1_probe"
if str(CORE_DIR) not in __import__("sys").path:
    __import__("sys").path.insert(0, str(CORE_DIR))

from probe_core import (
    ScanConfig,
    compute_delta_true,
    compute_lambda_used_seq,
    config_asdict,
    cosine_mean_windows,
    estimate_lag1_rho,
    estimate_lle_benettin,
    generate_inputs,
    generate_targets,
    init_base_weights,
    init_output_weights,
    lowpass_ema,
    simulate_tanh_rnn,
    split_rngs,
)


def _timestamp() -> str:
    return datetime.now().strftime("%Y%m%d_%H%M%S")


def _ensure_dir(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


def _find_gcrit(
    w_hh_base: np.ndarray,
    w_xh: np.ndarray,
    bias: np.ndarray,
    inputs0: np.ndarray,
    burnin_steps: int,
    seed: int,
    coarse_min: float,
    coarse_max: float,
    coarse_points: int,
    fine_window: float,
    fine_step: float,
    lle_trials: int,
) -> dict[str, Any]:
    def eval_grid(gains: np.ndarray) -> np.ndarray:
        lles = np.full((len(gains),), np.nan, dtype=np.float64)
        for i, g in enumerate(gains):
            w_hh = float(g) * w_hh_base
            _pre, _states, u_seq0 = simulate_tanh_rnn(w_hh, w_xh, bias, inputs0)
            vals = []
            for j in range(max(1, int(lle_trials))):
                lle_rng = np.random.default_rng(int(seed) * 10_000 + i * 97 + j * 997)
                vals.append(estimate_lle_benettin(w_hh, u_seq0, burnin_steps, lle_rng))
            lles[i] = float(np.nanmean(vals))
        return lles

    coarse = np.linspace(float(coarse_min), float(coarse_max), int(coarse_points), dtype=np.float64)
    lles_coarse = eval_grid(coarse)
    idx0 = int(np.nanargmin(np.abs(lles_coarse)))
    g0 = float(coarse[idx0])

    fine = np.linspace(float(g0 - fine_window), float(g0 + fine_window), int(max(3, fine_window / fine_step)) + 1)
    lles_fine = eval_grid(fine)
    idx1 = int(np.nanargmin(np.abs(lles_fine)))
    g1 = float(fine[idx1])
    lle1 = float(lles_fine[idx1])

    return {"gcrit": g1, "lle_at_gcrit": lle1}


def _compute_g_seq(w_out: np.ndarray, states: np.ndarray, targets: np.ndarray) -> np.ndarray:
    time_steps = int(states.shape[0])
    hidden_size = int(states.shape[1])
    batch_size = int(states.shape[2])
    g_seq = np.zeros((time_steps, hidden_size, batch_size), dtype=np.float64)
    for t in range(time_steps):
        y_t = w_out @ states[t]
        e_t = y_t - targets[t]
        g_seq[t] = w_out.T @ e_t
    return g_seq


def run_sweep(args: argparse.Namespace) -> dict[str, Any]:
    seeds = list(range(int(args.seed_start), int(args.seed_end) + 1))
    ranks = [int(x) for x in args.ranks]

    eval_windows = sorted({int(x) for x in args.eval_window_values})
    if not eval_windows:
        raise ValueError("No eval_window_values provided")

    error_lp_grid = np.array([float(x) for x in args.error_lp_values], dtype=np.float64)
    if error_lp_grid.size == 0:
        raise ValueError("No error_lp_values configured")

    thr = float(args.cos_threshold)

    blocks: list[dict[str, Any]] = []
    for r in ranks:
        cfg_template = ScanConfig(
            hidden_size=int(args.hidden_size),
            batch_size=int(args.batch_size),
            time_steps=int(args.time_steps),
            burnin_steps=int(args.burnin_steps),
            input_dim=1,
            input_mode=str(args.input_mode),
            input_std=float(args.input_std),
            poisson_rate=float(args.poisson_rate),
            gext_mode="output_mse",
            gext_std=float(args.gext_std),
            gext_ar1_rho=0.0,
            w_hh_mode="low_rank",
            low_rank_rank=int(r),
            low_rank_frac=float(args.low_rank_frac),
            output_dim=int(args.output_dim),
            target_std=float(args.target_std),
            target_ar1_rho=float(args.target_ar1_rho),
            error_lp_rho=float(error_lp_grid[0]),
            output_weight_scale=float(args.output_weight_scale),
            output_weight_mode="align_low_rank",
            output_basis_rank=int(r),
            alpha_rho=float(args.alpha_rho),
            alpha_source=str(args.alpha_source),
            lambda_window=int(args.lambda_window),
            eps_lambda=float(args.eps_lambda),
            lam_cap=float(args.lam_cap),
            denom_floor=float(args.denom_floor),
            fit_burnin_steps=int(args.fit_burnin_steps),
            use_safe_cap=bool(args.use_safe_cap),
        )

        seed_rows: list[dict[str, Any]] = []
        for s in seeds:
            rngs = split_rngs(int(s))
            inputs0 = generate_inputs(rngs["inputs"], cfg_template)
            w_hh_base, w_xh, bias, low_rank_basis = init_base_weights(rngs["recurrent"], cfg_template)

            gcrit_info = _find_gcrit(
                w_hh_base,
                w_xh,
                bias,
                inputs0,
                int(args.burnin_steps),
                int(s),
                float(args.gain_coarse_min),
                float(args.gain_coarse_max),
                int(args.gain_coarse_points),
                float(args.gain_fine_window),
                float(args.gain_fine_step),
                int(args.lle_trials),
            )
            gcrit = float(gcrit_info["gcrit"])
            lle_at_gcrit = float(gcrit_info["lle_at_gcrit"])
            w_hh = gcrit * w_hh_base

            w_out, _b_out = init_output_weights(
                rngs["output"],
                cfg_template,
                w_hh,
                low_rank_basis,
            )

            targets = generate_targets(
                rngs["targets"],
                int(cfg_template.time_steps),
                int(cfg_template.output_dim),
                int(cfg_template.batch_size),
                float(cfg_template.target_std),
                float(cfg_template.target_ar1_rho),
            )

            _pre, states, u_seq = simulate_tanh_rnn(w_hh, w_xh, bias, inputs0)
            g_seq = _compute_g_seq(w_out, states, targets)

            start_eval = [int(cfg_template.time_steps - k) for k in eval_windows]
            if min(start_eval) < 0:
                raise ValueError("time_steps too small for eval_window_values")

            cos_by_error = np.zeros((len(error_lp_grid), len(eval_windows)), dtype=np.float64)
            rho_eff = []
            lambda_used_mean = []
            alpha_hat_mean = []

            for ri, elp in enumerate(error_lp_grid.tolist()):
                g_lp = lowpass_ema(g_seq, float(elp))
                lambda_used_seq, debug = compute_lambda_used_seq(states, u_seq, g_lp, cfg_template)
                delta_true = compute_delta_true(w_hh, u_seq, g_lp)
                cos_mean_wk = cosine_mean_windows(
                    delta_true,
                    u_seq,
                    g_lp,
                    lambda_used_seq,
                    float(cfg_template.denom_floor),
                    int(cfg_template.fit_burnin_steps),
                    start_eval,
                )
                cos_by_error[ri] = cos_mean_wk

                rho_eff.append(float(estimate_lag1_rho(g_lp, int(cfg_template.burnin_steps))))
                lambda_used_mean.append(float(debug.get("lambda_used_mean", float("nan"))))
                alpha_hat_mean.append(float(debug.get("alpha_hat_mean", float("nan"))))

            best_cos = np.max(cos_by_error, axis=0)
            min_error_lp = []
            for wi in range(len(eval_windows)):
                idx = np.where(cos_by_error[:, wi] >= thr)[0]
                if idx.size == 0:
                    min_error_lp.append(None)
                else:
                    min_error_lp.append({"error_lp_rho": float(error_lp_grid[idx[0]]), "cos": float(cos_by_error[idx[0], wi])})

            seed_rows.append(
                {
                    "seed": int(s),
                    "gcrit": float(gcrit),
                    "lle_at_gcrit": float(lle_at_gcrit),
                    "gcrit_scan": gcrit_info,
                    "cos_by_error": cos_by_error.tolist(),
                    "best_cos": best_cos.tolist(),
                    "min_error_lp": min_error_lp,
                    "rho_eff": rho_eff,
                    "lambda_used_mean": lambda_used_mean,
                    "alpha_hat_mean": alpha_hat_mean,
                }
            )

        best_cos_stats: list[dict[str, Any]] = []
        min_error_stats: list[dict[str, Any]] = []
        for wi, k in enumerate(eval_windows):
            vals = [float(row["best_cos"][wi]) for row in seed_rows]
            best_cos_stats.append(
                {
                    "eval_window": int(k),
                    "best_cos_min": float(np.min(vals)),
                    "best_cos_mean": float(np.mean(vals)),
                    "best_cos_max": float(np.max(vals)),
                    "best_cos_pass_fraction": float(np.mean(np.array(vals) >= thr)),
                }
            )

            min_vals = []
            for row in seed_rows:
                mp = row["min_error_lp"][wi]
                if mp is None:
                    continue
                min_vals.append(float(mp["error_lp_rho"]))
            if min_vals:
                min_error_stats.append(
                    {
                        "eval_window": int(k),
                        "min_error_lp_mean": float(np.mean(min_vals)),
                        "min_error_lp_max": float(np.max(min_vals)),
                    }
                )
            else:
                min_error_stats.append(
                    {
                        "eval_window": int(k),
                        "min_error_lp_mean": float("nan"),
                        "min_error_lp_max": float("nan"),
                    }
                )

        blocks.append(
            {
                "rank": int(r),
                "error_lp_rho_grid": error_lp_grid.tolist(),
                "eval_windows": eval_windows,
                "best_cos_stats": best_cos_stats,
                "min_error_lp_stats": min_error_stats,
                "seeds": seed_rows,
            }
        )

    sweep_cfg = {
        "seeds": {"start": seeds[0], "end": seeds[-1], "count": len(seeds)},
        "ranks": ranks,
        "error_lp_rho_grid": error_lp_grid.tolist(),
        "eval_windows": eval_windows,
        "threshold_cos": thr,
        "base_config": config_asdict(cfg_template),
    }

    return {"sweep_cfg": sweep_cfg, "blocks": blocks}


def _apply_plot_style() -> None:
    import matplotlib as mpl

    mpl.rcParams.update(
        {
            "font.family": "sans-serif",
            "font.sans-serif": ["Arial", "DejaVu Sans"],
            "font.size": 8.5,
            "axes.titlesize": 8.5,
            "axes.labelsize": 8.5,
            "xtick.labelsize": 8,
            "ytick.labelsize": 8,
            "legend.fontsize": 7.5,
            "xtick.direction": "out",
            "ytick.direction": "out",
            "axes.linewidth": 0.8,
            "pdf.fonttype": 42,
            "ps.fonttype": 42,
            "mathtext.fontset": "stixsans",
        }
    )


def plot_results(results: dict[str, Any], out_dir: Path) -> None:
    try:
        import matplotlib.pyplot as plt
        import matplotlib.ticker as mticker
    except Exception as exc:
        print(f"[PLOT] matplotlib unavailable ({exc}); skipping plots.")
        return

    blocks = results.get("blocks", [])
    if not blocks:
        print("[PLOT] No blocks found.")
        return

    _apply_plot_style()

    eval_windows = blocks[0]["eval_windows"]
    Kmax = max(eval_windows)
    wi_max = int(np.argmax(np.array(eval_windows)))

    colors = plt.cm.tab10(np.linspace(0.0, 0.9, len(blocks)))

    fig, ax = plt.subplots(1, 1, figsize=(3.35, 2.35), constrained_layout=True)
    for idx, block in enumerate(blocks):
        stats = block.get("best_cos_stats", [])
        vals = [float(row.get("best_cos_mean", float("nan"))) for row in stats]
        ax.plot(eval_windows, vals, marker="o", markersize=3.5, linewidth=1.1, color=colors[idx], label=rf"$R={block['rank']}$")
    thr = float(results["sweep_cfg"]["threshold_cos"])
    ax.axhline(thr, color="black", linestyle="--", linewidth=1.0, alpha=0.6)
    ax.set_xlabel(r"Evaluation window $K$")
    ax.set_ylabel(r"Best mean $|\cos|$")
    ax.set_xscale("log", base=2)
    ax.set_xticks(eval_windows)
    ax.get_xaxis().set_major_formatter(mticker.FuncFormatter(lambda v, _: f"{int(v)}"))
    ax.set_ylim(0.0, 1.0)
    ax.grid(True, linestyle=":", alpha=0.3)
    ax.legend(frameon=False, ncol=2, loc="lower right")
    fig.savefig(out_dir / "stage3_lowrank_best_cos_vs_K.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "stage3_lowrank_best_cos_vs_K.png", dpi=240, bbox_inches="tight")
    plt.close(fig)

    fig, ax = plt.subplots(1, 1, figsize=(3.35, 2.35), constrained_layout=True)
    for idx, block in enumerate(blocks):
        grid = np.array(block["error_lp_rho_grid"], dtype=np.float64)
        cos_vals = []
        for row in block["seeds"]:
            cos_vals.append(np.array(row["cos_by_error"], dtype=np.float64)[:, wi_max])
        if cos_vals:
            cos_mean = np.mean(np.stack(cos_vals, axis=0), axis=0)
            ax.plot(grid, cos_mean, marker="o", markersize=3.5, linewidth=1.1, color=colors[idx], label=rf"$R={block['rank']}$")
    ax.axhline(thr, color="black", linestyle="--", linewidth=1.0, alpha=0.6)
    ax.set_xlabel(r"Error low-pass $\rho_e$")
    ax.set_ylabel(rf"Mean $|\cos|$ (K={Kmax})")
    ax.set_ylim(0.0, 1.0)
    ax.grid(True, linestyle=":", alpha=0.3)
    ax.legend(frameon=False, ncol=2, loc="lower right")
    fig.savefig(out_dir / f"stage3_lowrank_cos_vs_error_lp_K{Kmax}.pdf", bbox_inches="tight")
    fig.savefig(out_dir / f"stage3_lowrank_cos_vs_error_lp_K{Kmax}.png", dpi=240, bbox_inches="tight")
    plt.close(fig)

    fig, ax = plt.subplots(1, 1, figsize=(3.35, 2.35), constrained_layout=True)
    for idx, block in enumerate(blocks):
        stats = block.get("min_error_lp_stats", [])
        vals = [float(row.get("min_error_lp_mean", float("nan"))) for row in stats]
        ax.plot(eval_windows, vals, marker="o", markersize=3.5, linewidth=1.1, color=colors[idx], label=rf"$R={block['rank']}$")
    ax.set_xlabel(r"Evaluation window $K$")
    ax.set_ylabel(r"Min required $\rho_e$")
    ax.set_xscale("log", base=2)
    ax.set_xticks(eval_windows)
    ax.get_xaxis().set_major_formatter(mticker.FuncFormatter(lambda v, _: f"{int(v)}"))
    ax.set_ylim(0.0, 1.01)
    ax.grid(True, linestyle=":", alpha=0.3)
    ax.legend(frameon=False, ncol=2, loc="upper right")
    fig.savefig(out_dir / "stage3_lowrank_min_error_lp_vs_K.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "stage3_lowrank_min_error_lp_vs_K.png", dpi=240, bbox_inches="tight")
    plt.close(fig)


def main() -> None:
    parser = argparse.ArgumentParser(description="Stage-3: low-rank multi-R sweep")
    parser.add_argument("--seed-start", type=int, default=120)
    parser.add_argument("--seed-end", type=int, default=140)
    parser.add_argument("--input-mode", type=str, default="gaussian", choices=["gaussian", "poisson"])

    parser.add_argument("--hidden-size", type=int, default=128)
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--time-steps", type=int, default=700)
    parser.add_argument("--burnin-steps", type=int, default=100)
    parser.add_argument("--fit-burnin-steps", type=int, default=100)

    parser.add_argument("--input-std", type=float, default=0.05)
    parser.add_argument("--poisson-rate", type=float, default=5.0)
    parser.add_argument("--gext-std", type=float, default=1.0)

    parser.add_argument("--ranks", type=int, nargs="+", default=[1, 2, 4, 8, 16])
    parser.add_argument("--low-rank-frac", type=float, default=0.95)

    parser.add_argument("--output-dim", type=int, default=16)
    parser.add_argument("--target-std", type=float, default=1.0)
    parser.add_argument("--target-ar1-rho", type=float, default=0.995)
    parser.add_argument("--output-weight-scale", type=float, default=1.0)

    parser.add_argument("--alpha-rho", type=float, default=0.995)
    parser.add_argument("--alpha-source", type=str, default="h", choices=["h", "g"])
    parser.add_argument("--lambda-window", type=int, default=50)
    parser.add_argument("--eps-lambda", type=float, default=1e-8)
    parser.add_argument("--lam-cap", type=float, default=0.99)
    parser.add_argument("--denom-floor", type=float, default=1e-3)
    parser.add_argument("--use-safe-cap", action="store_true", default=True)

    parser.add_argument("--gain-coarse-min", type=float, default=0.8)
    parser.add_argument("--gain-coarse-max", type=float, default=1.3)
    parser.add_argument("--gain-coarse-points", type=int, default=31)
    parser.add_argument("--gain-fine-window", type=float, default=0.02)
    parser.add_argument("--gain-fine-step", type=float, default=0.002)
    parser.add_argument("--lle-trials", type=int, default=1)

    parser.add_argument("--error-lp-values", type=float, nargs="+", default=[0.0, 0.9, 0.99, 0.995, 0.998, 0.999])
    parser.add_argument("--cos-threshold", type=float, default=0.70)

    parser.add_argument("--eval-window-values", type=int, nargs="+", default=[8, 32, 128, 512])
    parser.add_argument("--plot-only", type=str, default=None)
    parser.add_argument("--out-dir", type=str, default=None)
    parser.add_argument("--overwrite", action="store_true")
    args = parser.parse_args()

    stamp = _timestamp()
    if args.out_dir is not None:
        out_dir = Path(str(args.out_dir))
        if out_dir.exists() and not bool(args.overwrite):
            raise SystemExit(f"[OUT] {out_dir} exists; pass --overwrite to write into it.")
        _ensure_dir(out_dir)
    else:
        out_dir = Path("paper_figure") / f"stage3_lowrank_sweep_{stamp}"
        _ensure_dir(out_dir)
    print(f"[OUT] {out_dir}")

    if args.plot_only is not None:
        data = json.loads(Path(args.plot_only).read_text(encoding="utf-8"))
        plot_results(data, out_dir)
        return

    results = run_sweep(args)
    (out_dir / "results.json").write_text(json.dumps(results, indent=2), encoding="utf-8")
    plot_results(results, out_dir)


if __name__ == "__main__":
    main()
