from __future__ import annotations

import argparse
import dataclasses
import json
import math
import os
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Literal

import numpy as np


def _timestamp() -> str:
    return datetime.now().strftime("%Y%m%d_%H%M%S")


def _ensure_dir(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


def _set_seeds(seed: int) -> np.random.Generator:
    return np.random.default_rng(int(seed))


def _clip01(x: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    return np.clip(x, eps, 1.0 - eps)


def _safe_norm(x: np.ndarray, eps: float = 1e-12) -> float:
    return float(np.linalg.norm(x) + eps)


def _cosine(a: np.ndarray, b: np.ndarray, eps: float = 1e-12) -> float:
    denom = (np.linalg.norm(a) * np.linalg.norm(b)) + eps
    return float(np.dot(a, b) / denom)


def _relerr(a: np.ndarray, b: np.ndarray, eps: float = 1e-12) -> float:
    return float(np.linalg.norm(a - b) / (np.linalg.norm(a) + eps))


def estimate_alpha_ar1(h: np.ndarray, clip: float = 0.99, eps: float = 1e-12) -> np.ndarray:
    """
    Least-squares AR(1) per unit:
      h_{t+1}^{(i)} ≈ α_i h_t^{(i)}
    Inputs:
      h: (T, N) or (T, N, B)
    Returns:
      alpha: (N,)
    """
    if h.ndim == 2:
        h_t = h[:-1]
        h_tp1 = h[1:]
        num = np.mean(h_tp1 * h_t, axis=0)
        den = np.mean(h_t * h_t, axis=0)
    elif h.ndim == 3:
        h_t = h[:-1]
        h_tp1 = h[1:]
        num = np.mean(h_tp1 * h_t, axis=(0, 2))
        den = np.mean(h_t * h_t, axis=(0, 2))
    else:
        raise ValueError(f"Unsupported h shape: {h.shape}")
    alpha = num / (den + eps)
    alpha = np.clip(alpha, -clip, clip)
    return alpha.astype(np.float64)


def simulate_rnn_tanh(
    W_hh: np.ndarray,
    W_xh: np.ndarray,
    b_h: np.ndarray,
    inputs: np.ndarray,
    h0: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Vanilla tanh RNN forward:
      x_t = W_hh h_{t-1} + W_xh I_t + b_h
      h_t = tanh(x_t)
    Inputs:
      inputs: (T, D, B)
      h0: (N, B)
    Returns:
      h: (T, N, B)
      u: (T, N, B) where u_t = 1 - h_t^2
    """
    T, D, B = inputs.shape
    N = W_hh.shape[0]
    if h0 is None:
        h_prev = np.zeros((N, B), dtype=np.float64)
    else:
        h_prev = np.asarray(h0, dtype=np.float64)
    h = np.zeros((T, N, B), dtype=np.float64)
    u = np.zeros((T, N, B), dtype=np.float64)
    for t in range(T):
        I_t = inputs[t]
        x_t = (W_hh @ h_prev) + (W_xh @ I_t) + b_h
        h_t = np.tanh(x_t)
        h[t] = h_t
        u[t] = 1.0 - h_t * h_t
        h_prev = h_t
    return h, u


def lyapunov_max_qr(
    W_hh: np.ndarray,
    W_xh: np.ndarray,
    b_h: np.ndarray,
    driver_input: np.ndarray,
    h0: np.ndarray | None = None,
    log_floor: float = 1e-12,
) -> float:
    """
    Benettin-style QR estimate of the maximal Lyapunov exponent for a single trajectory.
    driver_input: (D, T)
    """
    N = int(W_hh.shape[0])
    D, T = driver_input.shape
    if h0 is None:
        h = np.zeros((N, 1), dtype=np.float64)
    else:
        h = np.asarray(h0, dtype=np.float64).reshape(N, 1)

    Q = np.eye(N, dtype=np.float64)
    log_r_sum = np.zeros(N, dtype=np.float64)
    for t in range(T):
        I_t = driver_input[:, t].reshape(D, 1).astype(np.float64)
        x_t = (W_hh @ h) + (W_xh @ I_t) + b_h
        h = np.tanh(x_t)
        phi_prime = (1.0 - h * h).reshape(-1)
        J = (phi_prime[:, None]) * W_hh
        Z = J @ Q
        try:
            Q, R = np.linalg.qr(Z)
        except np.linalg.LinAlgError:
            return float("nan")
        diag = np.abs(np.diag(R))
        log_r_sum += np.log(np.clip(diag, log_floor, None))
    return float(np.max(log_r_sum / max(T, 1)))


def build_A_t(
    W_hh: np.ndarray,
    u_t: np.ndarray,
    alpha: np.ndarray,
) -> np.ndarray:
    """
    A_t = diag(u_t) W_hh^T diag(alpha)
    u_t: (N,)
    alpha: (N,)
    """
    return (u_t[:, None] * W_hh.T) * alpha[None, :]


def build_B_t(
    W_hh: np.ndarray,
    u_t: np.ndarray,
) -> np.ndarray:
    """
    B_t = diag(u_t) W_hh^T
    (This is the exact one-step adjoint/Jacobian factor appearing in BPTT.)
    """
    return u_t[:, None] * W_hh.T


def build_operator(
    W_hh: np.ndarray,
    u_t: np.ndarray,
    alpha: np.ndarray,
    mode: str,
) -> np.ndarray:
    if mode == "B":
        return build_B_t(W_hh, u_t)
    if mode == "Aalpha":
        return build_A_t(W_hh, u_t, alpha)
    raise ValueError("operator mode must be 'B' or 'Aalpha'")

def dominant_eigpair(A: np.ndarray) -> tuple[complex, np.ndarray]:
    eigvals, eigvecs = np.linalg.eig(A)
    idx = int(np.argmax(np.abs(eigvals)))
    mu = eigvals[idx]
    v = eigvecs[:, idx]
    v = np.real(v) if np.max(np.abs(np.imag(v))) < 1e-6 else np.real(v)
    v = v.astype(np.float64)
    v /= _safe_norm(v)
    return complex(mu), v


def rayleigh_scalar(A: np.ndarray, s: np.ndarray, eps: float = 1e-12) -> float:
    num = float(np.dot(s, A @ s))
    den = float(np.dot(s, s) + eps)
    return num / den


def componentwise_mu(A: np.ndarray, s: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    As = A @ s
    denom = np.where(np.abs(s) < eps, np.nan, s)
    mu = As / denom
    return mu.astype(np.float64)


def solve_resolvent(A: np.ndarray, s: np.ndarray) -> np.ndarray | None:
    eye = np.eye(A.shape[0], dtype=np.float64)
    try:
        return np.linalg.solve(eye - A, s)
    except np.linalg.LinAlgError:
        return None


def diagonal_scalar_predict(s: np.ndarray, mu: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    denom = 1.0 - mu
    denom = np.where(np.abs(denom) < 1e-6, np.sign(denom + eps) * 1e-6, denom)
    return s / denom


def scalar_predict(s: np.ndarray, mu: float, eps: float = 1e-12) -> np.ndarray:
    denom = 1.0 - float(mu)
    denom = denom if abs(denom) >= 1e-6 else (1e-6 if denom >= 0 else -1e-6)
    return s / denom


def fit_lambda_diag(
    A: np.ndarray,
    u_t: np.ndarray,
    sources: list[np.ndarray],
    eps: float = 1e-12,
) -> np.ndarray:
    """
    Fit per-unit λ in the diagonal approximation:
      A s ≈ (λ ⊙ u_t) ⊙ s
    by least squares over a source ensemble.
    Returns λ: (N,)
    """
    N = int(A.shape[0])
    u = u_t.astype(np.float64)
    num = np.zeros(N, dtype=np.float64)
    den = np.zeros(N, dtype=np.float64)
    for s in sources:
        As = A @ s
        num += s * As
        den += s * s
    lam = num / ((u + eps) * (den + eps))
    return lam


def lambda_action_residual(A: np.ndarray, u_t: np.ndarray, lam: np.ndarray, s: np.ndarray) -> float:
    As = A @ s
    approx = (lam * u_t) * s
    return float(np.linalg.norm(As - approx) / _safe_norm(As))


@dataclass
class ProbeResult:
    probe_type: str
    relerr: float
    cos: float
    residual: float


@dataclass
class SnapshotResult:
    t: int
    rho_op: float
    gap_ratio_op: float
    rho_B: float
    rho_Aalpha: float
    evr1: float
    evr5: float
    diag_power_ratio: float
    probes: list[ProbeResult]


@dataclass
class GainResult:
    g: float
    lyap: float
    alpha_median: float
    alpha_p95: float
    energy_r1: float
    energy_tau_int: float
    bptt_loss: str
    bptt_relerr_median: float
    bptt_cos_median: float
    bptt_relerr_p90: float
    snapshots: list[SnapshotResult]


def run_gain(
    *,
    g: float,
    rng: np.random.Generator,
    hidden: int,
    input_size: int,
    batch: int,
    t_burn: int,
    t_probe: int,
    sigma_in: float,
    w_xh_scale: float,
    snapshot_count: int,
    probes_per_snapshot: int,
    source_modes: list[str],
    W_hh_base: np.ndarray | None = None,
    W_xh: np.ndarray | None = None,
    inputs: np.ndarray | None = None,
    bptt_loss: str = "state_energy",
    lambda_window: int = 50,
    operator_mode: str = "B",
) -> GainResult:
    N = int(hidden)
    D = int(input_size)
    B = int(batch)
    T_total = int(t_burn + t_probe)

    if W_hh_base is None:
        W_hh_base = rng.standard_normal((N, N)).astype(np.float64) * (1.0 / math.sqrt(N))
    W_hh = np.asarray(W_hh_base, dtype=np.float64) * float(g)

    if W_xh is None:
        W_xh = rng.standard_normal((N, D)).astype(np.float64) * (float(w_xh_scale) / math.sqrt(max(1, D)))
    b_h = np.zeros((N, 1), dtype=np.float64)

    if inputs is None:
        inputs = rng.standard_normal((T_total, D, B)).astype(np.float64) * float(sigma_in)
    h, u = simulate_rnn_tanh(W_hh, W_xh, b_h, inputs)

    driver = np.mean(inputs, axis=2).T  # (D, T)
    lyap = lyapunov_max_qr(W_hh, W_xh, b_h, driver)

    h_probe = h[t_burn:]
    u_probe = u[t_burn:]
    alpha = estimate_alpha_ar1(h_probe, clip=0.99)
    alpha_median = float(np.median(alpha))
    alpha_p95 = float(np.percentile(alpha, 95))

    # Critical slowing down proxy from a scalar observable: energy(t) = mean(h_t^2)
    energy = np.mean(h_probe * h_probe, axis=(1, 2))
    energy = energy.astype(np.float64)
    energy_centered = energy - float(np.mean(energy))
    if energy_centered.size >= 3 and float(np.std(energy_centered)) > 1e-12:
        a = energy_centered[:-1]
        b = energy_centered[1:]
        energy_r1 = float(np.corrcoef(a, b)[0, 1])
    else:
        energy_r1 = float("nan")

    # Integrated autocorrelation time (simple cutoff when acf becomes negative).
    max_lag = min(200, energy_centered.size - 2)
    if max_lag >= 2 and float(np.std(energy_centered)) > 1e-12:
        var = float(np.var(energy_centered))
        tau_int = 0.5
        for k in range(1, max_lag + 1):
            c = float(np.mean(energy_centered[:-k] * energy_centered[k:])) / (var + 1e-12)
            if not np.isfinite(c) or c <= 0.0:
                break
            tau_int += c
        energy_tau_int = float(tau_int)
    else:
        energy_tau_int = float("nan")

    # ===== BPTT-vs-OLL check on a synthetic (label-free) loss =====
    # This directly tests the target approximation:
    #   δ^e_t = (u_t ⊙ g_t) ⊘ (1 - λ ⊙ u_t)
    # against exact BPTT:
    #   δ_t = u_t ⊙ (g_t + W_hh^T δ_{t+1})
    T = int(h_probe.shape[0])
    bptt_relerr_median = float("nan")
    bptt_relerr_p90 = float("nan")
    bptt_cos_median = float("nan")
    if bptt_loss != "none" and T >= 2:
        if bptt_loss not in {"state_energy", "readout_energy", "lowrank_ar1"}:
            raise ValueError("bptt_loss must be one of: none,state_energy,readout_energy,lowrank_ar1")

        # g_t: (T, N, B)
        if bptt_loss == "state_energy":
            g_seq = h_probe.copy()
        elif bptt_loss == "readout_energy":
            out_dim = max(4, N // 8)
            W_y = rng.standard_normal((out_dim, N)).astype(np.float64) * (1.0 / math.sqrt(N))
            g_seq = np.zeros_like(h_probe, dtype=np.float64)
            for t in range(T):
                y = W_y @ h_probe[t]
                g_seq[t] = W_y.T @ y
        else:
            # Low-rank AR(1) forcing as an "order-parameter" synthetic source (no labels/tasks).
            v = rng.standard_normal(N).astype(np.float64)
            v /= _safe_norm(v)
            a = 0.99
            q = rng.standard_normal(B).astype(np.float64)
            g_seq = np.zeros_like(h_probe, dtype=np.float64)
            for t in range(T):
                eps_q = rng.standard_normal(B).astype(np.float64)
                q = a * q + math.sqrt(max(1e-8, 1.0 - a * a)) * eps_q
                g_seq[t] = v[:, None] * q[None, :]

        # OLL forward-time λ update + δ^e computation
        alpha_rho = 0.995
        alpha_clip = 0.99
        denom_floor = 1e-3
        lambda_cap = 0.99
        eps = 1e-8
        eps_lambda = 1e-8
        lambda_window = max(2, int(lambda_window))
        lambda_rho = (lambda_window - 1) / lambda_window

        alpha_num = np.zeros((N, 1), dtype=np.float64)
        alpha_den = np.zeros((N, 1), dtype=np.float64)
        alpha_hat = np.zeros((N, 1), dtype=np.float64)
        S_A2 = np.zeros((N, 1), dtype=np.float64)
        S_AB = np.zeros((N, 1), dtype=np.float64)
        lambda_vals = np.zeros((N, 1), dtype=np.float64)

        h_prev = h[t_burn - 1] if t_burn > 0 else np.zeros((N, B), dtype=np.float64)
        prev_g = None
        prev_u = None
        delta_est = np.zeros((T, N, B), dtype=np.float64)

        for t in range(T):
            h_t = h_probe[t]
            u_t = u_probe[t]
            g_t = g_seq[t]

            hthp_mean = np.mean(h_t * h_prev, axis=1, keepdims=True)
            hphp_mean = np.mean(h_prev * h_prev, axis=1, keepdims=True)
            alpha_num = alpha_rho * alpha_num + (1.0 - alpha_rho) * hthp_mean
            alpha_den = alpha_rho * alpha_den + (1.0 - alpha_rho) * hphp_mean
            raw_alpha = alpha_num / (alpha_den + eps)
            alpha_hat = np.clip(raw_alpha, -alpha_clip, alpha_clip)

            denom = 1.0 - (lambda_vals * u_t)
            denom = np.where(np.abs(denom) < denom_floor, np.sign(denom + 1e-12) * denom_floor, denom)
            delta_est[t] = (u_t * g_t) / denom

            if prev_g is not None and prev_u is not None:
                A_s = prev_u * u_t * (alpha_hat * prev_g - g_t)
                B_s = alpha_hat * prev_u * prev_g - u_t * g_t
                A2_mean = np.mean(A_s * A_s, axis=1, keepdims=True)
                AB_mean = np.mean(A_s * B_s, axis=1, keepdims=True)
                S_A2 = lambda_rho * S_A2 + (1.0 - lambda_rho) * A2_mean
                S_AB = lambda_rho * S_AB + (1.0 - lambda_rho) * AB_mean
                lambda_unproj = S_AB / (S_A2 + eps_lambda)
                u_abs_max = np.max(np.abs(u_t), axis=1, keepdims=True) + 1e-12
                safe_cap = (1.0 - denom_floor) / u_abs_max
                cap = np.minimum(safe_cap, lambda_cap)
                lambda_vals = np.clip(lambda_unproj, -cap, cap)

            prev_g = g_t
            prev_u = u_t
            h_prev = h_t

        # Exact BPTT recursion (backwards)
        delta_true = np.zeros((T, N, B), dtype=np.float64)
        delta_next = np.zeros((N, B), dtype=np.float64)
        W_T = W_hh.T
        for t in range(T - 1, -1, -1):
            delta_t = u_probe[t] * (g_seq[t] + (W_T @ delta_next))
            delta_true[t] = delta_t
            delta_next = delta_t

        # Aggregate metrics across time using batch-mean vectors
        relerrs = []
        coss = []
        for t in range(T):
            dt = np.mean(delta_true[t], axis=1)
            de = np.mean(delta_est[t], axis=1)
            relerrs.append(_relerr(dt, de))
            coss.append(_cosine(dt, de))
        relerrs = np.asarray(relerrs, dtype=np.float64)
        coss = np.asarray(coss, dtype=np.float64)
        bptt_relerr_median = float(np.nanmedian(relerrs))
        bptt_relerr_p90 = float(np.nanpercentile(relerrs, 90))
        bptt_cos_median = float(np.nanmedian(coss))

    # Snapshot times (in probe coordinates)
    if snapshot_count <= 1:
        snap_ts = [int(t_probe // 2)]
    else:
        snap_ts = np.linspace(0, t_probe - 1, num=int(snapshot_count), dtype=int).tolist()
    snapshots: list[SnapshotResult] = []

    for t_idx in snap_ts:
        u_t_mean = np.mean(u_probe[t_idx], axis=1)  # (N,)
        B_t = build_B_t(W_hh, u_t_mean)
        A_t = build_A_t(W_hh, u_t_mean, alpha)
        op = build_operator(W_hh, u_t_mean, alpha, operator_mode)

        eye = np.eye(N, dtype=np.float64)
        try:
            R = np.linalg.solve(eye - op, eye)  # (I - op)^{-1}
        except np.linalg.LinAlgError:
            continue
        diag_R = np.diag(R).astype(np.float64)
        diag_power_ratio = float(np.sum(diag_R * diag_R) / (np.sum(R * R) + 1e-12))

        eigvals = np.linalg.eigvals(op)
        abs_eigs = np.sort(np.abs(eigvals))[::-1]
        mu1_abs = float(abs_eigs[0]) if abs_eigs.size > 0 else float("nan")
        mu2_abs = float(abs_eigs[1]) if abs_eigs.size > 1 else float("nan")
        gap_ratio = (
            float(mu1_abs / (mu2_abs + 1e-12)) if np.isfinite(mu1_abs) and np.isfinite(mu2_abs) else float("nan")
        )
        rho_op = float(mu1_abs)
        rho_B = float(np.max(np.abs(np.linalg.eigvals(B_t))))
        rho_Aalpha = float(np.max(np.abs(np.linalg.eigvals(A_t))))

        # Rank-1 dominance of the resolvent response (slaving proxy):
        # for isotropic sources, response covariance ~ R R^T, so EVR is given by singular values of R.
        try:
            U_svd, sing, Vt_svd = np.linalg.svd(R, full_matrices=False)
            power = sing * sing
            evr = power / (np.sum(power) + 1e-12)
            evr1 = float(evr[0])
            evr5 = float(np.sum(evr[:5]))
            u1_svd = U_svd[:, 0].astype(np.float64)
            v1_svd = Vt_svd[0, :].astype(np.float64)
            sigma1 = float(sing[0])
        except np.linalg.LinAlgError:
            evr1 = float("nan")
            evr5 = float("nan")
            u1_svd = None
            v1_svd = None
            sigma1 = float("nan")

        probes: list[ProbeResult] = []

        # Precompute eigendecomposition for eigen-aligned probes (of the chosen operator)
        eigvals_full, eigvecs_full = np.linalg.eig(op)
        idx1 = int(np.argmax(np.abs(eigvals_full)))
        mu1 = complex(eigvals_full[idx1])
        v1 = np.real(eigvecs_full[:, idx1]).astype(np.float64)
        v1 /= _safe_norm(v1)

        # Fit λ from an ensemble of isotropic sources (tests whether a source-independent diagonal gain exists)
        fit_sources: list[np.ndarray] = []
        fit_k = max(8, int(probes_per_snapshot) * 4)
        for _ in range(fit_k):
            s_fit = rng.standard_normal(N).astype(np.float64)
            s_fit /= _safe_norm(s_fit)
            fit_sources.append(s_fit)
        lam_hat = fit_lambda_diag(op, u_t_mean, fit_sources)

        for _ in range(int(probes_per_snapshot)):
            for mode in source_modes:
                if mode == "random_s":
                    s = rng.standard_normal(N).astype(np.float64)
                    s /= _safe_norm(s)
                    probe_type = "random_s"
                elif mode == "eigvec":
                    s = v1.copy()
                    probe_type = "eigvec"
                elif mode == "state_energy":
                    h_t_mean = np.mean(h_probe[t_idx], axis=1)
                    g_vec = h_t_mean
                    s = (u_t_mean * g_vec).astype(np.float64)
                    s /= _safe_norm(s)
                    probe_type = "state_energy"
                elif mode == "readout_energy":
                    # Synthetic loss: 0.5||y||^2 with random readout => g = W_y^T W_y h
                    out_dim = max(4, N // 16)
                    W_y = rng.standard_normal((out_dim, N)).astype(np.float64) * (1.0 / math.sqrt(N))
                    h_t_mean = np.mean(h_probe[t_idx], axis=1)
                    y = W_y @ h_t_mean
                    g_vec = W_y.T @ y
                    s = (u_t_mean * g_vec).astype(np.float64)
                    s /= _safe_norm(s)
                    probe_type = "readout_energy"
                else:
                    raise ValueError(f"Unknown source mode: {mode}")

                delta = R @ s

                # Directional single-pole (Rayleigh scalar) + residual
                mu_s = rayleigh_scalar(op, s)
                delta_scalar = scalar_predict(s, mu_s)
                res_scalar = np.linalg.norm((op @ s) - (mu_s * s)) / (_safe_norm(op @ s))
                probes.append(
                    ProbeResult(
                        probe_type=f"{probe_type}/rayleigh",
                        relerr=_relerr(delta, delta_scalar),
                        cos=_cosine(delta, delta_scalar),
                        residual=float(res_scalar),
                    )
                )

                # Diagonal echo with source-independent λ̂ (this matches the OLL structural claim)
                mu_hat = lam_hat * u_t_mean
                delta_diag = diagonal_scalar_predict(s, mu_hat)
                res_diag = lambda_action_residual(op, u_t_mean, lam_hat, s)
                probes.append(
                    ProbeResult(
                        probe_type=f"{probe_type}/diag_lamhat",
                        relerr=_relerr(delta, delta_diag),
                        cos=_cosine(delta, delta_diag),
                        residual=float(res_diag),
                    )
                )

                # Diagonal susceptibility (keep only diag of the true resolvent R)
                delta_diagR = diag_R * s
                probes.append(
                    ProbeResult(
                        probe_type=f"{probe_type}/diagR",
                        relerr=_relerr(delta, delta_diagR),
                        cos=_cosine(delta, delta_diagR),
                        residual=float("nan"),
                    )
                )

                # Rank-1 susceptibility (dominant singular mode of the resolvent)
                if u1_svd is not None and v1_svd is not None and np.isfinite(sigma1):
                    delta_rank1 = (sigma1 * u1_svd) * float(np.dot(v1_svd, s))
                    probes.append(
                        ProbeResult(
                            probe_type=f"{probe_type}/rank1",
                            relerr=_relerr(delta, delta_rank1),
                            cos=_cosine(delta, delta_rank1),
                            residual=float("nan"),
                        )
                    )

                # Eigenvalue-based best-case scalar (for eigen-aligned sources only)
                delta_mu1 = scalar_predict(s, float(np.real(mu1)))
                res_mu1 = np.linalg.norm((op @ s) - (np.real(mu1) * s)) / (_safe_norm(op @ s))
                probes.append(
                    ProbeResult(
                        probe_type=f"{probe_type}/mu1",
                        relerr=_relerr(delta, delta_mu1),
                        cos=_cosine(delta, delta_mu1),
                        residual=float(res_mu1),
                    )
                )

        snapshots.append(
            SnapshotResult(
                t=int(t_idx),
                rho_op=rho_op,
                gap_ratio_op=gap_ratio,
                rho_B=rho_B,
                rho_Aalpha=rho_Aalpha,
                evr1=evr1,
                evr5=evr5,
                diag_power_ratio=diag_power_ratio,
                probes=probes,
            )
        )

    return GainResult(
        g=float(g),
        lyap=float(lyap),
        alpha_median=alpha_median,
        alpha_p95=alpha_p95,
        energy_r1=energy_r1,
        energy_tau_int=energy_tau_int,
        bptt_loss=str(bptt_loss),
        bptt_relerr_median=bptt_relerr_median,
        bptt_cos_median=bptt_cos_median,
        bptt_relerr_p90=bptt_relerr_p90,
        snapshots=snapshots,
    )


def summarize_gain(result: GainResult) -> dict[str, Any]:
    snap_rho_op = [s.rho_op for s in result.snapshots if np.isfinite(s.rho_op)]
    snap_gap_op = [s.gap_ratio_op for s in result.snapshots if np.isfinite(s.gap_ratio_op)]
    snap_rho_B = [s.rho_B for s in result.snapshots if np.isfinite(s.rho_B)]
    snap_rho_A = [s.rho_Aalpha for s in result.snapshots if np.isfinite(s.rho_Aalpha)]
    snap_evr1 = [s.evr1 for s in result.snapshots if np.isfinite(s.evr1)]
    snap_evr5 = [s.evr5 for s in result.snapshots if np.isfinite(s.evr5)]
    snap_diag_ratio = [s.diag_power_ratio for s in result.snapshots if np.isfinite(s.diag_power_ratio)]
    probe_map: dict[str, list[ProbeResult]] = {}
    for s in result.snapshots:
        for p in s.probes:
            probe_map.setdefault(p.probe_type, []).append(p)

    def _summary_probe(items: list[ProbeResult]) -> dict[str, float]:
        relerrs = np.array([p.relerr for p in items], dtype=np.float64)
        cos = np.array([p.cos for p in items], dtype=np.float64)
        res = np.array([p.residual for p in items], dtype=np.float64)
        res_finite = res[np.isfinite(res)]
        return {
            "relerr_median": float(np.nanmedian(relerrs)),
            "relerr_p90": float(np.nanpercentile(relerrs, 90)),
            "cos_median": float(np.nanmedian(cos)),
            "residual_median": float(np.median(res_finite)) if res_finite.size else float("nan"),
        }

    probes_summary = {k: _summary_probe(v) for k, v in probe_map.items() if v}
    return {
        "g": result.g,
        "lyap": result.lyap,
        "alpha_median": result.alpha_median,
        "alpha_p95": result.alpha_p95,
        "energy_r1": result.energy_r1,
        "energy_tau_int": result.energy_tau_int,
        "bptt_loss": result.bptt_loss,
        "bptt_relerr_median": result.bptt_relerr_median,
        "bptt_relerr_p90": result.bptt_relerr_p90,
        "bptt_cos_median": result.bptt_cos_median,
        "rho_op_median": float(np.median(snap_rho_op)) if snap_rho_op else float("nan"),
        "gap_ratio_op_median": float(np.median(snap_gap_op)) if snap_gap_op else float("nan"),
        "rho_B_median": float(np.median(snap_rho_B)) if snap_rho_B else float("nan"),
        "rho_Aalpha_median": float(np.median(snap_rho_A)) if snap_rho_A else float("nan"),
        "rank1_evr1_median": float(np.median(snap_evr1)) if snap_evr1 else float("nan"),
        "rank1_evr5_median": float(np.median(snap_evr5)) if snap_evr5 else float("nan"),
        "diag_power_ratio_median": float(np.median(snap_diag_ratio)) if snap_diag_ratio else float("nan"),
        "probes": probes_summary,
    }


def plot_curves(summary: list[dict[str, Any]], out_dir: Path) -> None:
    try:
        import matplotlib.pyplot as plt
    except Exception:
        return

    def _get(key: str) -> list[float]:
        return [float(s.get(key, float("nan"))) for s in summary]

    gs = _get("g")
    lyaps = _get("lyap")
    rho_op = _get("rho_op_median")
    rho_B = _get("rho_B_median")
    rho_A = _get("rho_Aalpha_median")
    alpha_med = _get("alpha_median")
    energy_r1 = _get("energy_r1")
    energy_tau = _get("energy_tau_int")
    gap = _get("gap_ratio_op_median")
    evr1 = _get("rank1_evr1_median")
    evr5 = _get("rank1_evr5_median")
    diag_ratio = _get("diag_power_ratio_median")
    bptt_relerr = _get("bptt_relerr_median")
    bptt_cos = _get("bptt_cos_median")

    plt.figure(figsize=(7.2, 4.4), constrained_layout=True)
    plt.plot(gs, lyaps, marker="o")
    plt.axhline(0.0, color="k", linewidth=1.0, linestyle=":")
    plt.xlabel("Gain g")
    plt.ylabel("Max Lyapunov Exponent (finite-time)")
    plt.title("Driven Lyapunov vs Gain")
    plt.grid(True)
    plt.savefig(out_dir / "lyapunov_vs_g.png")
    plt.close()

    plt.figure(figsize=(7.2, 4.4), constrained_layout=True)
    plt.plot(gs, alpha_med, marker="o", label="median")
    plt.xlabel("Gain g")
    plt.ylabel("AR(1) α median")
    plt.title("Critical Slowing Down Proxy: α vs Gain")
    plt.grid(True)
    plt.savefig(out_dir / "alpha_vs_g.png")
    plt.close()

    plt.figure(figsize=(7.2, 4.4), constrained_layout=True)
    plt.plot(gs, energy_r1, marker="o", label="corr(e_t,e_{t+1})")
    plt.xlabel("Gain g")
    plt.ylabel("lag-1 autocorr of energy(t)")
    plt.title("Critical Slowing Down Proxy: lag-1 autocorr (energy)")
    plt.grid(True)
    plt.savefig(out_dir / "energy_r1_vs_g.png")
    plt.close()

    plt.figure(figsize=(7.2, 4.4), constrained_layout=True)
    plt.plot(gs, energy_tau, marker="o")
    plt.xlabel("Gain g")
    plt.ylabel("integrated autocorr time (energy)")
    plt.title("Critical Slowing Down Proxy: τ_int (energy)")
    plt.grid(True)
    plt.savefig(out_dir / "energy_tau_vs_g.png")
    plt.close()

    plt.figure(figsize=(7.2, 4.4), constrained_layout=True)
    plt.plot(gs, rho_op, marker="o", label="ρ(op)")
    if not np.all(np.isnan(rho_B)):
        plt.plot(gs, rho_B, marker="o", linestyle="--", label="ρ(B_t)")
    if not np.all(np.isnan(rho_A)):
        plt.plot(gs, rho_A, marker="o", linestyle=":", label="ρ(A_t= B_t diag(alpha))")
    plt.axhline(1.0, color="k", linewidth=1.0, linestyle=":")
    plt.xlabel("Gain g")
    plt.ylabel("median spectral radius")
    plt.title("Spectral Radius vs Gain (operator + references)")
    plt.legend()
    plt.grid(True)
    plt.savefig(out_dir / "rhoA_vs_g.png")
    plt.close()

    plt.figure(figsize=(7.2, 4.4), constrained_layout=True)
    plt.plot(gs, gap, marker="o")
    plt.xlabel("Gain g")
    plt.ylabel("median |μ1|/|μ2|")
    plt.title("Spectral Gap Proxy (chosen operator)")
    plt.grid(True)
    plt.savefig(out_dir / "gap_ratio_vs_g.png")
    plt.close()

    plt.figure(figsize=(7.2, 4.4), constrained_layout=True)
    plt.plot(gs, evr1, marker="o", label="EVR1")
    plt.plot(gs, evr5, marker="o", label="EVR1-5", linestyle="--")
    plt.xlabel("Gain g")
    plt.ylabel("Explained variance ratio")
    plt.title("Rank-1 Dominance of Resolvent Responses (δ = (I-op)^{-1}s)")
    plt.grid(True)
    plt.legend()
    plt.savefig(out_dir / "rank1_evr_vs_g.png")
    plt.close()

    plt.figure(figsize=(7.2, 4.4), constrained_layout=True)
    plt.plot(gs, diag_ratio, marker="o")
    plt.xlabel("Gain g")
    plt.ylabel("||diag(R)||_F^2 / ||R||_F^2")
    plt.title("Diagonal Dominance of the Susceptibility (R = (I-op)^{-1})")
    plt.grid(True)
    plt.savefig(out_dir / "diag_power_ratio_vs_g.png")
    plt.close()

    plt.figure(figsize=(7.2, 4.4), constrained_layout=True)
    plt.plot(gs, bptt_relerr, marker="o")
    plt.xlabel("Gain g")
    plt.ylabel("Median RelErr")
    plt.title("BPTT vs OLL: relerr (synthetic loss)")
    plt.grid(True)
    plt.savefig(out_dir / "bptt_relerr_vs_g.png")
    plt.close()

    plt.figure(figsize=(7.2, 4.4), constrained_layout=True)
    plt.plot(gs, bptt_cos, marker="o")
    plt.xlabel("Gain g")
    plt.ylabel("Median Cosine")
    plt.title("BPTT vs OLL: cosine (synthetic loss)")
    plt.grid(True)
    plt.savefig(out_dir / "bptt_cos_vs_g.png")
    plt.close()

    # Scalarization error vs Lyapunov curves per probe type, if present
    probe_keys: set[str] = set()
    for s in summary:
        probe_keys.update((s.get("probes") or {}).keys())
    for pk in sorted(probe_keys):
        relerrs = []
        cos = []
        for s in summary:
            p = (s.get("probes") or {}).get(pk) or {}
            relerrs.append(float(p.get("relerr_median", float("nan"))))
            cos.append(float(p.get("cos_median", float("nan"))))

        plt.figure(figsize=(7.2, 4.4), constrained_layout=True)
        plt.scatter(lyaps, relerrs)
        plt.axvline(0.0, color="k", linewidth=1.0, linestyle=":")
        plt.xlabel("Max Lyapunov Exponent")
        plt.ylabel("Median RelErr")
        plt.title(f"Scalarization Error vs Lyapunov ({pk})")
        plt.grid(True)
        safe_name = pk.replace("/", "_")
        plt.savefig(out_dir / f"relerr_vs_lyap__{safe_name}.png")
        plt.close()

        plt.figure(figsize=(7.2, 4.4), constrained_layout=True)
        plt.scatter(lyaps, cos)
        plt.axvline(0.0, color="k", linewidth=1.0, linestyle=":")
        plt.xlabel("Max Lyapunov Exponent")
        plt.ylabel("Median Cosine")
        plt.title(f"Alignment vs Lyapunov ({pk})")
        plt.grid(True)
        plt.savefig(out_dir / f"cos_vs_lyap__{safe_name}.png")
        plt.close()


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description="Stage-0 (no task) physical probe for OLL echo hypotheses.")
    parser.add_argument("--hidden", type=int, default=128)
    parser.add_argument("--input-size", type=int, default=8)
    parser.add_argument("--batch", type=int, default=16)
    parser.add_argument("--burn", type=int, default=300)
    parser.add_argument("--probe", type=int, default=1200)
    parser.add_argument("--sigma-in", type=float, default=0.5)
    parser.add_argument("--w-xh-scale", type=float, default=0.2)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--gains", type=str, default="0.5,0.7,0.9,1.1,1.3,1.5")
    parser.add_argument("--snapshots", type=int, default=6)
    parser.add_argument("--probes-per-snapshot", type=int, default=2)
    parser.add_argument(
        "--sources",
        type=str,
        default="random_s,eigvec,state_energy,readout_energy",
        help="Comma-separated: random_s,eigvec,state_energy,readout_energy",
    )
    parser.add_argument("--out-dir", type=str, default="")
    parser.add_argument(
        "--bptt-loss",
        type=str,
        default="none",
        choices=["none", "state_energy", "readout_energy", "lowrank_ar1"],
        help="Synthetic label-free loss used to generate g_t for the BPTT-vs-OLL check.",
    )
    parser.add_argument("--lambda-window", type=int, default=50)
    parser.add_argument(
        "--operator",
        type=str,
        default="B",
        choices=["B", "Aalpha"],
        help="Operator used for snapshot resolvent analysis: B=diag(u)W^T, Aalpha=diag(u)W^T diag(alpha).",
    )
    parser.add_argument(
        "--whh-structure",
        type=str,
        default="dense",
        choices=["dense", "sparse", "diag_plus_dense"],
        help="Structure of the recurrent weight matrix W_hh (Stage-0).",
    )
    parser.add_argument(
        "--whh-density",
        type=float,
        default=0.1,
        help="Connection probability for --whh-structure sparse.",
    )
    parser.add_argument(
        "--whh-diag",
        type=float,
        default=0.0,
        help="Diagonal strength added when --whh-structure diag_plus_dense.",
    )
    args = parser.parse_args(argv)

    rng = _set_seeds(args.seed)
    gains = [float(x.strip()) for x in str(args.gains).split(",") if x.strip()]
    source_modes = [x.strip() for x in str(args.sources).split(",") if x.strip()]

    out_dir = Path(args.out_dir) if args.out_dir else (Path("plots") / f"oll_stage0_probe_{_timestamp()}")
    _ensure_dir(out_dir)

    # Fix a base realization and a fixed input drive so curves reflect only gain changes.
    N = int(args.hidden)
    D = int(args.input_size)
    B = int(args.batch)
    T_total = int(args.burn + args.probe)
    whh_structure = str(args.whh_structure)
    if whh_structure == "dense":
        W_hh_base = rng.standard_normal((N, N)).astype(np.float64) * (1.0 / math.sqrt(N))
    elif whh_structure == "sparse":
        p = float(args.whh_density)
        if not (0.0 < p <= 1.0):
            raise ValueError("--whh-density must be in (0, 1].")
        mask = (rng.random((N, N)) < p).astype(np.float64)
        W_hh_base = rng.standard_normal((N, N)).astype(np.float64) * mask
        W_hh_base = W_hh_base * (1.0 / math.sqrt(max(1e-12, p * N)))
    elif whh_structure == "diag_plus_dense":
        diag_strength = float(args.whh_diag)
        W_hh_base = rng.standard_normal((N, N)).astype(np.float64) * (1.0 / math.sqrt(N))
        W_hh_base = W_hh_base + diag_strength * np.eye(N, dtype=np.float64)
    else:
        raise ValueError(f"Unknown --whh-structure: {whh_structure}")
    W_xh = rng.standard_normal((N, D)).astype(np.float64) * (float(args.w_xh_scale) / math.sqrt(max(1, D)))
    inputs = rng.standard_normal((T_total, D, B)).astype(np.float64) * float(args.sigma_in)

    results: list[GainResult] = []
    for g in gains:
        print(f"[Stage0] g={g:.3f} ...")
        res = run_gain(
            g=g,
            rng=rng,
            hidden=args.hidden,
            input_size=args.input_size,
            batch=args.batch,
            t_burn=args.burn,
            t_probe=args.probe,
            sigma_in=args.sigma_in,
            w_xh_scale=args.w_xh_scale,
            snapshot_count=args.snapshots,
            probes_per_snapshot=args.probes_per_snapshot,
            source_modes=source_modes,
            W_hh_base=W_hh_base,
            W_xh=W_xh,
            inputs=inputs,
            bptt_loss=str(args.bptt_loss),
            lambda_window=int(args.lambda_window),
            operator_mode=str(args.operator),
        )
        results.append(res)
        print(
            f"  lyap={res.lyap:.4f} | alpha_med={res.alpha_median:.3f} | "
            f"energy_r1={res.energy_r1:.3f} | tau_int={res.energy_tau_int:.2f} | "
            f"bptt_relerr={res.bptt_relerr_median:.3f} | bptt_cos={res.bptt_cos_median:.3f}"
        )

    summary = [summarize_gain(r) for r in results]
    with open(out_dir / "summary.json", "w", encoding="utf-8") as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)

    plot_curves(summary, out_dir)
    print(f"[Stage0] Wrote {out_dir / 'summary.json'}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
