﻿from __future__ import annotations

import argparse
import json
from dataclasses import asdict
from datetime import datetime
from pathlib import Path
from typing import Any

import numpy as np

HERE = Path(__file__).resolve()
CORE_DIR = HERE.parents[1] / "oll_stage1_probe"
if str(CORE_DIR) not in __import__("sys").path:
    __import__("sys").path.insert(0, str(CORE_DIR))

from probe_core import (
    ScanConfig,
    compute_delta_true,
    compute_lambda_used_seq,
    config_asdict,
    generate_inputs,
    generate_targets,
    init_base_weights,
    init_output_weights,
    lowpass_ema,
    simulate_tanh_rnn,
    split_rngs,
)


def _timestamp() -> str:
    return datetime.now().strftime("%Y%m%d_%H%M%S")


def _ensure_dir(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


def _apply_plot_style() -> None:
    import matplotlib as mpl

    mpl.rcParams.update(
        {
            "font.family": ["Times New Roman"],
            "font.size": 10,
            "axes.titlesize": 10,
            "axes.labelsize": 10,
            "xtick.labelsize": 9,
            "ytick.labelsize": 9,
            "xtick.direction": "out",
            "ytick.direction": "out",
            "axes.linewidth": 0.8,
            "pdf.fonttype": 42,
            "ps.fonttype": 42,
            "mathtext.fontset": "stix",
        }
    )


def _find_gcrit(w_hh_base, w_xh, bias, inputs0, burnin_steps, seed, coarse_min, coarse_max, coarse_points, fine_window, fine_step, lle_trials):
    def eval_grid(gains):
        lles = np.full((len(gains),), np.nan, dtype=np.float64)
        for i, g in enumerate(gains):
            w_hh = float(g) * w_hh_base
            _pre, _states, u_seq0 = simulate_tanh_rnn(w_hh, w_xh, bias, inputs0)
            vals = []
            for j in range(max(1, int(lle_trials))):
                lle_rng = np.random.default_rng(int(seed) * 10_000 + i * 97 + j * 997)
                vals.append(_estimate_lle(w_hh, u_seq0, burnin_steps, lle_rng))
            lles[i] = float(np.nanmean(vals))
        return lles

    coarse = np.linspace(float(coarse_min), float(coarse_max), int(coarse_points), dtype=np.float64)
    lles_coarse = eval_grid(coarse)
    idx0 = int(np.nanargmin(np.abs(lles_coarse)))
    g0 = float(coarse[idx0])

    fine = np.linspace(float(g0 - fine_window), float(g0 + fine_window), int(max(3, fine_window / fine_step)) + 1)
    lles_fine = eval_grid(fine)
    idx1 = int(np.nanargmin(np.abs(lles_fine)))
    g1 = float(fine[idx1])
    lle1 = float(lles_fine[idx1])

    return {"gcrit": g1, "lle_at_gcrit": lle1}


def _estimate_lle(w_hh: np.ndarray, u_seq: np.ndarray, burnin_steps: int, rng: np.random.Generator, eps: float = 1e-12) -> float:
    time_steps, hidden_size, _batch = u_seq.shape
    v = rng.normal(0.0, 1.0, size=(hidden_size,)).astype(np.float64)
    v /= float(np.linalg.norm(v) + eps)
    logs = []
    for t in range(time_steps):
        u_t = u_seq[t].mean(axis=1)
        v = u_t * (w_hh @ v)
        norm = float(np.linalg.norm(v) + eps)
        v /= norm
        if t >= int(burnin_steps):
            logs.append(np.log(norm))
    if not logs:
        return float("nan")
    return float(np.mean(logs))


def _effective_rank(X: np.ndarray, eps: float = 1e-12) -> float:
    """Participation-ratio effective rank from singular values.

    We define p_k = sigma_k^2 / sum_j sigma_j^2 and r_eff = 1 / sum_k p_k^2.
    This matches the "participation ratio" effective-rank convention used in our
    other diagnostics.
    """
    if X.size == 0:
        return float("nan")
    _U, S, _Vt = np.linalg.svd(X, full_matrices=False)
    power = S**2
    denom = float(np.sum(power) + eps)
    p = power / denom
    return float(1.0 / (np.sum(p**2) + eps))


def _effective_rank_from_svals(S: np.ndarray, eps: float = 1e-12) -> float:
    if S.size == 0:
        return float("nan")
    power = S**2
    denom = float(np.sum(power) + eps)
    p = power / denom
    return float(1.0 / (np.sum(p**2) + eps))


def _matrix_rank_from_svals(S: np.ndarray, shape: tuple[int, int], eps: float = 1e-12) -> int:
    """Numerical rank (row/col rank) from singular values.

    Uses NumPy's default tolerance: tol = S.max() * max(m, n) * eps_dtype.
    """
    if S.size == 0:
        return 0
    m, n = int(shape[0]), int(shape[1])
    tol = float(np.max(S) * max(m, n) * np.finfo(S.dtype).eps)
    if tol == 0.0:
        tol = float(eps)
    return int(np.sum(S > tol))


def _mean_abs_offdiag_corr(X: np.ndarray) -> float:
    """Mean absolute off-diagonal correlation across units.

    Interprets rows as samples (time/batch) and columns as units.
    """
    if X.ndim != 2:
        return float("nan")
    num_samples, num_units = X.shape
    if num_samples < 2 or num_units < 2:
        return float("nan")

    C = np.corrcoef(X, rowvar=False)
    if C.ndim != 2 or C.shape[0] != num_units or C.shape[1] != num_units:
        return float("nan")

    mask = ~np.eye(num_units, dtype=bool)
    vals = np.abs(C[mask])
    vals = vals[np.isfinite(vals)]
    return float(np.mean(vals)) if vals.size else float("nan")


def _mean_abs_unitwise_corr(X: np.ndarray, Y: np.ndarray, eps: float = 1e-12) -> float:
    """Mean absolute correlation per unit between two time series.

    X and Y are (samples, units).
    """
    if X.ndim != 2 or Y.ndim != 2 or X.shape != Y.shape:
        return float("nan")
    num_samples, num_units = X.shape
    if num_samples < 2 or num_units < 1:
        return float("nan")

    vals = []
    for i in range(num_units):
        x = X[:, i]
        y = Y[:, i]
        if float(np.std(x)) < eps or float(np.std(y)) < eps:
            continue
        c = float(np.corrcoef(x, y)[0, 1])
        if np.isfinite(c):
            vals.append(abs(c))
    return float(np.mean(vals)) if vals else float("nan")


def _apply_A(u_bar: np.ndarray, w_hh: np.ndarray, v: np.ndarray) -> np.ndarray:
    return u_bar * (w_hh.T @ v)


def _apply_A_alpha(u_bar: np.ndarray, w_hh: np.ndarray, alpha: np.ndarray, v: np.ndarray) -> np.ndarray:
    if v.ndim == 1:
        return u_bar * (w_hh.T @ (alpha * v))
    scaled = alpha[:, None] * v
    return u_bar[:, None] * (w_hh.T @ scaled)


def _compute_ar1_short_window(
    h_bar: np.ndarray,
    window: int,
    with_bias: bool,
    eps: float = 1e-12,
) -> tuple[np.ndarray, float, float]:
    """Short-window AR(1) fit per unit.

    Fits y ≈ a x (+ b) on sliding windows of length W, where
      x = h[t-W:t], y = h[t-W+1:t+1] (per unit, batch-averaged trajectory).
    Reports:
      - relative MSE: (in-window) MSE / E[y^2] averaged across time and units
      - R^2: global one-step coefficient of determination for predicting h[t] from
        h[t-1] using the windowed fit parameters, averaged across units.
    """
    t_steps, hidden = h_bar.shape
    W = max(2, int(window))
    alpha_seq = np.zeros((t_steps, hidden), dtype=np.float64)
    rel_vals = []
    pred_last_vals = []
    true_last_vals = []
    for t in range(W, t_steps - 1):
        x = h_bar[t - W : t]
        y = h_bar[t - W + 1 : t + 1]

        if with_bias:
            x_mean = np.mean(x, axis=0)
            y_mean = np.mean(y, axis=0)
            cov = np.mean((x - x_mean) * (y - y_mean), axis=0)
            var = np.mean((x - x_mean) ** 2, axis=0) + eps
            alpha = cov / var
            beta = y_mean - alpha * x_mean
            y_hat = alpha[None, :] * x + beta[None, :]
        else:
            xy = np.mean(x * y, axis=0)
            xx = np.mean(x * x, axis=0) + eps
            alpha = xy / xx
            y_hat = alpha[None, :] * x
            y_mean = np.mean(y, axis=0)

        alpha_seq[t] = alpha
        mse = np.mean((y - y_hat) ** 2, axis=0)
        rel = mse / (np.mean(y**2, axis=0) + eps)
        rel_vals.append(rel)

        # Global one-step prediction for time t: predict h[t] from h[t-1].
        if with_bias:
            pred_last_vals.append(alpha * h_bar[t - 1] + beta)
        else:
            pred_last_vals.append(alpha * h_bar[t - 1])
        true_last_vals.append(h_bar[t])

    rel_mean = float(np.mean(rel_vals)) if rel_vals else float("nan")
    if not pred_last_vals:
        r2_mean = float("nan")
    else:
        preds = np.stack(pred_last_vals, axis=0)
        trues = np.stack(true_last_vals, axis=0)
        sse = np.mean((trues - preds) ** 2, axis=0)
        t_mean = np.mean(trues, axis=0)
        sst = np.mean((trues - t_mean) ** 2, axis=0) + eps
        r2 = 1.0 - (sse / sst)
        r2_mean = float(np.mean(r2))
    return alpha_seq, rel_mean, r2_mean


def _lambda_self_consistency(
    u_seq: np.ndarray,
    g_seq: np.ndarray,
    alpha_seq: np.ndarray,
    lambda_used_seq: np.ndarray,
    eps: float = 1e-12,
) -> float:
    time_steps = int(u_seq.shape[0])
    res_vals = []
    prev_g = None
    prev_u = None
    for t in range(time_steps):
        g_t = g_seq[t]
        u_t = u_seq[t]
        alpha_t = alpha_seq[t]
        if not np.all(np.isfinite(alpha_t)) or np.all(alpha_t == 0):
            prev_g = g_t
            prev_u = u_t
            continue
        if prev_g is None:
            prev_g = g_t
            prev_u = u_t
            continue
        A_s = prev_u * u_t * (alpha_t[:, None] * prev_g - g_t)
        B_s = alpha_t[:, None] * prev_u * prev_g - u_t * g_t
        lam_t = lambda_used_seq[t]
        denom = float(np.mean(B_s**2) + eps)
        res = float(np.mean((B_s - lam_t * A_s) ** 2) / denom)
        res_vals.append(res)
        prev_g = g_t
        prev_u = u_t
    return float(np.mean(res_vals)) if res_vals else float("nan")


def _spectral_radius(
    w_hh: np.ndarray,
    u_bar: np.ndarray,
    alpha: np.ndarray | None = None,
    iters: int = 20,
    eps: float = 1e-12,
) -> float:
    v = np.random.default_rng(0).normal(0.0, 1.0, size=(w_hh.shape[0],)).astype(np.float64)
    v /= float(np.linalg.norm(v) + eps)
    norm = 0.0
    for _ in range(int(iters)):
        if alpha is None:
            v = _apply_A(u_bar, w_hh, v)
        else:
            v = _apply_A_alpha(u_bar, w_hh, alpha, v)
        norm = float(np.linalg.norm(v) + eps)
        v /= norm
    return float(norm)


def _compute_metrics_for_seed(cfg: ScanConfig, seed: int, args: argparse.Namespace) -> dict[str, Any]:
    rngs = split_rngs(int(seed))
    inputs0 = generate_inputs(rngs["inputs"], cfg)
    w_hh_base, w_xh, bias, low_rank_basis = init_base_weights(rngs["recurrent"], cfg)

    gcrit_info = _find_gcrit(
        w_hh_base,
        w_xh,
        bias,
        inputs0,
        int(args.burnin_steps),
        int(seed),
        float(args.gain_coarse_min),
        float(args.gain_coarse_max),
        int(args.gain_coarse_points),
        float(args.gain_fine_window),
        float(args.gain_fine_step),
        int(args.lle_trials),
    )
    gcrit = float(gcrit_info["gcrit"])
    w_hh = gcrit * w_hh_base

    w_out, _b_out = init_output_weights(rngs["output"], cfg, w_hh, low_rank_basis)
    targets = generate_targets(
        rngs["targets"],
        int(cfg.time_steps),
        int(cfg.output_dim),
        int(cfg.batch_size),
        float(cfg.target_std),
        float(cfg.target_ar1_rho),
    )

    _pre, states, u_seq = simulate_tanh_rnn(w_hh, w_xh, bias, inputs0)
    g_seq = np.zeros_like(u_seq)
    for t in range(int(cfg.time_steps)):
        y_t = w_out @ states[t]
        e_t = y_t - targets[t]
        g_seq[t] = w_out.T @ e_t

    g_lp = lowpass_ema(g_seq, float(cfg.error_lp_rho))
    lambda_used_seq, debug = compute_lambda_used_seq(states, u_seq, g_lp, cfg)
    delta_true = compute_delta_true(w_hh, u_seq, g_lp)

    burn = int(cfg.fit_burnin_steps)
    h_bar = np.mean(states[burn:], axis=2)
    u_bar_seq = np.mean(u_seq[burn:], axis=2)
    s_bar = np.mean(u_seq[burn:] * g_lp[burn:], axis=2)
    delta_bar = np.mean(delta_true[burn:], axis=2)

    # Time-by-unit matrices (rows=time, cols=units), matching Appendix definitions.
    S_mat = s_bar
    D_mat = delta_bar
    H_mat = h_bar

    s_svals = np.linalg.svd(S_mat, full_matrices=False, compute_uv=False)
    d_svals = np.linalg.svd(D_mat, full_matrices=False, compute_uv=False)
    h_svals = np.linalg.svd(H_mat, full_matrices=False, compute_uv=False)

    eff_rank_s = _effective_rank_from_svals(s_svals)
    eff_rank_d = _effective_rank_from_svals(d_svals)
    eff_rank_h = _effective_rank_from_svals(h_svals)

    # Row/column rank are equal, but we report both to make the interpretation explicit.
    rank_s = _matrix_rank_from_svals(s_svals, S_mat.shape)
    rank_d = _matrix_rank_from_svals(d_svals, D_mat.shape)
    rank_h = _matrix_rank_from_svals(h_svals, H_mat.shape)

    # Short-window AR(1) (temporal smoothness), with and without an intercept.
    # The derivation in Appendix B.2 uses a no-bias continuation for the teaching signal,
    # while a bias can capture slow mean drift in state trajectories.
    alpha_seq_nobias, ar1_rel_mse_nobias, ar1_r2_nobias = _compute_ar1_short_window(h_bar, int(args.ar_window), with_bias=False)
    alpha_seq_bias, ar1_rel_mse_bias, ar1_r2_bias = _compute_ar1_short_window(h_bar, int(args.ar_window), with_bias=True)

    alpha_delta_seq_nobias, delta_ar1_rel_mse_nobias, delta_ar1_r2_nobias = _compute_ar1_short_window(
        delta_bar, int(args.ar_window), with_bias=False
    )
    alpha_delta_seq_bias, delta_ar1_rel_mse_bias, delta_ar1_r2_bias = _compute_ar1_short_window(delta_bar, int(args.ar_window), with_bias=True)

    # Closure diagnostics (matches Appendix B.4)
    lambda_bar_seq = np.mean(lambda_used_seq[burn:], axis=2)
    implicit_vals = []
    closure_vals = []
    mu_rel_mse_vals = []

    stride = max(1, int(args.time_stride))
    start_t = max(1, int(args.ar_window))

    # Spatial coherence across units (correlations over time).
    idx = np.arange(start_t, delta_bar.shape[0], stride, dtype=np.int64)
    delta_samples = delta_bar[idx]
    z_samples = np.zeros_like(delta_samples)
    for n, t in enumerate(idx.tolist()):
        alpha_t = alpha_delta_seq_bias[t]
        z_samples[n] = w_hh.T @ (alpha_t * delta_bar[t])

    delta_offdiag_abs_corr = _mean_abs_offdiag_corr(delta_samples)
    z_offdiag_abs_corr = _mean_abs_offdiag_corr(z_samples)
    z_delta_unit_abs_corr = _mean_abs_unitwise_corr(delta_samples, z_samples)

    for t in range(start_t, s_bar.shape[0], stride):
        u_bar = u_bar_seq[t]
        alpha_t = alpha_delta_seq_bias[t]
        if not (np.all(np.isfinite(u_bar)) and np.all(np.isfinite(alpha_t))):
            continue

        delta_t = delta_bar[t]
        s_t = s_bar[t]
        # Implicit-system residual in the full space: ||(I-A)delta - s|| / (||delta||+||s||)
        A_delta = _apply_A_alpha(u_bar, w_hh, alpha_t, delta_t)
        denom = float(np.linalg.norm(delta_t) + np.linalg.norm(s_t) + 1e-12)
        implicit_vals.append(float(np.linalg.norm(delta_t - (s_t + A_delta)) / denom))

        # Per-unit loop-gain closure residual: z ~= lambda * delta, where z = W^T (alpha * delta)
        z_t = w_hh.T @ (alpha_t * delta_t)
        lam_t = lambda_bar_seq[t]
        z_hat = lam_t * delta_t
        denom_z = float(np.linalg.norm(z_t) + np.linalg.norm(z_hat) + 1e-12)
        if denom_z > float(args.denom_floor):
            closure_vals.append(float(np.linalg.norm(z_t - z_hat) / denom_z))

        # Oracle-vs-estimated mu (mu := u * lambda) relative MSE, masked to avoid 0/0.
        mask = np.abs(delta_t) > 1e-6
        if np.any(mask):
            mu_oracle = u_bar[mask] * (z_t[mask] / delta_t[mask])
            mu_used = u_bar[mask] * lam_t[mask]
            mu_denom = float(np.mean(mu_oracle**2) + 1e-12)
            mu_rel_mse_vals.append(float(np.mean((mu_used - mu_oracle) ** 2) / mu_denom))

    implicit_resid = float(np.mean(implicit_vals)) if implicit_vals else float("nan")
    closure_resid = float(np.mean(closure_vals)) if closure_vals else float("nan")
    mu_rel_mse = float(np.mean(mu_rel_mse_vals)) if mu_rel_mse_vals else float("nan")

    # Lambda self-consistency residual
    lambda_self_resid = _lambda_self_consistency(
        u_seq[burn:],
        g_lp[burn:],
        alpha_delta_seq_bias,
        lambda_used_seq[burn:],
    )

    # Slope variation: log change in u
    u_abs = np.abs(u_bar_seq) + 1e-12
    log_diff = np.abs(np.log(u_abs[1:]) - np.log(u_abs[:-1]))
    slope_log_diff = float(np.mean(log_diff))

    # Stability: spectral radius and max |lambda u|
    spec_vals = []
    for t in range(start_t, u_bar_seq.shape[0], stride):
        spec_vals.append(_spectral_radius(w_hh, u_bar_seq[t], alpha_delta_seq_bias[t], iters=int(args.spectral_iters)))
    spec_mean = float(np.mean(spec_vals)) if spec_vals else float("nan")
    spec_max = float(np.max(spec_vals)) if spec_vals else float("nan")

    lam_u = np.abs(lambda_used_seq[burn:] * u_seq[burn:])
    lam_u_max = float(np.max(lam_u)) if lam_u.size > 0 else float("nan")

    return {
        "seed": int(seed),
        "gcrit": float(gcrit),
        "lle_at_gcrit": float(gcrit_info["lle_at_gcrit"]),
        "alpha_hat_mean": float(debug.get("alpha_hat_mean", float("nan"))),
        "lambda_used_mean": float(debug.get("lambda_used_mean", float("nan"))),
        "eff_rank_source": eff_rank_s,
        "eff_rank_source_row": eff_rank_s,
        "eff_rank_source_col": eff_rank_s,
        "rank_source_row": int(rank_s),
        "rank_source_col": int(rank_s),
        "eff_rank_delta": eff_rank_d,
        "eff_rank_delta_row": eff_rank_d,
        "eff_rank_delta_col": eff_rank_d,
        "rank_delta_row": int(rank_d),
        "rank_delta_col": int(rank_d),
        "eff_rank_state": eff_rank_h,
        "eff_rank_state_row": eff_rank_h,
        "eff_rank_state_col": eff_rank_h,
        "rank_state_row": int(rank_h),
        "rank_state_col": int(rank_h),
        "delta_offdiag_abs_corr": float(delta_offdiag_abs_corr),
        "z_offdiag_abs_corr": float(z_offdiag_abs_corr),
        "z_delta_unit_abs_corr": float(z_delta_unit_abs_corr),
        "ar1_rel_mse": ar1_rel_mse_bias,
        "ar1_rel_mse_nobias": ar1_rel_mse_nobias,
        "ar1_r2": ar1_r2_bias,
        "ar1_r2_nobias": ar1_r2_nobias,
        "delta_ar1_rel_mse": delta_ar1_rel_mse_bias,
        "delta_ar1_rel_mse_nobias": delta_ar1_rel_mse_nobias,
        "delta_ar1_r2": delta_ar1_r2_bias,
        "delta_ar1_r2_nobias": delta_ar1_r2_nobias,
        "implicit_resid": implicit_resid,
        "closure_resid": closure_resid,
        "mu_rel_mse": mu_rel_mse,
        "lambda_self_resid": lambda_self_resid,
        "slope_log_diff": slope_log_diff,
        "spectral_radius_mean": spec_mean,
        "spectral_radius_max": spec_max,
        "lambda_u_max": lam_u_max,
    }


def run_diagnostics(args: argparse.Namespace) -> dict[str, Any]:
    seeds = list(range(int(args.seed_start), int(args.seed_end) + 1))

    configs: list[dict[str, Any]] = []
    if "low_rank" in args.scenarios:
        configs.append({"name": f"low_rank_R{int(args.rank)}", "w_hh_mode": "low_rank", "output_weight_mode": "align_low_rank", "rank": int(args.rank)})
    if "iid" in args.scenarios:
        configs.append({"name": "iid_random", "w_hh_mode": "iid", "output_weight_mode": "random", "rank": 0})

    results_cfg: list[dict[str, Any]] = []
    for cfg_info in configs:
        cfg = ScanConfig(
            hidden_size=int(args.hidden_size),
            batch_size=int(args.batch_size),
            time_steps=int(args.time_steps),
            burnin_steps=int(args.burnin_steps),
            input_dim=1,
            input_mode=str(args.input_mode),
            input_std=float(args.input_std),
            poisson_rate=float(args.poisson_rate),
            gext_mode="output_mse",
            gext_std=float(args.gext_std),
            gext_ar1_rho=0.0,
            w_hh_mode=str(cfg_info["w_hh_mode"]),
            low_rank_rank=int(cfg_info.get("rank", 1)),
            low_rank_frac=float(args.low_rank_frac),
            output_dim=int(args.output_dim),
            target_std=float(args.target_std),
            target_ar1_rho=float(args.target_ar1_rho),
            error_lp_rho=float(args.error_lp_rho),
            output_weight_scale=float(args.output_weight_scale),
            output_weight_mode=str(cfg_info["output_weight_mode"]),
            output_basis_rank=int(cfg_info.get("rank", 0)),
            alpha_rho=float(args.alpha_rho),
            alpha_source=str(args.alpha_source),
            lambda_window=int(args.lambda_window),
            eps_lambda=float(args.eps_lambda),
            lam_cap=float(args.lam_cap),
            denom_floor=float(args.denom_floor),
            fit_burnin_steps=int(args.fit_burnin_steps),
            use_safe_cap=bool(args.use_safe_cap),
        )

        rows = []
        for seed in seeds:
            rows.append(_compute_metrics_for_seed(cfg, seed, args))

        def _mean(key: str) -> float:
            vals = [r.get(key) for r in rows if r.get(key) is not None]
            vals = [v for v in vals if v == v]
            return float(np.mean(vals)) if vals else float("nan")

        summary_means = {
            "eff_rank_source": _mean("eff_rank_source"),
            "rank_source_row": _mean("rank_source_row"),
            "rank_source_col": _mean("rank_source_col"),
            "eff_rank_delta": _mean("eff_rank_delta"),
            "rank_delta_row": _mean("rank_delta_row"),
            "rank_delta_col": _mean("rank_delta_col"),
            "eff_rank_state": _mean("eff_rank_state"),
            "rank_state_row": _mean("rank_state_row"),
            "rank_state_col": _mean("rank_state_col"),
            "delta_offdiag_abs_corr": _mean("delta_offdiag_abs_corr"),
            "z_offdiag_abs_corr": _mean("z_offdiag_abs_corr"),
            "z_delta_unit_abs_corr": _mean("z_delta_unit_abs_corr"),
            "implicit_resid": _mean("implicit_resid"),
            "closure_resid": _mean("closure_resid"),
            "mu_rel_mse": _mean("mu_rel_mse"),
            "ar1_rel_mse": _mean("ar1_rel_mse"),
            "ar1_rel_mse_nobias": _mean("ar1_rel_mse_nobias"),
            "ar1_r2": _mean("ar1_r2"),
            "ar1_r2_nobias": _mean("ar1_r2_nobias"),
            "delta_ar1_rel_mse": _mean("delta_ar1_rel_mse"),
            "delta_ar1_rel_mse_nobias": _mean("delta_ar1_rel_mse_nobias"),
            "delta_ar1_r2": _mean("delta_ar1_r2"),
            "delta_ar1_r2_nobias": _mean("delta_ar1_r2_nobias"),
            "slope_log_diff": _mean("slope_log_diff"),
            "spectral_radius_mean": _mean("spectral_radius_mean"),
            "spectral_radius_max": _mean("spectral_radius_max"),
            "lambda_u_max": _mean("lambda_u_max"),
            "lambda_self_resid": _mean("lambda_self_resid"),
        }

        results_cfg.append(
            {
                "name": cfg_info["name"],
                "config": config_asdict(cfg),
                "seeds": rows,
                "summary_means": summary_means,
            }
        )

    return {
        "scenarios": results_cfg,
        "seeds": seeds,
        "ar_window": int(args.ar_window),
    }


def plot_results(results: dict[str, Any], out_dir: Path) -> None:
    try:
        import matplotlib.pyplot as plt
    except Exception as exc:
        print(f"[PLOT] matplotlib unavailable ({exc}); skipping plots.")
        return

    scenarios = results.get("scenarios", [])
    if not scenarios:
        print("[PLOT] No scenarios found.")
        return

    _apply_plot_style()

    labels = [s["name"] for s in scenarios]
    pretty_labels = []
    for name in labels:
        label = name.replace("_", " ")
        if "low rank" in label:
            label = label.replace("low rank", "rank")
        if "low-rank" in label:
            label = label.replace("low-rank", "rank")
        if "rank" in label and " R" in label:
            label = label.replace(" R", " R=")
        pretty_labels.append(label)
    colors = plt.cm.tab10(np.linspace(0.0, 0.8, len(labels)))

    # Effective rank (source / delta / state)
    fig, ax = plt.subplots(1, 1, figsize=(6.4, 3.2), constrained_layout=True)
    x = np.arange(len(labels))
    width = 0.25
    for i, key in enumerate(["eff_rank_source", "eff_rank_delta", "eff_rank_state"]):
        vals = []
        for s in scenarios:
            rows = s["seeds"]
            vals.append(float(np.nanmean([r.get(key, float("nan")) for r in rows])))
        ax.bar(x + (i - 1) * width, vals, width=width, label=key.replace("eff_rank_", ""))
    ax.set_xticks(x)
    ax.set_xticklabels(pretty_labels, rotation=0)
    ax.set_ylabel("Effective rank")
    ax.legend(frameon=False, fontsize=8)
    ax.grid(True, axis="y", linestyle=":", alpha=0.3)
    fig.savefig(out_dir / "stage4_effective_rank.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "stage4_effective_rank.png", dpi=240, bbox_inches="tight")
    plt.close(fig)

    # Spatial coherence (correlation across units) and elementwise alignment (z vs delta).
    fig, axes = plt.subplots(1, 2, figsize=(8.4, 3.2), constrained_layout=True)
    for ax, (key, title, ylabel) in zip(
        axes,
        [
            ("delta_offdiag_abs_corr", "Delta spatial coherence", "Mean |corr| (off-diag)"),
            ("z_delta_unit_abs_corr", "Elementwise z--delta alignment", "Mean |corr| (per unit)"),
        ],
    ):
        vals = []
        for s in scenarios:
            rows = s["seeds"]
            vals.append(float(np.nanmean([r.get(key, float("nan")) for r in rows])))
        ax.bar(pretty_labels, vals, color=colors[: len(labels)])
        ax.set_title(title)
        ax.set_ylabel(ylabel)
        ax.grid(True, axis="y", linestyle=":", alpha=0.3)
        ax.tick_params(axis="x", labelrotation=0)
    fig.savefig(out_dir / "stage4_spatial_corr.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "stage4_spatial_corr.png", dpi=240, bbox_inches="tight")
    plt.close(fig)

    # Closure diagnostics (Appendix B.4)
    fig, axes = plt.subplots(1, 2, figsize=(7.6, 3.2), constrained_layout=True)
    for ax, (key, title) in zip(
        axes,
        [
            ("closure_resid", "Closure residual"),
            ("mu_rel_mse", "Mu rel. MSE (oracle vs used)"),
        ],
    ):
        vals = []
        for s in scenarios:
            rows = s["seeds"]
            vals.append(float(np.nanmean([r.get(key, float("nan")) for r in rows])))
        ax.bar(pretty_labels, vals, color=colors[: len(labels)])
        ax.set_title(title)
        ax.grid(True, axis="y", linestyle=":", alpha=0.3)
        ax.tick_params(axis="x", labelrotation=0)
    fig.savefig(out_dir / "stage4_closure.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "stage4_closure.png", dpi=240, bbox_inches="tight")
    plt.close(fig)

    # Temporal smoothness + stability.
    #
    # We no longer use (or assume) AR(1) structure on the hidden state in the paper.
    # Instead, we validate short-window AR(1) structure directly on the true BPTT
    # teaching signal (delta), which is the quantity used in Appendix B.2.
    fig, axes = plt.subplots(2, 4, figsize=(11.2, 4.8), constrained_layout=True)

    ar_window = results.get("ar_window", None)
    delta_ar_bias_title = "Delta AR(1)+bias rel. MSE"
    delta_ar_nobias_title = "Delta AR(1) no-bias rel. MSE"
    if ar_window is not None:
        delta_ar_bias_title = f"Delta AR(1)+bias rel. MSE (window {ar_window})"
        delta_ar_nobias_title = f"Delta AR(1) no-bias rel. MSE (window {ar_window})"

    metrics = [
        ("delta_ar1_rel_mse", delta_ar_bias_title),
        ("delta_ar1_rel_mse_nobias", delta_ar_nobias_title),
        ("slope_log_diff", "Slope log-diff"),
        ("implicit_resid", "Implicit residual"),
        ("lambda_self_resid", "Lambda self-consistency"),
        ("spectral_radius_max", "Spectral radius (max)"),
        ("lambda_u_max", "Max |lambda u|"),
        ("closure_resid", "Closure residual"),
    ]

    for ax, (key, title) in zip(axes.flatten(), metrics):
        vals = []
        for s in scenarios:
            rows = s["seeds"]
            vals.append(float(np.nanmean([r.get(key, float("nan")) for r in rows])))
        ax.bar(pretty_labels, vals, color=colors[: len(labels)])
        ax.set_title(title)
        ax.grid(True, axis="y", linestyle=":", alpha=0.3)
        ax.tick_params(axis="x", labelrotation=0)

    fig.savefig(out_dir / "stage4_temporal_stability.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "stage4_temporal_stability.png", dpi=240, bbox_inches="tight")
    plt.close(fig)


def main() -> None:
    parser = argparse.ArgumentParser(description="Stage-4 diagnostics: assumptions and stability")
    parser.add_argument("--seed-start", type=int, default=120)
    parser.add_argument("--seed-end", type=int, default=140)
    parser.add_argument("--input-mode", type=str, default="gaussian", choices=["gaussian", "poisson"])

    parser.add_argument("--hidden-size", type=int, default=128)
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--time-steps", type=int, default=700)
    parser.add_argument("--burnin-steps", type=int, default=100)
    parser.add_argument("--fit-burnin-steps", type=int, default=100)

    parser.add_argument("--input-std", type=float, default=0.05)
    parser.add_argument("--poisson-rate", type=float, default=5.0)
    parser.add_argument("--gext-std", type=float, default=1.0)

    parser.add_argument("--low-rank-frac", type=float, default=0.95)
    parser.add_argument("--rank", type=int, default=1)

    parser.add_argument("--output-dim", type=int, default=16)
    parser.add_argument("--target-std", type=float, default=1.0)
    parser.add_argument("--target-ar1-rho", type=float, default=0.995)
    parser.add_argument("--output-weight-scale", type=float, default=1.0)

    parser.add_argument("--alpha-rho", type=float, default=0.995)
    parser.add_argument("--alpha-source", type=str, default="h", choices=["h", "g"])
    parser.add_argument("--lambda-window", type=int, default=50)
    parser.add_argument("--eps-lambda", type=float, default=1e-8)
    parser.add_argument("--lam-cap", type=float, default=0.99)
    parser.add_argument("--denom-floor", type=float, default=1e-3)
    parser.add_argument("--use-safe-cap", action="store_true", default=True)

    parser.add_argument("--error-lp-rho", type=float, default=0.9995)
    parser.add_argument("--ar-window", type=int, default=20)

    parser.add_argument("--gain-coarse-min", type=float, default=0.8)
    parser.add_argument("--gain-coarse-max", type=float, default=1.3)
    parser.add_argument("--gain-coarse-points", type=int, default=31)
    parser.add_argument("--gain-fine-window", type=float, default=0.02)
    parser.add_argument("--gain-fine-step", type=float, default=0.002)
    parser.add_argument("--lle-trials", type=int, default=1)

    parser.add_argument("--time-stride", type=int, default=5)
    parser.add_argument("--spectral-iters", type=int, default=20)

    parser.add_argument("--scenarios", type=str, default="low_rank", help="Comma-separated list: low_rank,iid")
    parser.add_argument("--plot-only", type=str, default=None)
    parser.add_argument("--out-dir", type=str, default=None)
    parser.add_argument("--overwrite", action="store_true")

    args = parser.parse_args()
    args.scenarios = [s.strip() for s in str(args.scenarios).split(",") if s.strip()]

    stamp = _timestamp()
    if args.out_dir is not None:
        out_dir = Path(str(args.out_dir))
        if out_dir.exists() and not bool(args.overwrite):
            raise SystemExit(f"[OUT] {out_dir} exists; pass --overwrite to write into it.")
        _ensure_dir(out_dir)
    else:
        out_dir = Path("paper_figure") / f"stage4_diagnostics_{stamp}"
        _ensure_dir(out_dir)
    print(f"[OUT] {out_dir}")

    if args.plot_only is not None:
        data = json.loads(Path(args.plot_only).read_text(encoding="utf-8"))
        plot_results(data, out_dir)
        return

    results = run_diagnostics(args)
    for scenario in results.get("scenarios", []):
        name = scenario.get("name", "unknown")
        sm = scenario.get("summary_means", {}) or {}
        print(
            "[SUMMARY] "
            f"{name}: "
            f"eff_rank_delta={sm.get('eff_rank_delta', float('nan')):.4g}, "
            f"rank_delta_row={sm.get('rank_delta_row', float('nan')):.4g}, "
            f"rank_delta_col={sm.get('rank_delta_col', float('nan')):.4g}"
        )
    (out_dir / "results.json").write_text(json.dumps(results, indent=2), encoding="utf-8")
    plot_results(results, out_dir)


if __name__ == "__main__":
    main()
