from __future__ import annotations

import argparse
import math
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple

import numpy as np
import torch

ROOT_DIR = Path(__file__).resolve().parents[2]
TASK_DIR = ROOT_DIR / "task"
if str(TASK_DIR) not in sys.path:
    sys.path.insert(0, str(TASK_DIR))

os.environ["OLL_PHYSICAL_PROBE_NO_LYAPUNOV"] = "1"

from common.sequence_core import (  # noqa: E402
    DEFAULT_DEVICE,
    TorchBPTTRNN,
    TorchLocalRuleRNN,
    build_lyapunov_driver,
    build_repeated_targets,
    calculate_lyapunov_exponent_numpy,
    evaluate_classifier_final_step,
    evaluate_regression_mse,
    evaluate_regression_mse_rollout,
    extract_params,
    iterate_minibatches,
    load_mnist_images,
    load_params,
    generate_lorenz_sequences,
    split_train_val,
    train_batches,
)


@dataclass
class ProbeSnapshot:
    x: np.ndarray
    h: np.ndarray
    u: np.ndarray
    g: np.ndarray
    s: np.ndarray
    delta_true: np.ndarray
    delta_approx: np.ndarray
    W_hh: np.ndarray
    alpha_hat: np.ndarray
    lambda_vals: np.ndarray
    loss: float


@dataclass(frozen=True)
class HypothesisThresholds:
    rank_top2_median_min: float
    rank_pc1_median_min: float
    rank90_median_max: float
    h_corr_median_min: float
    h_r2_centered_median_min: float
    grad_smooth_corr_median_min: float
    eff_gain_p95_max: float
    neumann_median_max: float
    neumann_p90_max: float
    grad_cos_median_min: float
    grad_sign_median_min: float
    lambda_vol_p95_max: float


THRESHOLD_TIERS: List[Tuple[str, HypothesisThresholds]] = [
    (
        "strict",
        HypothesisThresholds(
            rank_top2_median_min=0.60,
            rank_pc1_median_min=0.45,
            rank90_median_max=5.0,
            h_corr_median_min=0.80,
            h_r2_centered_median_min=0.55,
            grad_smooth_corr_median_min=0.30,
            eff_gain_p95_max=0.95,
            neumann_median_max=0.40,
            neumann_p90_max=2.5,
            grad_cos_median_min=0.60,
            grad_sign_median_min=0.85,
            lambda_vol_p95_max=0.40,
        ),
    ),
    (
        "tight",
        HypothesisThresholds(
            rank_top2_median_min=0.58,
            rank_pc1_median_min=0.40,
            rank90_median_max=6.0,
            h_corr_median_min=0.78,
            h_r2_centered_median_min=0.50,
            grad_smooth_corr_median_min=0.20,
            eff_gain_p95_max=0.97,
            neumann_median_max=0.50,
            neumann_p90_max=3.0,
            grad_cos_median_min=0.45,
            grad_sign_median_min=0.80,
            lambda_vol_p95_max=0.45,
        ),
    ),
    (
        "moderate",
        HypothesisThresholds(
            rank_top2_median_min=0.56,
            rank_pc1_median_min=0.35,
            rank90_median_max=7.0,
            h_corr_median_min=0.75,
            h_r2_centered_median_min=0.45,
            grad_smooth_corr_median_min=0.12,
            eff_gain_p95_max=0.985,
            neumann_median_max=0.80,
            neumann_p90_max=4.0,
            grad_cos_median_min=0.30,
            grad_sign_median_min=0.75,
            lambda_vol_p95_max=0.55,
        ),
    ),
    (
        "relaxed",
        HypothesisThresholds(
            rank_top2_median_min=0.54,
            rank_pc1_median_min=0.30,
            rank90_median_max=8.0,
            h_corr_median_min=0.72,
            h_r2_centered_median_min=0.35,
            grad_smooth_corr_median_min=0.08,
            eff_gain_p95_max=0.995,
            neumann_median_max=1.00,
            neumann_p90_max=5.0,
            grad_cos_median_min=0.15,
            grad_sign_median_min=0.68,
            lambda_vol_p95_max=0.65,
        ),
    ),
    (
        "lenient",
        HypothesisThresholds(
            rank_top2_median_min=0.52,
            rank_pc1_median_min=0.28,
            rank90_median_max=9.0,
            h_corr_median_min=0.70,
            h_r2_centered_median_min=0.25,
            grad_smooth_corr_median_min=0.05,
            eff_gain_p95_max=0.999,
            neumann_median_max=1.20,
            neumann_p90_max=6.0,
            grad_cos_median_min=0.10,
            grad_sign_median_min=0.65,
            lambda_vol_p95_max=0.75,
        ),
    ),
]


def select_threshold_tier(
    attempt: int, attempts_per_tier: int
) -> Tuple[int, str, HypothesisThresholds]:
    attempt = max(1, int(attempt))
    attempts_per_tier = max(1, int(attempts_per_tier))
    tier_idx = min((attempt - 1) // attempts_per_tier, len(THRESHOLD_TIERS) - 1)
    tier_name, thresholds = THRESHOLD_TIERS[tier_idx]
    return tier_idx, tier_name, thresholds


class VanillaRNN(torch.nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        device: torch.device,
        loss_mode: str,
    ) -> None:
        super().__init__()
        if loss_mode not in {"ce", "mse"}:
            raise ValueError("loss_mode must be 'ce' or 'mse'.")
        self.loss_mode = loss_mode
        self.W_xh = torch.nn.Parameter(torch.zeros(hidden_size, input_size, device=device))
        self.W_hh = torch.nn.Parameter(torch.zeros(hidden_size, hidden_size, device=device))
        self.b_h = torch.nn.Parameter(torch.zeros(hidden_size, 1, device=device))
        self.W_hy = torch.nn.Parameter(torch.zeros(output_size, hidden_size, device=device))
        self.b_y = torch.nn.Parameter(torch.zeros(output_size, 1, device=device))

    def load_from_local(self, local_model: TorchLocalRuleRNN) -> None:
        with torch.no_grad():
            self.W_xh.copy_(local_model.W_xh)
            self.W_hh.copy_(local_model.W_hh)
            self.b_h.copy_(local_model.b_h)
            self.W_hy.copy_(local_model.W_hy)
            self.b_y.copy_(local_model.b_y)

    def forward_with_probe(
        self,
        inputs: torch.Tensor,
        targets: torch.Tensor,
        step_weights: torch.Tensor | None,
        lambda_vals: torch.Tensor,
        probe: "PhysicalProbe",
    ) -> Tuple[float, ProbeSnapshot]:
        batch_size = int(inputs.shape[0])
        time_steps = int(inputs.shape[2])
        h_prev = torch.zeros((self.W_hh.shape[0], batch_size), device=inputs.device)
        loss_total = torch.zeros((), device=inputs.device)

        x_list: List[torch.Tensor] = []
        h_list: List[torch.Tensor] = []
        u_list: List[torch.Tensor] = []
        g_list: List[torch.Tensor] = []
        s_list: List[torch.Tensor] = []
        delta_approx_list: List[torch.Tensor] = []
        delta_true_list: List[torch.Tensor | None] = []

        weight_norm = float(step_weights.sum().item()) if step_weights is not None else float(time_steps)
        weight_norm = max(weight_norm, 1.0)

        for t in range(time_steps):
            step_weight = step_weights[t] if step_weights is not None else 1.0
            I_t = inputs[:, :, t].T
            y_true_t = targets[:, :, t].T

            x_t = self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h
            h_t = torch.tanh(x_t)
            y_hat_t = self.W_hy @ h_t + self.b_y

            x_list.append(x_t)
            h_list.append(h_t)
            u_t = 1.0 - h_t**2
            u_list.append(u_t)

            if self.loss_mode == "ce":
                p_t = torch.softmax(y_hat_t, dim=0)
                dL_dyhat = p_t - y_true_t
            else:
                dL_dyhat = y_hat_t - y_true_t
            if step_weights is not None:
                dL_dyhat = dL_dyhat * step_weight
            g_t = self.W_hy.T @ dL_dyhat
            g_list.append(g_t)
            s_list.append(u_t * g_t)

            lambda_used = lambda_vals
            denominator = 1.0 - lambda_used * u_t
            denom_floor = 1e-3
            denom_mask = torch.abs(denominator) < denom_floor
            denominator = torch.where(
                denom_mask,
                denom_floor * torch.sign(denominator + 1e-12),
                denominator,
            )
            delta_approx = (u_t * g_t) / denominator
            delta_approx_list.append(delta_approx)

            if self.loss_mode == "ce":
                log_probs = torch.log_softmax(y_hat_t, dim=0)
                loss_t = -(y_true_t * log_probs).sum(dim=0).mean()
            else:
                error = y_hat_t - y_true_t
                loss_t = 0.5 * torch.sum(error**2, dim=0).mean()
            loss_t = loss_t * step_weight
            loss_total = loss_total + loss_t

            delta_true_list.append(None)
            probe.register_grad_hook(x_t, delta_true_list, t)

            h_prev = h_t

        loss = loss_total / weight_norm
        loss.backward()

        delta_true = []
        for idx, grad in enumerate(delta_true_list):
            if grad is None:
                delta_true.append(torch.zeros_like(x_list[idx]))
            else:
                delta_true.append(grad)

        snapshot = probe.build_snapshot(
            x_list,
            h_list,
            u_list,
            g_list,
            s_list,
            delta_true,
            delta_approx_list,
            self.W_hh,
            loss,
        )
        return float(loss.item()), snapshot


class PhysicalProbe:
    def __init__(self, device: torch.device, step_weights: torch.Tensor | None = None) -> None:
        self.device = device
        self.step_weights = step_weights

    def register_grad_hook(
        self,
        x_t: torch.Tensor,
        delta_true_list: List[torch.Tensor | None],
        idx: int,
    ) -> None:
        def _hook(grad: torch.Tensor, slot: int = idx) -> None:
            delta_true_list[slot] = grad.detach()

        x_t.register_hook(_hook)

    def build_snapshot(
        self,
        x_list: List[torch.Tensor],
        h_list: List[torch.Tensor],
        u_list: List[torch.Tensor],
        g_list: List[torch.Tensor],
        s_list: List[torch.Tensor],
        delta_true_list: List[torch.Tensor],
        delta_approx_list: List[torch.Tensor],
        W_hh: torch.Tensor,
        loss: torch.Tensor,
    ) -> ProbeSnapshot:
        x_stack = torch.stack(x_list, dim=0)
        h_stack = torch.stack(h_list, dim=0)
        u_stack = torch.stack(u_list, dim=0)
        g_stack = torch.stack(g_list, dim=0)
        s_stack = torch.stack(s_list, dim=0)
        delta_true_stack = torch.stack(delta_true_list, dim=0)
        delta_approx_stack = torch.stack(delta_approx_list, dim=0)

        return ProbeSnapshot(
            x=x_stack.detach().cpu().numpy(),
            h=h_stack.detach().cpu().numpy(),
            u=u_stack.detach().cpu().numpy(),
            g=g_stack.detach().cpu().numpy(),
            s=s_stack.detach().cpu().numpy(),
            delta_true=delta_true_stack.detach().cpu().numpy(),
            delta_approx=delta_approx_stack.detach().cpu().numpy(),
            W_hh=W_hh.detach().cpu().numpy(),
            alpha_hat=np.zeros((W_hh.shape[0],), dtype=np.float32),
            lambda_vals=np.zeros((W_hh.shape[0],), dtype=np.float32),
            loss=float(loss.item()),
        )


def seed_everything(seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def load_row_mnist_sequences(
    train_limit: int | None = None,
    test_limit: int | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    train_images, train_labels, test_images, test_labels = load_mnist_images(
        train_limit=train_limit,
        test_limit=test_limit,
    )
    train_inputs = np.transpose(train_images, (0, 2, 1)).astype(np.float32)
    test_inputs = np.transpose(test_images, (0, 2, 1)).astype(np.float32)
    time_steps = train_inputs.shape[2]
    train_targets = build_repeated_targets(train_labels, 10, time_steps)
    test_targets = build_repeated_targets(test_labels, 10, time_steps)
    return train_inputs, train_targets, train_labels, test_inputs, test_targets, test_labels


def load_lorenz_image_sequences(
    train_samples: int,
    test_samples: int,
    seq_len: int,
    frame_h: int,
    frame_w: int,
    dt: float,
    sigma: float,
    rho: float,
    beta: float,
    warmup: int,
    seed: int,
    blur_sigma: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    total_samples = train_samples + test_samples
    inputs, targets = generate_lorenz_sequences(
        num_samples=total_samples,
        seq_len=seq_len,
        frame_h=frame_h,
        frame_w=frame_w,
        dt=dt,
        sigma=sigma,
        rho=rho,
        beta=beta,
        warmup=warmup,
        seed=seed,
        blur_sigma=blur_sigma,
    )
    train_inputs = inputs[:train_samples]
    train_targets = targets[:train_samples]
    test_inputs = inputs[train_samples:]
    test_targets = targets[train_samples:]
    return train_inputs, train_targets, test_inputs, test_targets


def load_adding_problem_sequences(
    train_samples: int,
    test_samples: int,
    seq_len: int,
    seed: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    total = int(train_samples) + int(test_samples)
    if total <= 0:
        raise ValueError("train_samples + test_samples must be positive.")
    seq_len = max(2, int(seq_len))
    rng = np.random.default_rng(seed)

    values = rng.random((total, seq_len), dtype=np.float32)
    selector = rng.random((total, seq_len), dtype=np.float32)
    pos = np.argsort(selector, axis=1)[:, :2]
    markers = np.zeros((total, seq_len), dtype=np.float32)
    markers[np.arange(total)[:, None], pos] = 1.0

    inputs = np.stack([values, markers], axis=1)
    running_sum = np.cumsum(values * markers, axis=1, dtype=np.float32)
    targets = running_sum[:, None, :]

    train_inputs = inputs[:train_samples]
    train_targets = targets[:train_samples]
    test_inputs = inputs[train_samples:]
    test_targets = targets[train_samples:]
    return train_inputs, train_targets, test_inputs, test_targets


def compute_cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
    a_flat = a.reshape(-1)
    b_flat = b.reshape(-1)
    denom = float(np.linalg.norm(a_flat) * np.linalg.norm(b_flat))
    if denom <= 0:
        return float("nan")
    return float(np.dot(a_flat, b_flat) / denom)


def flatten_core_params(params: Dict[str, np.ndarray]) -> np.ndarray:
    parts = []
    for name in ("W_hh", "W_xh", "b_h", "W_hy", "b_y"):
        arr = np.asarray(params[name], dtype=np.float64)
        parts.append(arr.reshape(-1))
    if not parts:
        return np.array([], dtype=np.float64)
    return np.concatenate(parts, axis=0)


def build_random_2d_basis(dim: int, rng: np.random.Generator) -> Tuple[np.ndarray, np.ndarray]:
    dim = int(dim)
    if dim <= 0:
        raise ValueError("dim must be positive.")
    v1 = rng.standard_normal(dim).astype(np.float64)
    v1 = v1 / (np.linalg.norm(v1) + 1e-12)
    v2 = rng.standard_normal(dim).astype(np.float64)
    v2 = v2 - v1 * float(np.dot(v2, v1))
    v2 = v2 / (np.linalg.norm(v2) + 1e-12)
    return v1, v2


def build_endpoints_2d_basis(
    theta0: np.ndarray,
    theta_a: np.ndarray,
    theta_b: np.ndarray,
    rng: np.random.Generator,
) -> Tuple[np.ndarray, np.ndarray]:
    theta0 = np.asarray(theta0, dtype=np.float64).reshape(-1)
    theta_a = np.asarray(theta_a, dtype=np.float64).reshape(-1)
    theta_b = np.asarray(theta_b, dtype=np.float64).reshape(-1)
    if theta_a.shape != theta0.shape or theta_b.shape != theta0.shape:
        raise ValueError("theta vectors must have the same shape.")

    eps = 1e-12
    v1 = theta_a - theta0
    n1 = float(np.linalg.norm(v1))
    if not np.isfinite(n1) or n1 <= eps:
        return build_random_2d_basis(theta0.size, rng)
    v1 = v1 / max(n1, eps)

    v2 = theta_b - theta0
    v2 = v2 - v1 * float(np.dot(v2, v1))
    n2 = float(np.linalg.norm(v2))
    if not np.isfinite(n2) or n2 <= eps:
        return build_random_2d_basis(theta0.size, rng)
    v2 = v2 / max(n2, eps)
    return v1, v2


def build_pca_2d_basis(
    theta_series: np.ndarray,
    theta0: np.ndarray,
    rng: np.random.Generator,
) -> Tuple[np.ndarray, np.ndarray]:
    theta0 = np.asarray(theta0, dtype=np.float64).reshape(-1)
    series = np.asarray(theta_series, dtype=np.float64)
    if series.ndim != 2 or series.shape[1] != theta0.size:
        raise ValueError("theta_series must have shape (n, dim) matching theta0.")
    deltas = series - theta0[None, :]
    deltas = deltas[np.all(np.isfinite(deltas), axis=1)]
    if deltas.shape[0] < 2:
        return build_random_2d_basis(theta0.size, rng)
    try:
        _, _, vt = np.linalg.svd(deltas, full_matrices=False)
    except np.linalg.LinAlgError:
        return build_random_2d_basis(theta0.size, rng)
    if vt.shape[0] < 2:
        return build_random_2d_basis(theta0.size, rng)
    v1 = vt[0]
    v2 = vt[1]
    v1 = v1 / (np.linalg.norm(v1) + 1e-12)
    v2 = v2 - v1 * float(np.dot(v2, v1))
    v2 = v2 / (np.linalg.norm(v2) + 1e-12)
    return v1, v2


def project_to_basis(theta: np.ndarray, theta0: np.ndarray, v1: np.ndarray, v2: np.ndarray) -> np.ndarray:
    theta = np.asarray(theta, dtype=np.float64).reshape(-1)
    theta0 = np.asarray(theta0, dtype=np.float64).reshape(-1)
    if theta.shape != theta0.shape:
        raise ValueError("theta and theta0 must have the same shape.")
    d = theta - theta0
    return np.array([float(np.dot(d, v1)), float(np.dot(d, v2))], dtype=np.float64)


def evaluate_sequence_loss(
    model: Any,
    inputs: np.ndarray,
    targets: np.ndarray,
    loss_mode: str,
) -> float:
    if loss_mode not in {"ce", "mse"}:
        raise ValueError("loss_mode must be 'ce' or 'mse'.")
    device = DEFAULT_DEVICE
    if hasattr(model, "device"):
        device = getattr(model, "device")
    if hasattr(model, "W_hh") and torch.is_tensor(getattr(model, "W_hh")):
        device = getattr(model, "W_hh").device

    inputs_t = torch.as_tensor(inputs, device=device, dtype=torch.float32)
    targets_t = torch.as_tensor(targets, device=device, dtype=torch.float32)
    batch_size = int(inputs_t.shape[0])
    h_prev = torch.zeros((int(getattr(model, "hidden_size")), batch_size), device=device, dtype=torch.float32)
    outputs_seq, _ = model.forward_cycle(inputs_t, h_prev)
    outputs = torch.stack(outputs_seq, dim=2)  # (out, batch, time)
    y_true = targets_t.permute(1, 0, 2)
    if loss_mode == "ce":
        log_probs = torch.log_softmax(outputs, dim=0)
        loss = -(y_true * log_probs).sum(dim=0).mean()
    else:
        error = outputs - y_true
        loss = 0.5 * torch.sum(error**2, dim=0).mean()
    return float(loss.detach().cpu().item())


def compute_lyapunov_diagnostics_numpy(
    model: Any,
    driver_input: np.ndarray | torch.Tensor,
    seed: int,
    min_steps: int = 50,
) -> Dict[str, Any]:
    def _to_numpy(value: Any) -> np.ndarray:
        if torch.is_tensor(value):
            return value.detach().cpu().numpy()
        return np.asarray(value)

    W_hh = _to_numpy(getattr(model, "W_hh")).astype(np.float64, copy=False)
    W_xh = _to_numpy(getattr(model, "W_xh")).astype(np.float64, copy=False)
    b_h = _to_numpy(getattr(model, "b_h")).astype(np.float64, copy=False).reshape(-1, 1)
    hidden = int(W_hh.shape[0])

    driver = _to_numpy(driver_input).astype(np.float64, copy=False)
    if driver.ndim == 1:
        driver = driver[:, None]
    if driver.ndim != 2:
        raise ValueError("driver_input must be 1D (input_size,) or 2D (input_size, time_steps).")
    time_steps = int(driver.shape[1])
    if time_steps < min_steps:
        reps = int(math.ceil(min_steps / max(time_steps, 1)))
        driver = np.tile(driver, (1, reps))
        time_steps = int(driver.shape[1])

    # Weight-only summaries.
    try:
        rho_w = float(np.max(np.abs(np.linalg.eigvals(W_hh))))
    except Exception:
        rho_w = float("nan")
    try:
        sigma_w = float(np.linalg.svd(W_hh, compute_uv=False)[0])
    except Exception:
        sigma_w = float("nan")

    # Stepwise diagnostics along a driven trajectory.
    h = np.zeros((hidden, 1), dtype=np.float64)
    rng = np.random.default_rng(int(seed))
    v = rng.standard_normal((hidden, 1)).astype(np.float64)
    v_norm0 = float(np.linalg.norm(v))
    if not np.isfinite(v_norm0) or v_norm0 <= 0.0:
        v = np.ones((hidden, 1), dtype=np.float64)
        v_norm0 = float(np.linalg.norm(v))
    v = v / max(v_norm0, 1e-12)

    log_growth: List[float] = []
    phi_max: List[float] = []
    phi_mean: List[float] = []
    sat_frac: List[float] = []
    eps = 1e-12

    for t in range(time_steps):
        I_t = driver[:, t].reshape(-1, 1)
        x = W_hh @ h + W_xh @ I_t + b_h
        h = np.tanh(x)
        phi = 1.0 - h * h

        phi_max.append(float(np.max(phi)))
        phi_mean.append(float(np.mean(phi)))
        sat_frac.append(float(np.mean(np.abs(h) >= 0.95)))

        J = phi * W_hh
        v_next = J @ v
        n = float(np.linalg.norm(v_next))
        log_growth.append(float(np.log(max(n, eps))))
        v = v_next / max(n, eps)

    lyap_pi = float(np.mean(log_growth)) if log_growth else float("nan")
    lyap_qr = float(calculate_lyapunov_exponent_numpy(model, driver))

    phi_max_arr = np.asarray(phi_max, dtype=np.float64)
    log_phi_max_mean = float(np.mean(np.log(np.clip(phi_max_arr, eps, None)))) if phi_max_arr.size else float("nan")
    upper_bound = (
        log_phi_max_mean + float(np.log(max(sigma_w, eps))) if np.isfinite(log_phi_max_mean) else float("nan")
    )

    return {
        "lyap_qr": lyap_qr,
        "lyap_power_iter": lyap_pi,
        "contraction_factor": float(np.exp(lyap_pi)) if np.isfinite(lyap_pi) else float("nan"),
        "rho_W_hh": rho_w,
        "sigma_W_hh": sigma_w,
        "phi_max": phi_max_arr,
        "phi_mean": np.asarray(phi_mean, dtype=np.float64),
        "sat_frac": np.asarray(sat_frac, dtype=np.float64),
        "log_growth": np.asarray(log_growth, dtype=np.float64),
        "upper_bound_log_phiMax_plus_log_sigmaW": upper_bound,
    }


def choose_probe_steps(total_updates: int, points: int) -> List[int]:
    if total_updates <= 0:
        return [0]
    points = max(2, points)
    steps = np.linspace(0, total_updates, num=points, dtype=int)
    steps = sorted(set(int(s) for s in steps))
    if steps[0] != 0:
        steps.insert(0, 0)
    if steps[-1] != total_updates:
        steps.append(total_updates)
    return steps


def select_probe_step(
    probe_snapshots: Dict[int, ProbeSnapshot],
    spectral_curve: List[Dict[str, float]],
    select_mode: str,
    *,
    explicit_step: int | None = None,
    rho_target: float = 1.0,
    exclude_step0: bool = True,
) -> int:
    if not probe_snapshots:
        raise RuntimeError("No probe snapshots available.")

    available_steps = sorted(int(s) for s in probe_snapshots.keys())
    if select_mode == "final":
        return available_steps[-1]

    if select_mode == "step":
        if explicit_step is None:
            raise ValueError("explicit_step must be provided when select_mode='step'.")
        explicit_step = int(explicit_step)
        return min(available_steps, key=lambda s: abs(s - explicit_step))

    if select_mode == "rho_target":
        candidates = spectral_curve
        if exclude_step0:
            candidates = [item for item in candidates if int(item.get("step", -1)) != 0]
            if not candidates:
                candidates = spectral_curve
        candidates = [
            item
            for item in candidates
            if int(item.get("step", -1)) in probe_snapshots
            and np.isfinite(float(item.get("rho_max", float("nan"))))
        ]
        if candidates:
            best = min(
                candidates,
                key=lambda item: abs(float(item["rho_max"]) - float(rho_target)),
            )
            return int(best["step"])
        return available_steps[-1]

    raise ValueError("select_mode must be 'rho_target', 'final', or 'step'.")


def build_plot_dir(root: Path, task_name: str) -> Path:
    stamp = time.strftime("%Y%m%d_%H%M%S")
    plot_dir = root / "plots" / f"oll_physical_probe_{task_name}" / stamp
    plot_dir.mkdir(parents=True, exist_ok=True)
    return plot_dir


def save_figure(fig: Any, path: Path) -> None:
    fig.savefig(path, dpi=300, bbox_inches="tight", pad_inches=0.2)


_PLOT_STYLE_SET = False


def set_plot_style() -> None:
    global _PLOT_STYLE_SET
    if _PLOT_STYLE_SET:
        return
    import matplotlib as mpl

    mpl.rcParams.update(
        {
            "font.family": "serif",
            "font.serif": ["Times New Roman", "Times", "Nimbus Roman", "DejaVu Serif"],
            "mathtext.fontset": "stix",
            "axes.titlesize": 13,
            "axes.labelsize": 12,
            "axes.titleweight": "bold",
            "axes.edgecolor": "0.2",
            "axes.linewidth": 1.0,
            "xtick.labelsize": 10,
            "ytick.labelsize": 10,
            "legend.fontsize": 10,
            "legend.frameon": False,
            "grid.color": "0.85",
            "grid.linestyle": ":",
            "grid.linewidth": 0.8,
            "lines.linewidth": 1.8,
            "lines.markersize": 4.5,
            "figure.facecolor": "white",
            "axes.facecolor": "white",
            "savefig.facecolor": "white",
            "savefig.edgecolor": "white",
        }
    )
    _PLOT_STYLE_SET = True


def select_neighbors(W_hh: np.ndarray, center: int, k: int) -> np.ndarray:
    row = np.abs(W_hh[center])
    row[center] = -np.inf
    idx = np.argsort(row)[::-1][:k]
    return idx.astype(int)


def build_neighbors_map(W_hh: np.ndarray, centers: Iterable[int], k: int) -> Dict[int, np.ndarray]:
    return {int(center): select_neighbors(W_hh, int(center), k) for center in centers}


def sample_time_points(
    s_stack: np.ndarray,
    rng: np.random.Generator,
    max_points: int,
) -> Tuple[np.ndarray, np.ndarray]:
    time_steps, hidden, batch = s_stack.shape
    total = time_steps * batch
    if total <= 0:
        raise ValueError("No samples available for probe.")
    count = min(max_points, total)
    flat_indices = rng.choice(total, size=count, replace=False)
    t_idx = flat_indices // batch
    b_idx = flat_indices % batch
    s_samples = s_stack[t_idx, :, b_idx].T
    return s_samples, flat_indices


def compute_rank1_analysis(
    snapshot: ProbeSnapshot,
    centers: List[int],
    k_neighbors: int,
    time_points: int,
    rng: np.random.Generator,
) -> Dict[str, Any]:
    s_samples, _ = sample_time_points(snapshot.s, rng, time_points)
    results: Dict[str, Any] = {"centers": centers, "neighbors": {}, "evr": {}, "pc1_cos": {}}
    for center in centers:
        neighbors = select_neighbors(snapshot.W_hh, center, k_neighbors)
        M = s_samples[neighbors]
        _, svals, vh = np.linalg.svd(M, full_matrices=False)
        svals = np.asarray(svals, dtype=np.float64)
        svals_sq = svals * svals
        total = float(np.sum(svals_sq))
        evr = (svals_sq / total) if total > 0.0 else np.zeros_like(svals_sq)
        pc1_cos = float("nan")
        if M.shape[1] >= 2:
            half = max(1, M.shape[1] // 2)
            if half >= 1 and (M.shape[1] - half) >= 1:
                try:
                    u_a, _, _ = np.linalg.svd(M[:, :half], full_matrices=False)
                    u_b, _, _ = np.linalg.svd(M[:, half:], full_matrices=False)
                    if u_a.shape[1] and u_b.shape[1]:
                        pc1_a = u_a[:, 0]
                        pc1_b = u_b[:, 0]
                        pc1_cos = float(np.dot(pc1_a, pc1_b))
                except np.linalg.LinAlgError:
                    pc1_cos = float("nan")
        results["neighbors"][center] = neighbors.tolist()
        results["evr"][center] = evr.tolist()
        results["pc1_cos"][center] = abs(float(pc1_cos)) if np.isfinite(pc1_cos) else float("nan")
    return results


def compute_eigen_residual_stats(
    snapshot: ProbeSnapshot,
    centers: List[int],
    k_neighbors: int,
    rng: np.random.Generator,
    neighbors_map: Dict[int, np.ndarray] | None = None,
) -> Dict[str, Any]:
    del rng
    s_stack = snapshot.s
    time_steps, hidden, batch = s_stack.shape
    if time_steps == 0 or hidden == 0 or batch == 0:
        return {
            "centers": centers,
            "k_neighbors": k_neighbors,
            "window": 0,
            "residuals": np.array([], dtype=np.float64),
            "mu_vals": np.array([], dtype=np.float64),
            "eig_aligns": np.array([], dtype=np.float64),
            "svd_aligns": np.array([], dtype=np.float64),
            "sigma_min": np.array([], dtype=np.float64),
            "residual_ratio": np.array([], dtype=np.float64),
            "residual_mean": float("nan"),
            "residual_median": float("nan"),
            "mu_mean": float("nan"),
            "mu_median": float("nan"),
            "eig_align_mean": float("nan"),
            "eig_align_median": float("nan"),
            "svd_align_mean": float("nan"),
            "svd_align_median": float("nan"),
            "sigma_min_mean": float("nan"),
            "sigma_min_median": float("nan"),
            "residual_ratio_mean": float("nan"),
            "residual_ratio_median": float("nan"),
        }

    window = min(10, time_steps)
    s_window = s_stack[time_steps - window : time_steps]
    s_matrix = np.transpose(s_window, (1, 0, 2)).reshape(hidden, -1)

    residuals: List[float] = []
    mu_vals: List[float] = []
    eig_aligns: List[float] = []
    svd_aligns: List[float] = []
    sigma_mins: List[float] = []
    residual_ratios: List[float] = []
    eps = 1e-12

    for center in centers:
        if neighbors_map is None:
            neighbors = select_neighbors(snapshot.W_hh, center, k_neighbors)
        else:
            neighbors = np.asarray(neighbors_map.get(center, []), dtype=int)
        if neighbors.size == 0:
            continue
        M = s_matrix[neighbors].astype(np.float64, copy=False)
        if M.size == 0:
            continue
        try:
            u_mat, _, _ = np.linalg.svd(M, full_matrices=False)
        except np.linalg.LinAlgError:
            continue
        if u_mat.shape[1] == 0:
            continue
        # Dominant spatial mode of the local source term window.
        q = u_mat[:, 0]

        u_vec = np.asarray(snapshot.u[-1, neighbors, :].mean(axis=1), dtype=np.float64)
        alpha_vec = np.asarray(snapshot.alpha_hat[neighbors], dtype=np.float64)
        W_local = np.asarray(snapshot.W_hh[neighbors][:, neighbors], dtype=np.float64)
        A_local = (np.diag(u_vec) @ W_local.T) @ np.diag(alpha_vec)

        v_exact = A_local @ q
        mu = float(np.dot(q, v_exact))
        v_approx = mu * q
        r = v_exact - v_approx
        denom = np.linalg.norm(v_exact) + eps
        residuals.append(float(np.linalg.norm(r) / denom))
        mu_vals.append(mu)

        sigma_min = float("nan")
        residual_ratio = float("nan")
        eye = np.eye(A_local.shape[0], dtype=np.float64)
        try:
            svals = np.linalg.svd(eye - A_local, compute_uv=False)
            if svals.size:
                sigma_min = float(svals[-1])
        except np.linalg.LinAlgError:
            sigma_min = float("nan")
        try:
            res_solve = np.linalg.solve(eye - A_local, r)
            residual_ratio = float(np.linalg.norm(res_solve) / (np.linalg.norm(q) + eps))
        except np.linalg.LinAlgError:
            residual_ratio = float("nan")
        sigma_mins.append(sigma_min)
        residual_ratios.append(residual_ratio)

        eig_align = float("nan")
        try:
            eigvals, eigvecs = np.linalg.eig(A_local)
            if eigvals.size:
                idx = int(np.argmax(np.abs(eigvals)))
                v_dom = eigvecs[:, idx]
                eig_align = float(
                    np.abs(np.vdot(v_dom, q)) / (np.linalg.norm(v_dom) * np.linalg.norm(q) + eps)
                )
        except np.linalg.LinAlgError:
            eig_align = float("nan")
        eig_aligns.append(eig_align)

        svd_align = float("nan")
        try:
            _, _, vh = np.linalg.svd(A_local, full_matrices=False)
            if vh.shape[0]:
                v1 = vh[0]
                svd_align = float(
                    np.abs(np.dot(v1, q)) / (np.linalg.norm(v1) * np.linalg.norm(q) + eps)
                )
        except np.linalg.LinAlgError:
            svd_align = float("nan")
        svd_aligns.append(svd_align)

    residual_arr = np.array(residuals, dtype=np.float64)
    mu_arr = np.array(mu_vals, dtype=np.float64)
    eig_align_arr = np.array(eig_aligns, dtype=np.float64)
    svd_align_arr = np.array(svd_aligns, dtype=np.float64)
    sigma_min_arr = np.array(sigma_mins, dtype=np.float64)
    residual_ratio_arr = np.array(residual_ratios, dtype=np.float64)
    residual_valid = residual_arr[np.isfinite(residual_arr)]
    mu_valid = mu_arr[np.isfinite(mu_arr)]
    eig_align_valid = eig_align_arr[np.isfinite(eig_align_arr)]
    svd_align_valid = svd_align_arr[np.isfinite(svd_align_arr)]
    sigma_min_valid = sigma_min_arr[np.isfinite(sigma_min_arr)]
    residual_ratio_valid = residual_ratio_arr[np.isfinite(residual_ratio_arr)]
    return {
        "centers": centers,
        "k_neighbors": k_neighbors,
        "window": window,
        "residuals": residual_arr,
        "mu_vals": mu_arr,
        "eig_aligns": eig_align_arr,
        "svd_aligns": svd_align_arr,
        "sigma_min": sigma_min_arr,
        "residual_ratio": residual_ratio_arr,
        "residual_mean": float(np.mean(residual_valid)) if residual_valid.size else float("nan"),
        "residual_median": float(np.median(residual_valid)) if residual_valid.size else float("nan"),
        "mu_mean": float(np.mean(mu_valid)) if mu_valid.size else float("nan"),
        "mu_median": float(np.median(mu_valid)) if mu_valid.size else float("nan"),
        "eig_align_mean": float(np.mean(eig_align_valid)) if eig_align_valid.size else float("nan"),
        "eig_align_median": float(np.median(eig_align_valid)) if eig_align_valid.size else float("nan"),
        "svd_align_mean": float(np.mean(svd_align_valid)) if svd_align_valid.size else float("nan"),
        "svd_align_median": float(np.median(svd_align_valid)) if svd_align_valid.size else float("nan"),
        "sigma_min_mean": float(np.mean(sigma_min_valid)) if sigma_min_valid.size else float("nan"),
        "sigma_min_median": float(np.median(sigma_min_valid)) if sigma_min_valid.size else float("nan"),
        "residual_ratio_mean": float(np.mean(residual_ratio_valid)) if residual_ratio_valid.size else float("nan"),
        "residual_ratio_median": float(np.median(residual_ratio_valid)) if residual_ratio_valid.size else float("nan"),
    }


def compute_dynamic_rank_stats(
    snapshot: ProbeSnapshot,
    centers: List[int],
    k_neighbors: int,
    window: int,
    stride: int,
) -> Dict[str, Any]:
    s_stack = snapshot.s
    time_steps, hidden, batch = s_stack.shape
    window = max(2, min(int(window), time_steps))
    stride = max(1, int(stride))

    windows: List[np.ndarray] = []
    for start in range(0, time_steps - window + 1, stride):
        s_window = s_stack[start : start + window].transpose(1, 0, 2).reshape(hidden, -1)
        windows.append(s_window)
    if not windows:
        windows = [s_stack.transpose(1, 0, 2).reshape(hidden, -1)]

    pc1_map: Dict[int, List[float]] = {}
    top2_map: Dict[int, List[float]] = {}
    rank90_map: Dict[int, List[int]] = {}
    neighbors_map: Dict[int, List[int]] = {}

    for center in centers:
        neighbors = select_neighbors(snapshot.W_hh, center, k_neighbors)
        neighbors_map[center] = neighbors.tolist()
        pc1_vals: List[float] = []
        top2_vals: List[float] = []
        rank90_vals: List[int] = []
        for samples in windows:
            M = samples[neighbors]
            _, svals, _ = np.linalg.svd(M, full_matrices=False)
            svals = np.asarray(svals, dtype=np.float64)
            svals_sq = svals * svals
            total = float(np.sum(svals_sq))
            evr = (svals_sq / total) if total > 0.0 else np.zeros_like(svals_sq)
            pc1_vals.append(float(evr[0]))
            top2 = float(np.sum(evr[:2])) if evr.size >= 2 else float(evr[0])
            top2_vals.append(top2)
            rank90 = int(np.searchsorted(np.cumsum(evr), 0.90, side="left") + 1)
            rank90 = min(rank90, int(evr.size))
            rank90_vals.append(rank90)
        pc1_map[center] = pc1_vals
        top2_map[center] = top2_vals
        rank90_map[center] = rank90_vals

    pc1_medians = [np.median(v) for v in pc1_map.values()] if pc1_map else [float("nan")]
    top2_medians = [np.median(v) for v in top2_map.values()] if top2_map else [float("nan")]
    rank90_medians = [np.median(v) for v in rank90_map.values()] if rank90_map else [float("nan")]

    return {
        "centers": centers,
        "neighbors": neighbors_map,
        "k_neighbors": k_neighbors,
        "window": window,
        "stride": stride,
        "pc1": pc1_map,
        "top2": top2_map,
        "rank90": rank90_map,
        "pc1_median": float(np.median(pc1_medians)),
        "top2_median": float(np.median(top2_medians)),
        "rank90_median": float(np.median(rank90_medians)),
        "pc1_mean": float(np.mean(pc1_medians)),
        "top2_mean": float(np.mean(top2_medians)),
        "rank90_mean": float(np.mean(rank90_medians)),
    }


def compute_ar1_stats(
    h_stack: np.ndarray,
) -> Dict[str, np.ndarray]:
    time_steps, hidden, batch = h_stack.shape
    if time_steps < 2:
        raise ValueError("Need at least 2 time steps for AR(1) analysis.")
    h_prev = h_stack[:-1].transpose(1, 0, 2).reshape(hidden, -1)
    h_curr = h_stack[1:].transpose(1, 0, 2).reshape(hidden, -1)
    eps = 1e-12

    alpha = np.sum(h_prev * h_curr, axis=1) / (np.sum(h_prev**2, axis=1) + eps)
    pred_zero = alpha[:, None] * h_prev
    sse_zero = np.sum((h_curr - pred_zero) ** 2, axis=1)
    mean_curr = np.mean(h_curr, axis=1, keepdims=True)
    sst = np.sum((h_curr - mean_curr) ** 2, axis=1) + eps
    r2_zero = 1.0 - sse_zero / sst

    mean_prev = np.mean(h_prev, axis=1, keepdims=True)
    alpha_centered = np.sum((h_prev - mean_prev) * (h_curr - mean_curr), axis=1) / (
        np.sum((h_prev - mean_prev) ** 2, axis=1) + eps
    )
    intercept = mean_curr - alpha_centered[:, None] * mean_prev
    pred_centered = alpha_centered[:, None] * h_prev + intercept
    sse_centered = np.sum((h_curr - pred_centered) ** 2, axis=1)
    r2_centered = 1.0 - sse_centered / sst

    corr = []
    for i in range(hidden):
        a = h_prev[i] - np.mean(h_prev[i])
        b = h_curr[i] - np.mean(h_curr[i])
        denom = np.linalg.norm(a) * np.linalg.norm(b)
        corr.append(float(np.dot(a, b) / (denom + eps)))
    return {
        "alpha": alpha,
        "alpha_centered": alpha_centered,
        "r2_zero": r2_zero,
        "r2_centered": r2_centered,
        "corr": np.array(corr),
    }


def compute_gradient_duality(
    delta_true: np.ndarray,
) -> np.ndarray:
    time_steps, hidden, batch = delta_true.shape
    if time_steps < 2:
        raise ValueError("Need at least 2 time steps for gradient duality.")
    d_prev = delta_true[:-1].transpose(1, 0, 2).reshape(hidden, -1)
    d_curr = delta_true[1:].transpose(1, 0, 2).reshape(hidden, -1)
    eps = 1e-12
    corr = []
    for i in range(hidden):
        a = d_prev[i] - np.mean(d_prev[i])
        b = d_curr[i] - np.mean(d_curr[i])
        denom = np.linalg.norm(a) * np.linalg.norm(b)
        corr.append(float(np.dot(a, b) / (denom + eps)))
    return np.array(corr)


def compute_smoothed_autocorr(
    values: np.ndarray,
    rho: float,
) -> np.ndarray:
    time_steps, hidden, batch = values.shape
    if time_steps < 2:
        raise ValueError("Need at least 2 time steps for autocorrelation.")
    rho = float(min(max(rho, 0.0), 0.999))
    series = values.mean(axis=2)
    smoothed = np.zeros_like(series)
    smoothed[0] = series[0]
    for t in range(1, time_steps):
        smoothed[t] = rho * smoothed[t - 1] + (1.0 - rho) * series[t]

    prev = smoothed[:-1]
    curr = smoothed[1:]
    eps = 1e-12
    corr = []
    for i in range(hidden):
        a = prev[:, i] - np.mean(prev[:, i])
        b = curr[:, i] - np.mean(curr[:, i])
        denom = np.linalg.norm(a) * np.linalg.norm(b)
        corr.append(float(np.dot(a, b) / (denom + eps)))
    return np.array(corr)


def compute_lambda_consistency(
    snapshot: ProbeSnapshot,
    alpha_hat: np.ndarray,
) -> Dict[str, np.ndarray]:
    g = snapshot.g
    u = snapshot.u
    time_steps, hidden, batch = g.shape
    if time_steps < 2 or batch <= 0:
        raise ValueError("Need at least 2 time steps for lambda consistency.")
    g_prev = g[:-1]
    g_curr = g[1:]
    u_prev = u[:-1]
    u_curr = u[1:]
    alpha = np.asarray(alpha_hat, dtype=np.float64).reshape(1, hidden, 1)

    A = u_prev * u_curr * (alpha * g_prev - g_curr)
    B = alpha * u_prev * g_prev - u_curr * g_curr
    A_flat = A.transpose(1, 0, 2).reshape(hidden, -1).astype(np.float64, copy=False)
    B_flat = B.transpose(1, 0, 2).reshape(hidden, -1).astype(np.float64, copy=False)
    eps = 1e-12
    denom = np.sum(A_flat**2, axis=1) + eps
    lambda_hat = np.sum(A_flat * B_flat, axis=1) / denom
    pred = lambda_hat[:, None] * A_flat
    sse = np.sum((B_flat - pred) ** 2, axis=1)
    mean_b = np.mean(B_flat, axis=1, keepdims=True)
    sst = np.sum((B_flat - mean_b) ** 2, axis=1) + eps
    r2 = 1.0 - sse / sst
    a = A_flat - np.mean(A_flat, axis=1, keepdims=True)
    b = B_flat - np.mean(B_flat, axis=1, keepdims=True)
    denom_corr = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1)
    corr = np.sum(a * b, axis=1) / (denom_corr + eps)
    return {
        "lambda_hat": lambda_hat,
        "r2": r2,
        "corr": corr,
    }


def compute_lambda_volatility(
    probe_snapshots: Dict[int, ProbeSnapshot],
) -> Dict[str, Any]:
    if not probe_snapshots:
        return {
            "steps": [],
            "lambda_series": np.array([]),
            "volatility": np.array([]),
            "volatility_excl0": np.array([]),
        }
    steps = sorted(probe_snapshots.keys())
    lambda_series = np.stack([probe_snapshots[step].lambda_vals for step in steps], axis=0)
    volatility = np.std(lambda_series, axis=0)
    volatility_excl0 = np.std(lambda_series[1:], axis=0) if len(steps) > 1 else np.zeros_like(volatility)
    return {
        "steps": steps,
        "lambda_series": lambda_series,
        "volatility": volatility,
        "volatility_excl0": volatility_excl0,
    }


def _spearman_corr(x: np.ndarray, y: np.ndarray) -> float:
    x = np.asarray(x, dtype=np.float64).reshape(-1)
    y = np.asarray(y, dtype=np.float64).reshape(-1)
    mask = np.isfinite(x) & np.isfinite(y)
    if int(np.sum(mask)) < 2:
        return float("nan")
    x = x[mask]
    y = y[mask]
    x_rank = np.argsort(np.argsort(x)).astype(np.float64)
    y_rank = np.argsort(np.argsort(y)).astype(np.float64)
    x_rank -= float(np.mean(x_rank))
    y_rank -= float(np.mean(y_rank))
    denom = float(np.linalg.norm(x_rank) * np.linalg.norm(y_rank))
    return float(np.dot(x_rank, y_rank) / (denom + 1e-12))


def _spearman_perm_test(
    x: np.ndarray,
    y: np.ndarray,
    rng: np.random.Generator,
    max_exact_n: int = 8,
    n_perm: int = 2000,
) -> Dict[str, float]:
    x = np.asarray(x, dtype=np.float64).reshape(-1)
    y = np.asarray(y, dtype=np.float64).reshape(-1)
    mask = np.isfinite(x) & np.isfinite(y)
    x = x[mask]
    y = y[mask]
    n = int(x.size)
    rho_obs = _spearman_corr(x, y)
    if not np.isfinite(rho_obs) or n < 2:
        return {"n": float(n), "rho": float("nan"), "p": float("nan")}

    if n <= int(max_exact_n):
        import itertools

        idx = np.arange(n)
        perms = itertools.permutations(idx.tolist())
        total = 0
        extreme = 0
        for perm in perms:
            total += 1
            rho = _spearman_corr(x, y[list(perm)])
            if abs(rho) >= abs(rho_obs):
                extreme += 1
        p = float(extreme / max(1, total))
        return {"n": float(n), "rho": float(rho_obs), "p": p}

    total = int(max(1, n_perm))
    extreme = 0
    for _ in range(total):
        rho = _spearman_corr(x, y[rng.permutation(n)])
        if abs(rho) >= abs(rho_obs):
            extreme += 1
    p = float((extreme + 1) / (total + 1))
    return {"n": float(n), "rho": float(rho_obs), "p": p}


def compute_eigen_alignment_curve(
    probe_snapshots: Dict[int, ProbeSnapshot],
    centers: List[int],
    k_neighbors: int,
    rng: np.random.Generator,
    neighbors_map: Dict[int, np.ndarray] | None = None,
) -> Dict[str, Any]:
    if not probe_snapshots:
        return {
            "steps": [],
            "residual_median": [],
            "eig_align_median": [],
            "svd_align_median": [],
            "mu_median": [],
            "sigma_min_median": [],
            "residual_ratio_median": [],
            "counts": [],
            "spearman_residual_median": {"n": float("nan"), "rho": float("nan"), "p": float("nan")},
            "spearman_eig_align_median": {"n": float("nan"), "rho": float("nan"), "p": float("nan")},
            "spearman_svd_align_median": {"n": float("nan"), "rho": float("nan"), "p": float("nan")},
            "spearman_residual_ratio_median": {"n": float("nan"), "rho": float("nan"), "p": float("nan")},
        }
    steps = sorted(probe_snapshots.keys())
    used_steps: List[int] = []
    residual_median: List[float] = []
    eig_align_median: List[float] = []
    svd_align_median: List[float] = []
    mu_median: List[float] = []
    sigma_min_median: List[float] = []
    residual_ratio_median: List[float] = []
    counts: List[int] = []
    for step in steps:
        alpha_hat = np.asarray(probe_snapshots[step].alpha_hat, dtype=np.float64).reshape(-1)
        if alpha_hat.size and float(np.max(np.abs(alpha_hat))) < 1e-12:
            continue
        used_steps.append(int(step))
        stats = compute_eigen_residual_stats(
            probe_snapshots[step],
            centers=centers,
            k_neighbors=k_neighbors,
            rng=rng,
            neighbors_map=neighbors_map,
        )
        sigma_min_median.append(float(stats.get("sigma_min_median", float("nan"))))
        residual_ratio_median.append(float(stats.get("residual_ratio_median", float("nan"))))
        residual_median.append(float(stats.get("residual_median", float("nan"))))
        eig_align_median.append(float(stats.get("eig_align_median", float("nan"))))
        svd_align_median.append(float(stats.get("svd_align_median", float("nan"))))
        mu_median.append(float(stats.get("mu_median", float("nan"))))
        residuals_arr = np.asarray(stats.get("residuals", []), dtype=np.float64)
        counts.append(int(np.sum(np.isfinite(residuals_arr))))
    steps_arr = np.asarray(used_steps, dtype=np.float64)
    residual_arr = np.asarray(residual_median, dtype=np.float64)
    eig_align_arr = np.asarray(eig_align_median, dtype=np.float64)
    svd_align_arr = np.asarray(svd_align_median, dtype=np.float64)
    residual_ratio_arr = np.asarray(residual_ratio_median, dtype=np.float64)
    return {
        "steps": used_steps,
        "residual_median": residual_median,
        "eig_align_median": eig_align_median,
        "svd_align_median": svd_align_median,
        "mu_median": mu_median,
        "sigma_min_median": sigma_min_median,
        "residual_ratio_median": residual_ratio_median,
        "counts": counts,
        "spearman_residual_median": _spearman_perm_test(steps_arr, residual_arr, rng),
        "spearman_eig_align_median": _spearman_perm_test(steps_arr, eig_align_arr, rng),
        "spearman_svd_align_median": _spearman_perm_test(steps_arr, svd_align_arr, rng),
        "spearman_residual_ratio_median": _spearman_perm_test(steps_arr, residual_ratio_arr, rng),
    }


def compute_effective_gain_stats(
    snapshot: ProbeSnapshot,
) -> Dict[str, float]:
    u_mean = snapshot.u.mean(axis=2)
    lambda_vals = snapshot.lambda_vals.reshape(1, -1)
    eff_gain = np.abs(u_mean * lambda_vals)
    return {
        "median": float(np.median(eff_gain)),
        "p90": float(np.percentile(eff_gain, 90)),
        "p95": float(np.percentile(eff_gain, 95)),
        "p99": float(np.percentile(eff_gain, 99)),
    }


def compute_criticality_stats(
    model: TorchLocalRuleRNN,
    inputs: np.ndarray,
    time_points: int | None,
    rng: np.random.Generator,
) -> Dict[str, float]:
    inputs_t = torch.as_tensor(inputs, device=model.device, dtype=torch.float32)
    batch_size = int(inputs_t.shape[0])
    time_steps = int(inputs_t.shape[2])
    if time_points is None or time_points >= time_steps:
        time_idx = None
    else:
        time_idx = sorted(rng.choice(time_steps, size=max(1, int(time_points)), replace=False).tolist())
    time_idx_set = set(time_idx) if time_idx is not None else None

    h_prev = torch.zeros((model.hidden_size, batch_size), device=model.device)
    rho_eff_vals: List[float] = []
    u_stack: List[np.ndarray] = []

    with torch.no_grad():
        for t in range(time_steps):
            I_t = inputs_t[:, :, t].T
            x_t = model.W_hh @ h_prev + model.W_xh @ I_t + model.b_h
            h_t = torch.tanh(x_t)
            u_t = 1.0 - h_t**2

            if time_idx_set is None or t in time_idx_set:
                v_t = h_prev
                denom = torch.norm(v_t, dim=0)
                numer = torch.norm(u_t * (model.W_hh @ v_t), dim=0)
                mask = denom > 1e-8
                if torch.any(mask):
                    rho_eff_vals.append(torch.mean(numer[mask] / denom[mask]).item())

            u_stack.append(u_t.detach().cpu().numpy())
            h_prev = h_t

    rho_eff_median = float(np.median(rho_eff_vals)) if rho_eff_vals else float("nan")
    rho_eff_mean = float(np.mean(rho_eff_vals)) if rho_eff_vals else float("nan")
    if u_stack:
        u_arr = np.stack(u_stack, axis=0)
        u_mean = u_arr.mean(axis=2)
        lambda_vals = model.lambda_vals.detach().cpu().numpy().reshape(1, -1)
        eff_gain = np.abs(u_mean * lambda_vals)
        eff_gain_p95 = float(np.percentile(eff_gain, 95))
        eff_gain_median = float(np.median(eff_gain))
    else:
        eff_gain_p95 = float("nan")
        eff_gain_median = float("nan")

    return {
        "rho_eff_median": rho_eff_median,
        "rho_eff_mean": rho_eff_mean,
        "eff_gain_p95": eff_gain_p95,
        "eff_gain_median": eff_gain_median,
    }


def evaluate_hypotheses(
    rank_dyn_stats: Dict[str, Any],
    ar1_stats: Dict[str, Any],
    smooth_stats: Dict[str, Any],
    eff_gain_stats: Dict[str, Any],
    lambda_vol_stats: Dict[str, Any],
    neumann_stats: Dict[str, Any],
    grad_stats: Dict[str, Any],
    thresholds: HypothesisThresholds,
) -> Dict[str, Any]:
    h1_pass = (
        rank_dyn_stats["top2_median"] >= thresholds.rank_top2_median_min
        and rank_dyn_stats["pc1_median"] >= thresholds.rank_pc1_median_min
        and rank_dyn_stats["rank90_median"] <= thresholds.rank90_median_max
    )
    h2_pass = (
        ar1_stats["corr_median"] >= thresholds.h_corr_median_min
        and ar1_stats["r2_centered_median"] >= thresholds.h_r2_centered_median_min
        and smooth_stats["grad_corr_median"] >= thresholds.grad_smooth_corr_median_min
    )
    h3_pass = (
        eff_gain_stats["p95"] <= thresholds.eff_gain_p95_max
        and neumann_stats["err_median"] <= thresholds.neumann_median_max
        and neumann_stats["err_p90"] <= thresholds.neumann_p90_max
    )
    h4_pass = (
        grad_stats["cos_median"] >= thresholds.grad_cos_median_min
        and grad_stats["sign_median"] >= thresholds.grad_sign_median_min
    )
    lambda_vol_p95 = float(lambda_vol_stats.get("vol_p95", float("nan")))
    h5_pass = lambda_vol_p95 <= thresholds.lambda_vol_p95_max
    return {
        "passed": h1_pass and h2_pass and h3_pass and h4_pass and h5_pass,
        "H1": h1_pass,
        "H2": h2_pass,
        "H3": h3_pass,
        "H4": h4_pass,
        "H5": h5_pass,
    }


def evaluate_all_tiers(
    rank_dyn_stats: Dict[str, Any],
    ar1_stats: Dict[str, Any],
    smooth_stats: Dict[str, Any],
    eff_gain_stats: Dict[str, Any],
    lambda_vol_stats: Dict[str, Any],
    neumann_stats: Dict[str, Any],
    grad_stats: Dict[str, Any],
) -> Dict[str, Any]:
    tier_results: List[Dict[str, Any]] = []
    best_idx = None
    best_tier = None
    best_thresholds = None
    best_eval = None
    for idx, (tier_name, thresholds) in enumerate(THRESHOLD_TIERS):
        eval_result = evaluate_hypotheses(
            rank_dyn_stats,
            ar1_stats,
            smooth_stats,
            eff_gain_stats,
            lambda_vol_stats,
            neumann_stats,
            grad_stats,
            thresholds,
        )
        tier_results.append(
            {
                "tier_name": tier_name,
                "passed": bool(eval_result["passed"]),
            }
        )
        if eval_result["passed"] and best_idx is None:
            best_idx = idx
            best_tier = tier_name
            best_thresholds = thresholds
            best_eval = eval_result
    return {
        "best_idx": best_idx,
        "best_tier": best_tier,
        "best_thresholds": best_thresholds,
        "best_eval": best_eval,
        "tier_results": tier_results,
    }


def compute_spectral_radius(
    snapshot: ProbeSnapshot,
    alpha_hat: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
    W_hh = snapshot.W_hh
    time_steps, hidden, batch = snapshot.u.shape
    rho = []
    rho_mean = []
    for t in range(time_steps):
        u_t = snapshot.u[t].mean(axis=1)
        A_t = (u_t[:, None] * W_hh.T) * alpha_hat[None, :]
        eigvals = np.linalg.eigvals(A_t)
        rho_t = np.max(np.abs(eigvals))
        rho.append(float(rho_t))
        rho_mean.append(float(np.mean(np.abs(eigvals))))
    return np.array(rho), np.array(rho_mean)


def compute_neumann_error(
    snapshot: ProbeSnapshot,
    alpha_hat: np.ndarray,
    lambda_vals: np.ndarray,
    rng: np.random.Generator,
    time_samples: int = 5,
) -> np.ndarray:
    W_hh = snapshot.W_hh
    time_steps = snapshot.u.shape[0]
    indices = rng.choice(time_steps, size=min(time_samples, time_steps), replace=False)
    errors: List[float] = []
    for t in indices:
        u_t = snapshot.u[t].mean(axis=1)
        s_t = snapshot.s[t].mean(axis=1)
        A_t = (u_t[:, None] * W_hh.T) * alpha_hat[None, :]
        eye = np.eye(A_t.shape[0], dtype=np.float64)
        try:
            v_exact = np.linalg.solve(eye - A_t, s_t)
        except np.linalg.LinAlgError:
            continue
        denom = 1.0 - lambda_vals * u_t
        v_approx = s_t / (denom + 1e-8)
        rel_err = np.abs(v_exact - v_approx) / (np.abs(v_exact) + 1e-8)
        errors.extend(rel_err.tolist())
    return np.array(errors)


def compute_grad_alignment(
    delta_true: np.ndarray,
    delta_approx: np.ndarray,
) -> Tuple[float, float]:
    d_true = delta_true.mean(axis=2)
    d_approx = delta_approx.mean(axis=2)
    cos = compute_cosine_similarity(d_true, d_approx)
    sign = np.mean(np.sign(d_true) == np.sign(d_approx))
    return cos, float(sign)


def plot_rank1_scree(
    evr_map: Dict[int, List[float]],
    out_path: Path,
) -> None:
    import matplotlib.pyplot as plt
    set_plot_style()

    fig, ax = plt.subplots(figsize=(6.4, 4.4), constrained_layout=True)
    evr_list = [np.array(evr) for evr in evr_map.values() if evr]
    if not evr_list:
        ax.text(0.5, 0.5, "No EVR data available", ha="center", va="center")
    elif len(evr_list) <= 12:
        for center, evr in evr_map.items():
            evr_arr = np.array(evr)
            ax.plot(np.arange(1, len(evr_arr) + 1), evr_arr, marker="o", label=f"center {center}")
        ax.legend(fontsize=8)
    else:
        min_len = min(len(evr) for evr in evr_list)
        evr_stack = np.stack([evr[:min_len] for evr in evr_list], axis=0)
        x_vals = np.arange(1, min_len + 1)
        mean = np.mean(evr_stack, axis=0)
        median = np.median(evr_stack, axis=0)
        p10 = np.percentile(evr_stack, 10, axis=0)
        p90 = np.percentile(evr_stack, 90, axis=0)
        ax.fill_between(x_vals, p10, p90, color="0.8", alpha=0.6, label="p10-p90")
        ax.plot(x_vals, mean, marker="o", color="#0072B2", label="mean")
        ax.plot(x_vals, median, marker="o", color="#D55E00", label="median")
        ax.legend(fontsize=8)
    ax.set_xlabel("Component")
    ax.set_ylabel("Explained Variance Ratio")
    ax.set_title("Rank-1 Hypothesis: Local SVD Scree")
    ax.grid(True, linestyle=":")
    save_figure(fig, out_path)
    plt.close(fig)


def plot_rank1_manifold(
    s_samples: np.ndarray,
    neuron_indices: List[int],
    out_path: Path,
) -> None:
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
    set_plot_style()

    fig = plt.figure(figsize=(6.2, 5.2), constrained_layout=True)
    ax = fig.add_subplot(111, projection="3d")
    pts = s_samples[neuron_indices].T
    ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2], s=10, alpha=0.7)
    ax.set_xlabel(f"Neuron {neuron_indices[0]}")
    ax.set_ylabel(f"Neuron {neuron_indices[1]}")
    ax.set_zlabel(f"Neuron {neuron_indices[2]}")
    ax.set_title("Local Source Term Manifold (3D)")
    save_figure(fig, out_path)
    plt.close(fig)


def plot_convex_descent_3d(
    bptt_params: np.ndarray,
    bptt_losses: np.ndarray,
    local_params: np.ndarray,
    local_losses: np.ndarray,
    out_path: Path,
    *,
    label_a: str = "BPTT",
    label_b: str = "local_rule",
    xlabel: str = "w1",
    ylabel: str = "w2",
    title: str = "Toy Convex Quadratic: 3D Descent Trajectory",
) -> None:
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D  # noqa: F401

    set_plot_style()

    fig = plt.figure(figsize=(6.8, 5.2), constrained_layout=True)
    ax = fig.add_subplot(111, projection="3d")
    ax.plot(
        bptt_params[:, 0],
        bptt_params[:, 1],
        bptt_losses,
        color="#0072B2",
        label=label_a,
    )
    ax.plot(
        local_params[:, 0],
        local_params[:, 1],
        local_losses,
        color="#D55E00",
        label=label_b,
    )
    ax.scatter(
        [bptt_params[0, 0]],
        [bptt_params[0, 1]],
        [bptt_losses[0]],
        color="#0072B2",
        marker="o",
        s=30,
        alpha=0.9,
    )
    ax.scatter(
        [local_params[0, 0]],
        [local_params[0, 1]],
        [local_losses[0]],
        color="#D55E00",
        marker="o",
        s=30,
        alpha=0.9,
    )
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_zlabel("Loss")
    ax.set_title(title)
    ax.view_init(elev=28, azim=-55)
    ax.legend(loc="best")
    save_figure(fig, out_path)
    plt.close(fig)


def plot_convex_contours_with_paths(
    A: np.ndarray,
    w_star: np.ndarray,
    path_a: np.ndarray,
    path_b: np.ndarray,
    out_path: Path,
    *,
    label_a: str = "BPTT",
    label_b: str = "local_rule",
) -> None:
    import matplotlib.pyplot as plt

    set_plot_style()

    A = np.asarray(A, dtype=np.float64)
    w_star = np.asarray(w_star, dtype=np.float64).reshape(2)
    path_a = np.asarray(path_a, dtype=np.float64).reshape(-1, 2)
    path_b = np.asarray(path_b, dtype=np.float64).reshape(-1, 2)

    points = np.vstack([path_a, path_b, w_star[None, :]])
    x_min = float(np.min(points[:, 0]))
    x_max = float(np.max(points[:, 0]))
    y_min = float(np.min(points[:, 1]))
    y_max = float(np.max(points[:, 1]))
    dx = max(x_max - x_min, 1e-6)
    dy = max(y_max - y_min, 1e-6)
    pad = 0.25
    x_min -= pad * dx
    x_max += pad * dx
    y_min -= pad * dy
    y_max += pad * dy

    grid_n = 220
    xs = np.linspace(x_min, x_max, grid_n, dtype=np.float64)
    ys = np.linspace(y_min, y_max, grid_n, dtype=np.float64)
    X, Y = np.meshgrid(xs, ys)
    D = np.stack([X - w_star[0], Y - w_star[1]], axis=-1)
    Z = 0.5 * (
        A[0, 0] * D[..., 0] ** 2
        + 2.0 * A[0, 1] * D[..., 0] * D[..., 1]
        + A[1, 1] * D[..., 1] ** 2
    )

    fig, ax = plt.subplots(figsize=(6.6, 5.4), constrained_layout=True)
    levels = 35
    cf = ax.contourf(X, Y, Z, levels=levels, cmap="viridis")
    ax.contour(X, Y, Z, levels=12, colors="0.35", linewidths=0.7, alpha=0.55)
    fig.colorbar(cf, ax=ax, fraction=0.046, pad=0.04, label="Loss")

    ax.plot(path_a[:, 0], path_a[:, 1], color="#0072B2", marker="o", markersize=3.5, label=label_a)
    ax.plot(path_b[:, 0], path_b[:, 1], color="#D55E00", marker="o", markersize=3.5, label=label_b)
    ax.scatter([path_a[0, 0]], [path_a[0, 1]], color="black", s=22, marker="o", label="start")
    ax.scatter([w_star[0]], [w_star[1]], color="black", s=32, marker="*", label="optimum")
    ax.set_title("Convex Loss Contours + Descent Paths")
    ax.set_xlabel("w1")
    ax.set_ylabel("w2")
    ax.set_aspect("equal", adjustable="box")
    ax.grid(True, linestyle=":", alpha=0.5)
    ax.legend(fontsize=9)
    save_figure(fig, out_path)
    plt.close(fig)


def plot_paths_2d(
    path_a: np.ndarray,
    path_b: np.ndarray,
    out_path: Path,
    *,
    label_a: str = "BPTT",
    label_b: str = "local_rule",
    xlabel: str = "coord-1",
    ylabel: str = "coord-2",
    title: str = "Descent paths (2D)",
) -> None:
    import matplotlib.pyplot as plt

    set_plot_style()

    path_a = np.asarray(path_a, dtype=np.float64).reshape(-1, 2)
    path_b = np.asarray(path_b, dtype=np.float64).reshape(-1, 2)

    fig, ax = plt.subplots(figsize=(6.4, 5.2), constrained_layout=True)
    ax.plot(path_a[:, 0], path_a[:, 1], color="#0072B2", marker="o", markersize=3.5, label=label_a)
    ax.plot(path_b[:, 0], path_b[:, 1], color="#D55E00", marker="o", markersize=3.5, label=label_b)
    ax.scatter([path_a[0, 0]], [path_a[0, 1]], color="black", s=22, marker="o", label="start")
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.grid(True, linestyle=":", alpha=0.5)
    ax.legend(fontsize=9)
    save_figure(fig, out_path)
    plt.close(fig)


def plot_histogram(
    values: np.ndarray,
    title: str,
    xlabel: str,
    out_path: Path,
    bins: int = 40,
) -> None:
    import matplotlib.pyplot as plt
    set_plot_style()

    fig, ax = plt.subplots(figsize=(6.2, 4.4), constrained_layout=True)
    ax.hist(values, bins=bins, color="#0072B2", alpha=0.8, edgecolor="black")
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel("Count")
    ax.grid(True, axis="y", linestyle=":")
    save_figure(fig, out_path)
    plt.close(fig)


def plot_scatter(
    x: np.ndarray,
    y: np.ndarray,
    title: str,
    xlabel: str,
    ylabel: str,
    out_path: Path,
    reference: bool = True,
) -> None:
    import matplotlib.pyplot as plt
    set_plot_style()

    fig, ax = plt.subplots(figsize=(6.2, 4.8), constrained_layout=True)
    ax.scatter(x, y, s=12, alpha=0.7, color="#D55E00")
    if reference:
        min_val = min(np.min(x), np.min(y))
        max_val = max(np.max(x), np.max(y))
        ax.plot([min_val, max_val], [min_val, max_val], linestyle="--", color="0.3")
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.grid(True, linestyle=":")
    save_figure(fig, out_path)
    plt.close(fig)


def plot_curve(
    steps: List[int],
    values: List[float],
    title: str,
    ylabel: str,
    out_path: Path,
) -> None:
    import matplotlib.pyplot as plt
    set_plot_style()

    fig, ax = plt.subplots(figsize=(6.4, 4.2), constrained_layout=True)
    ax.plot(steps, values, marker="o", color="#009E73")
    ax.set_title(title)
    ax.set_xlabel("Training Step")
    ax.set_ylabel(ylabel)
    ax.grid(True, linestyle=":")
    save_figure(fig, out_path)
    plt.close(fig)


def plot_two_curves(
    steps: List[int],
    values_a: List[float],
    values_b: List[float],
    label_a: str,
    label_b: str,
    title: str,
    ylabel: str,
    out_path: Path,
    ylim: Tuple[float, float] | None = None,
) -> None:
    import matplotlib.pyplot as plt

    set_plot_style()
    fig, ax = plt.subplots(figsize=(6.6, 4.2), constrained_layout=True)
    ax.plot(steps, values_a, marker="o", color="#0072B2", label=label_a)
    ax.plot(steps, values_b, marker="o", color="#D55E00", label=label_b)
    ax.set_title(title)
    ax.set_xlabel("Training Step")
    ax.set_ylabel(ylabel)
    if ylim is not None:
        ax.set_ylim(float(ylim[0]), float(ylim[1]))
    ax.grid(True, linestyle=":")
    ax.legend(fontsize=9)
    save_figure(fig, out_path)
    plt.close(fig)


def plot_xy_curve(
    x: np.ndarray,
    y: np.ndarray,
    title: str,
    xlabel: str,
    ylabel: str,
    out_path: Path,
) -> None:
    import matplotlib.pyplot as plt
    set_plot_style()

    fig, ax = plt.subplots(figsize=(6.4, 4.2), constrained_layout=True)
    ax.plot(x, y, color="#0072B2")
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.grid(True, linestyle=":")
    save_figure(fig, out_path)
    plt.close(fig)

def plot_heatmap(
    values: np.ndarray,
    title: str,
    xlabel: str,
    ylabel: str,
    out_path: Path,
    cmap: str = "viridis",
    center: float | None = None,
    symmetric: bool = False,
) -> None:
    import matplotlib.pyplot as plt
    import matplotlib as mpl

    set_plot_style()
    fig, ax = plt.subplots(figsize=(7.0, 4.8), constrained_layout=True)
    norm = None
    if center is not None:
        vmin = float(np.nanmin(values))
        vmax = float(np.nanmax(values))
        if symmetric:
            max_abs = max(abs(vmin - center), abs(vmax - center), 1e-12)
            vmin = center - max_abs
            vmax = center + max_abs
        elif not (vmin < center < vmax):
            max_abs = max(abs(vmin - center), abs(vmax - center), 1e-12)
            vmin = center - max_abs
            vmax = center + max_abs
        norm = mpl.colors.TwoSlopeNorm(vmin=vmin, vcenter=center, vmax=vmax)
    im = ax.imshow(values, aspect="auto", cmap=cmap, interpolation="nearest", norm=norm)
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    save_figure(fig, out_path)
    plt.close(fig)


def plot_lambda_trajectories(
    steps: List[int],
    stable_series: np.ndarray,
    unstable_series: np.ndarray,
    stable_idx: int,
    unstable_idx: int,
    out_path: Path,
) -> None:
    import matplotlib.pyplot as plt
    set_plot_style()

    fig, ax = plt.subplots(figsize=(6.6, 4.2), constrained_layout=True)
    ax.plot(steps, stable_series, color="#009E73", label=f"stable neuron {stable_idx}")
    ax.plot(steps, unstable_series, color="#D55E00", label=f"unstable neuron {unstable_idx}")
    ax.set_title("Lambda Trajectories: Most vs Least Stable")
    ax.set_xlabel("Training Step")
    ax.set_ylabel(r"$\lambda$")
    ax.grid(True, linestyle=":")
    ax.legend(fontsize=8)
    save_figure(fig, out_path)
    plt.close(fig)


def write_report(
    report_path: Path,
    plot_dir: Path,
    task_label: str,
    config: Dict[str, Any],
    rank1_stats: Dict[str, Any],
    eigen_residual_stats: Dict[str, Any],
    rank_dyn_stats: Dict[str, Any],
    ar1_stats: Dict[str, Any],
    delta_approx_ar1_stats: Dict[str, Any],
    duality_stats: Dict[str, Any],
    smooth_stats: Dict[str, Any],
    lambda_stats: Dict[str, Any],
    stability_lambda_stats: Dict[str, Any],
    eff_gain_stats: Dict[str, Any],
    neumann_stats: Dict[str, Any],
    grad_stats: Dict[str, Any],
    hypothesis_eval: Dict[str, Any],
    spectral_curve: List[Dict[str, float]],
    attempt_summaries: List[Dict[str, Any]] | None = None,
    eigen_align_curve_dynamic: Dict[str, Any] | None = None,
    eigen_align_curve_fixed: Dict[str, Any] | None = None,
    lyap_diag: Dict[str, Any] | None = None,
) -> None:
    rel_plot_dir = plot_dir.relative_to(Path.cwd())
    plot_dir_posix = rel_plot_dir.as_posix()
    def _fmt(val: Any, digits: int = 4) -> str:
        if val is None:
            return "n/a"
        if isinstance(val, (float, np.floating)):
            if not np.isfinite(val):
                return "nan"
            return f"{float(val):.{digits}f}"
        return str(val)

    def _summ(values: Any) -> Dict[str, float]:
        arr = np.asarray(values, dtype=np.float64).reshape(-1)
        arr = arr[np.isfinite(arr)]
        if arr.size == 0:
            return {
                "mean": float("nan"),
                "median": float("nan"),
                "p10": float("nan"),
                "p90": float("nan"),
                "min": float("nan"),
                "max": float("nan"),
            }
        return {
            "mean": float(np.mean(arr)),
            "median": float(np.median(arr)),
            "p10": float(np.percentile(arr, 10)),
            "p90": float(np.percentile(arr, 90)),
            "min": float(np.min(arr)),
            "max": float(np.max(arr)),
        }

    lines = []
    lines.append(f"# OLL 物理假设验证实验报告（{task_label}）")
    lines.append("")
    lines.append("## 1. 实验设置摘要")
    for key, value in config.items():
        lines.append(f"- {key}: {value}")
    lines.append("")
    lines.append("## 2. 假设一：空间局部性与动态低秩")
    centers = list(rank1_stats.get("centers", []))
    centers_label = str(centers) if len(centers) <= 16 else f"{centers[:8]} ... {centers[-8:]} (n={len(centers)})"
    lines.append(f"- 选取中心神经元: {centers_label}")
    lines.append(f"- 邻域大小 K: {rank1_stats['k_neighbors']}, 采样时间点: {rank1_stats['time_points']}")
    lines.append(
        f"- 动态窗口(长度/步长): {rank_dyn_stats['window']} / {rank_dyn_stats['stride']}"
    )
    pc1_center_meds = np.array([np.median(v) for v in rank_dyn_stats.get("pc1", {}).values() if v], dtype=np.float64)
    top2_center_meds = np.array([np.median(v) for v in rank_dyn_stats.get("top2", {}).values() if v], dtype=np.float64)
    rank90_center_meds = np.array([np.median(v) for v in rank_dyn_stats.get("rank90", {}).values() if v], dtype=np.float64)
    pc1_s = _summ(pc1_center_meds)
    top2_s = _summ(top2_center_meds)
    rank90_s = _summ(rank90_center_meds)
    lines.append(
        "- 动态低秩(按中心取中位数后再汇总): "
        f"PC1_median={rank_dyn_stats['pc1_median']:.4f} (p10={pc1_s['p10']:.4f}, p90={pc1_s['p90']:.4f}), "
        f"Top2_median={rank_dyn_stats['top2_median']:.4f} (p10={top2_s['p10']:.4f}, p90={top2_s['p90']:.4f}), "
        f"Rank90_median={rank_dyn_stats['rank90_median']:.2f} (p10={rank90_s['p10']:.2f}, p90={rank90_s['p90']:.2f})"
    )
    pc1_cos_map = rank1_stats.get("pc1_cos", {})
    if pc1_cos_map:
        pc1_cos_vals = np.array(list(pc1_cos_map.values()), dtype=np.float64)
        pc1_cos_s = _summ(pc1_cos_vals)
        most_stable = max(pc1_cos_map.items(), key=lambda kv: kv[1])[0]
        least_stable = min(pc1_cos_map.items(), key=lambda kv: kv[1])[0]
        lines.append(
            "- PC1 方向稳定性(cosine): "
            f"median={pc1_cos_s['median']:.4f} (p10={pc1_cos_s['p10']:.4f}, p90={pc1_cos_s['p90']:.4f}), "
            f"min={pc1_cos_s['min']:.4f}, max={pc1_cos_s['max']:.4f} "
            f"| 最稳定={most_stable}, 最不稳定={least_stable}"
        )
    else:
        lines.append("- PC1 方向稳定性(cosine): n/a")
    if "pc1_mean" in rank1_stats:
        lines.append(
            "- PC1 解释方差(均值/最小/最大): "
            f"{rank1_stats['pc1_mean']:.4f} / {rank1_stats['pc1_min']:.4f} / {rank1_stats['pc1_max']:.4f}"
        )
        lines.append(f"- 结论: {rank1_stats.get('conclusion', '')}")
    lines.append("")
    lines.append(f"![Scree]({plot_dir_posix}/rank1_scree.png)")
    lines.append("")
    lines.append(f"![Manifold]({plot_dir_posix}/rank1_manifold.png)")
    lines.append("")
    residual_mean = eigen_residual_stats.get("residual_mean")
    residual_median = eigen_residual_stats.get("residual_median")
    mu_mean = eigen_residual_stats.get("mu_mean")
    mu_median = eigen_residual_stats.get("mu_median")
    eig_align_mean = eigen_residual_stats.get("eig_align_mean")
    eig_align_median = eigen_residual_stats.get("eig_align_median")
    svd_align_mean = eigen_residual_stats.get("svd_align_mean")
    svd_align_median = eigen_residual_stats.get("svd_align_median")
    sigma_min_mean = eigen_residual_stats.get("sigma_min_mean")
    sigma_min_median = eigen_residual_stats.get("sigma_min_median")
    residual_ratio_mean = eigen_residual_stats.get("residual_ratio_mean")
    residual_ratio_median = eigen_residual_stats.get("residual_ratio_median")
    residuals_arr = np.asarray(eigen_residual_stats.get("residuals", []), dtype=np.float64)
    residual_count = int(np.sum(np.isfinite(residuals_arr)))
    residual_window = eigen_residual_stats.get("window")
    lines.append("## 2.5 Eigen-Residual Verification")
    lines.append(
        "- Relative residual (mean/median"
        f", n={residual_count}, window={_fmt(residual_window, digits=0)}): "
        f"{_fmt(residual_mean)} / {_fmt(residual_median)}"
    )
    lines.append(f"- Rayleigh mu (mean/median): {_fmt(mu_mean)} / {_fmt(mu_median)}")
    if (
        eig_align_mean is not None
        and eig_align_median is not None
        and np.isfinite(eig_align_mean)
        and np.isfinite(eig_align_median)
    ):
        lines.append(
            f"- Align(q_s, eig_dom(A_local)) (mean/median): {_fmt(eig_align_mean)} / {_fmt(eig_align_median)}"
        )
    if (
        svd_align_mean is not None
        and svd_align_median is not None
        and np.isfinite(svd_align_mean)
        and np.isfinite(svd_align_median)
    ):
        lines.append(
            f"- Align(q_s, svd_dom(A_local)) (mean/median): {_fmt(svd_align_mean)} / {_fmt(svd_align_median)}"
        )
    if (
        sigma_min_mean is not None
        and sigma_min_median is not None
        and np.isfinite(sigma_min_mean)
        and np.isfinite(sigma_min_median)
    ):
        lines.append(
            f"- σ_min(I-A_local) (mean/median): {_fmt(sigma_min_mean)} / {_fmt(sigma_min_median)}"
        )
    if (
        residual_ratio_mean is not None
        and residual_ratio_median is not None
        and np.isfinite(residual_ratio_mean)
        and np.isfinite(residual_ratio_median)
    ):
        lines.append(
            f"- ||(I-A_local)^(-1) r|| / ||q_s|| (mean/median): {_fmt(residual_ratio_mean)} / {_fmt(residual_ratio_median)}"
        )
    lines.append("")
    if residual_count:
        lines.append(f"![Eigen Residual]({plot_dir_posix}/eigen_residual_hist.png)")
        lines.append("")
    if eigen_align_curve_dynamic and eigen_align_curve_fixed:
        steps = list(eigen_align_curve_dynamic.get("steps", []))
        if steps:
            fixed_ref = eigen_align_curve_fixed.get("ref_step")
            fixed_label = "fixed" if fixed_ref is None else f"fixed@step{int(fixed_ref)}"
            dyn_eig = np.asarray(eigen_align_curve_dynamic.get("eig_align_median", []), dtype=np.float64)
            fix_eig = np.asarray(eigen_align_curve_fixed.get("eig_align_median", []), dtype=np.float64)
            dyn_svd = np.asarray(eigen_align_curve_dynamic.get("svd_align_median", []), dtype=np.float64)
            fix_svd = np.asarray(eigen_align_curve_fixed.get("svd_align_median", []), dtype=np.float64)
            dyn_res = np.asarray(eigen_align_curve_dynamic.get("residual_median", []), dtype=np.float64)
            fix_res = np.asarray(eigen_align_curve_fixed.get("residual_median", []), dtype=np.float64)
            dyn_sigma = np.asarray(
                eigen_align_curve_dynamic.get("sigma_min_median", []), dtype=np.float64
            )
            fix_sigma = np.asarray(eigen_align_curve_fixed.get("sigma_min_median", []), dtype=np.float64)
            dyn_ratio = np.asarray(
                eigen_align_curve_dynamic.get("residual_ratio_median", []), dtype=np.float64
            )
            fix_ratio = np.asarray(eigen_align_curve_fixed.get("residual_ratio_median", []), dtype=np.float64)

            def _first_last(values: np.ndarray) -> Tuple[float, float]:
                mask = np.isfinite(values)
                if not mask.size or not np.any(mask):
                    return float("nan"), float("nan")
                idx = np.where(mask)[0]
                return float(values[int(idx[0])]), float(values[int(idx[-1])])

            lines.append("## 2.6 动态谱对齐（训练过程）")
            lines.append(f"- probe steps: {steps}")
            if dyn_eig.size and fix_eig.size:
                dyn_first, dyn_last = _first_last(dyn_eig)
                fix_first, fix_last = _first_last(fix_eig)
                lines.append(
                    "- Align(q_s, eig_dom(A_local)) median: "
                    f"dynamic {_fmt(dyn_first)} → {_fmt(dyn_last)}, "
                    f"{fixed_label} {_fmt(fix_first)} → {_fmt(fix_last)}"
                )
            if dyn_svd.size and fix_svd.size:
                dyn_first, dyn_last = _first_last(dyn_svd)
                fix_first, fix_last = _first_last(fix_svd)
                lines.append(
                    "- Align(q_s, svd_dom(A_local)) median: "
                    f"dynamic {_fmt(dyn_first)} → {_fmt(dyn_last)}, "
                    f"{fixed_label} {_fmt(fix_first)} → {_fmt(fix_last)}"
                )
            if dyn_res.size and fix_res.size:
                dyn_first, dyn_last = _first_last(dyn_res)
                fix_first, fix_last = _first_last(fix_res)
                lines.append(
                    "- Eigen-residual median: "
                    f"dynamic {_fmt(dyn_first)} → {_fmt(dyn_last)}, "
                    f"{fixed_label} {_fmt(fix_first)} → {_fmt(fix_last)}"
                )
            if dyn_sigma.size and fix_sigma.size:
                dyn_first, dyn_last = _first_last(dyn_sigma)
                fix_first, fix_last = _first_last(fix_sigma)
                lines.append(
                    "- σ_min(I-A_local) median: "
                    f"dynamic {_fmt(dyn_first)} → {_fmt(dyn_last)}, "
                    f"{fixed_label} {_fmt(fix_first)} → {_fmt(fix_last)}"
                )
            if dyn_ratio.size and fix_ratio.size:
                dyn_first, dyn_last = _first_last(dyn_ratio)
                fix_first, fix_last = _first_last(fix_ratio)
                lines.append(
                    "- ||(I-A_local)^(-1) r|| / ||q_s|| median: "
                    f"dynamic {_fmt(dyn_first)} → {_fmt(dyn_last)}, "
                    f"{fixed_label} {_fmt(fix_first)} → {_fmt(fix_last)}"
                )
            spe_dyn = eigen_align_curve_dynamic.get("spearman_eig_align_median", {})
            spe_fix = eigen_align_curve_fixed.get("spearman_eig_align_median", {})
            if spe_dyn or spe_fix:
                lines.append(
                    "- Spearman ρ(step, eig_align_median): "
                    f"dynamic {_fmt(spe_dyn.get('rho'))} (p={_fmt(spe_dyn.get('p'))}), "
                    f"{fixed_label} {_fmt(spe_fix.get('rho'))} (p={_fmt(spe_fix.get('p'))})"
                )
            spe_dyn = eigen_align_curve_dynamic.get("spearman_residual_ratio_median", {})
            spe_fix = eigen_align_curve_fixed.get("spearman_residual_ratio_median", {})
            if spe_dyn or spe_fix:
                lines.append(
                    "- Spearman ρ(step, residual_ratio_median): "
                    f"dynamic {_fmt(spe_dyn.get('rho'))} (p={_fmt(spe_dyn.get('p'))}), "
                    f"{fixed_label} {_fmt(spe_fix.get('rho'))} (p={_fmt(spe_fix.get('p'))})"
                )
            lines.append("")
            lines.append(f"![Eig Align Curve]({plot_dir_posix}/eig_align_curve.png)")
            lines.append("")
            lines.append(f"![SVD Align Curve]({plot_dir_posix}/svd_align_curve.png)")
            lines.append("")
            lines.append(f"![Eigen Residual Curve]({plot_dir_posix}/eigen_residual_curve.png)")
            lines.append("")
    lines.append("## 3. 假设二：时间稳定性与平滑误差信号")
    lines.append(
        f"- AR(1) R^2(中心化)均值/中位数: "
        f"{ar1_stats['r2_centered_mean']:.4f} / {ar1_stats['r2_centered_median']:.4f}"
    )
    lines.append(
        f"- 激活相关系数均值/中位数: {ar1_stats['corr_mean']:.4f} / {ar1_stats['corr_median']:.4f}"
    )
    lines.append(
        f"- 平滑梯度相关(ρ={smooth_stats['rho']:.2f})均值/中位数: "
        f"{smooth_stats['grad_corr_mean']:.4f} / {smooth_stats['grad_corr_median']:.4f}"
    )
    lines.append(
        f"- λ 一致性 R^2(均值/中位数): {lambda_stats['r2_mean']:.4f} / {lambda_stats['r2_median']:.4f}"
    )
    lines.append("")
    lines.append(f"![AR1]({plot_dir_posix}/ar1_r2_hist.png)")
    lines.append("")
    lines.append(f"![Duality]({plot_dir_posix}/duality_scatter_smoothed.png)")
    lines.append("")
    lines.append("## 3.2 OLL 教学信号时间自洽性")
    lines.append(
        f"- $\\tilde{{\\delta}}_t \\approx \\alpha \\, \\tilde{{\\delta}}_{{t-1}}$ (AR(1), 无截距): "
        f"alpha(均值/中位数)={delta_approx_ar1_stats['alpha_mean']:.4f} / {delta_approx_ar1_stats['alpha_median']:.4f}, "
        f"R^2(均值/中位数)={delta_approx_ar1_stats['r2_zero_mean']:.4f} / {delta_approx_ar1_stats['r2_zero_median']:.4f}, "
        f"corr(均值/中位数)={delta_approx_ar1_stats['corr_mean']:.4f} / {delta_approx_ar1_stats['corr_median']:.4f}"
    )
    lines.append(
        "- 注：此处的 α 是从 $\\tilde{\\delta}$ 自身拟合得到的；训练时用于估计 λ 的 `alpha_hat` 是由 $h_t$ 的 AR(1) 在线估计得到的，二者通常不同。"
    )
    lines.append("")
    lines.append(f"![Delta Approx AR1 R2]({plot_dir_posix}/delta_approx_ar1_r2_hist.png)")
    lines.append("")
    lines.append(f"![Delta Approx AR1 Alpha]({plot_dir_posix}/delta_approx_ar1_alpha_hist.png)")
    lines.append("")
    lines.append("## 3.5 Stability vs Lambda Adaptation")
    most_stable = stability_lambda_stats.get("most_stable")
    least_stable = stability_lambda_stats.get("least_stable")
    most_stable_score = stability_lambda_stats.get("most_stable_score")
    least_stable_score = stability_lambda_stats.get("least_stable_score")
    most_stable_vol = stability_lambda_stats.get("most_stable_volatility")
    least_stable_vol = stability_lambda_stats.get("least_stable_volatility")
    pc1_lambda_corr = stability_lambda_stats.get("pc1_lambda_corr")
    lambda_vol_p95 = stability_lambda_stats.get("lambda_vol_p95")
    lambda_vol_median = stability_lambda_stats.get("lambda_vol_median")
    if most_stable is None or least_stable is None:
        lines.append("- extremes: n/a")
    else:
        lines.append(
            "- most_stable_neuron: "
            f"{most_stable} (pc1_cos={most_stable_score:.4f}, lambda_vol={most_stable_vol:.4f})"
        )
        lines.append(
            "- least_stable_neuron: "
            f"{least_stable} (pc1_cos={least_stable_score:.4f}, lambda_vol={least_stable_vol:.4f})"
        )
    if pc1_lambda_corr is not None and not np.isnan(pc1_lambda_corr):
        lines.append(f"- corr(pc1_cos, lambda_volatility(excluding step0)): {pc1_lambda_corr:.4f}")
    if lambda_vol_p95 is not None and not np.isnan(lambda_vol_p95):
        if lambda_vol_median is not None and not np.isnan(lambda_vol_median):
            lines.append(
                f"- lambda_volatility(excluding step0) p95/median: {lambda_vol_p95:.4f} / {lambda_vol_median:.4f}"
            )
    lines.append("")
    lines.append(f"![PC1 vs Lambda Volatility]({plot_dir_posix}/stability_lambda_volatility_scatter.png)")
    lines.append("")
    lines.append(f"![Lambda Trajectory]({plot_dir_posix}/lambda_trajectory_extremes.png)")
    lines.append("")
    lines.append("## 4. 假设三：有效回路增益与 Neumann 近似")
    lines.append(
        f"- 谱半径(最大)均值/最大值: {neumann_stats['rho_mean']:.4f} / {neumann_stats['rho_max']:.4f}"
    )
    lines.append(
        f"- 有效增益 |λu|(中位数/p95): {eff_gain_stats['median']:.4f} / {eff_gain_stats['p95']:.4f}"
    )
    lines.append(
        f"- 近似误差(相对)均值/中位数/p90: "
        f"{neumann_stats['err_mean']:.4f} / {neumann_stats['err_median']:.4f} / {neumann_stats['err_p90']:.4f}"
    )
    lines.append("")
    lines.append(f"![Spectral]({plot_dir_posix}/spectral_radius_curve.png)")
    lines.append("")
    lines.append(f"![Neumann]({plot_dir_posix}/neumann_error_hist.png)")
    lines.append("")
    lines.append("## 5. 假设四：梯度一致性")
    lines.append(
        f"- 余弦相似度(均值/中位数): {grad_stats['cos_mean']:.4f} / {grad_stats['cos_median']:.4f}"
    )
    lines.append(
        f"- Sign Agreement(均值/中位数): {grad_stats['sign_mean']:.4f} / {grad_stats['sign_median']:.4f}"
    )
    lines.append("")
    lines.append(f"![Cosine]({plot_dir_posix}/grad_cos_curve.png)")
    lines.append("")
    lines.append(f"![Grad Scatter]({plot_dir_posix}/grad_scatter.png)")
    lines.append("")
    lines.append(f"![Delta True Heatmap]({plot_dir_posix}/delta_true_heatmap.png)")
    lines.append("")
    lines.append(f"![Delta Approx Heatmap]({plot_dir_posix}/delta_approx_heatmap.png)")
    lines.append("")
    if attempt_summaries:
        lines.append("## 6. 逐次尝试日志")
        for item in attempt_summaries:
            pass_label = "pass" if item.get("target_passed") else "fail"
            lines.append(
                "- "
                f"attempt={item.get('attempt')} "
                f"seed={item.get('seed')} "
                f"tier={item.get('target_tier')} "
                f"pass={pass_label} "
                f"best_gain={_fmt(item.get('best_gain'), digits=3)} "
                f"val_acc={_fmt(item.get('best_val_acc'))} "
                f"val_mse={_fmt(item.get('best_val_mse'))} "
                f"crit={_fmt(item.get('best_critical_metric'))} "
                f"score={_fmt(item.get('best_critical_score'))} "
                f"lyap_pre={_fmt(item.get('best_lyap_pre'))} "
                f"lyap_post={_fmt(item.get('best_lyap_post'))}"
            )
        lines.append("")
    lines.append("## 7. 假设通过判定")
    tier_name = hypothesis_eval.get("tier_name", "unknown")
    best_tier = hypothesis_eval.get("best_tier")
    attempt = hypothesis_eval.get("attempt", 1)
    attempt_range = hypothesis_eval.get("attempt_range")
    attempts_per_tier = hypothesis_eval.get("attempts_per_tier")
    passed = hypothesis_eval.get("passed", False)
    pass_label = "通过" if passed else "未通过"
    range_label = ""
    if attempt_range and isinstance(attempt_range, (list, tuple)) and len(attempt_range) == 2:
        range_label = f", range={attempt_range[0]}-{attempt_range[1]}"
    if attempts_per_tier:
        range_label += f", step={attempts_per_tier}"
    lines.append(f"- target tier: {tier_name} (attempt={attempt}{range_label})")
    if best_tier and best_tier != tier_name:
        lines.append(f"- best tier (diagnostic): {best_tier}")
    lines.append(f"- 总体结果: {pass_label}")
    thresholds = hypothesis_eval.get("thresholds")
    if thresholds is not None:
        lambda_p95 = stability_lambda_stats.get(
            "lambda_vol_p95",
            stability_lambda_stats.get("vol_p95", float("nan")),
        )
        lines.append(
            "- H1(动态低秩): "
            f"{'通过' if hypothesis_eval.get('H1') else '未通过'} | "
            f"Top2_median={rank_dyn_stats['top2_median']:.4f} >= {thresholds.rank_top2_median_min:.2f}, "
            f"Rank90_median={rank_dyn_stats['rank90_median']:.2f} <= {thresholds.rank90_median_max:.1f} "
            f"(PC1_median={rank_dyn_stats['pc1_median']:.4f}, diag≥{thresholds.rank_pc1_median_min:.2f})"
        )
        lines.append(
            "- H2(时间稳定): "
            f"{'通过' if hypothesis_eval.get('H2') else '未通过'} | "
            f"corr_median={ar1_stats['corr_median']:.4f} >= {thresholds.h_corr_median_min:.2f}, "
            f"R2c_median={ar1_stats['r2_centered_median']:.4f} >= {thresholds.h_r2_centered_median_min:.2f}, "
            f"smooth_median={smooth_stats['grad_corr_median']:.4f} >= {thresholds.grad_smooth_corr_median_min:.2f}"
        )
        lines.append(
            "- H3(回路增益): "
            f"{'通过' if hypothesis_eval.get('H3') else '未通过'} | "
            f"|λu|p95={eff_gain_stats['p95']:.4f} <= {thresholds.eff_gain_p95_max:.3f}, "
            f"err_med={neumann_stats['err_median']:.4f} <= {thresholds.neumann_median_max:.2f}, "
            f"err_p90={neumann_stats['err_p90']:.4f} <= {thresholds.neumann_p90_max:.2f}"
        )
        lines.append(
            "- H4(梯度一致): "
            f"{'通过' if hypothesis_eval.get('H4') else '未通过'} | "
            f"cos_median={grad_stats['cos_median']:.4f} >= {thresholds.grad_cos_median_min:.2f}, "
            f"sign_median={grad_stats['sign_median']:.4f} >= {thresholds.grad_sign_median_min:.2f}"
        )
        lines.append(
            "- H5(lambda 波动): "
            f"{'通过' if hypothesis_eval.get('H5') else '未通过'} | "
            f"p95={_fmt(lambda_p95)} <= {thresholds.lambda_vol_p95_max:.2f}"
        )
    lines.append("## 8. 总结")
    lines.append("- 动态低秩假设：短窗内低维协同成立，允许方向随时间漂移。")
    lines.append("- 时间稳定假设：状态可用中心化 AR(1) 近似，误差信号需经平滑后满足一致性。")
    lines.append("- 回路增益假设：采用有效增益而非全局谱半径判断 Neumann 近似可行性。")
    lines.append("- 梯度一致性：余弦相似度与符号一致率衡量 OLL 教学信号的对齐程度。")
    lines.append("")

    if lyap_diag is not None:
        lines.append("## 9. Lorenz：Lyapunov 指数诊断（为何训练后常在 -1 附近）")
        lyap_qr = lyap_diag.get("lyap_qr")
        lyap_pi = lyap_diag.get("lyap_power_iter")
        contract = lyap_diag.get("contraction_factor")
        rho_w = lyap_diag.get("rho_W_hh")
        sigma_w = lyap_diag.get("sigma_W_hh")
        ub = lyap_diag.get("upper_bound_log_phiMax_plus_log_sigmaW")
        phi_max = np.asarray(lyap_diag.get("phi_max", np.array([])), dtype=np.float64)
        phi_mean = np.asarray(lyap_diag.get("phi_mean", np.array([])), dtype=np.float64)
        sat_frac = np.asarray(lyap_diag.get("sat_frac", np.array([])), dtype=np.float64)

        lines.append(f"- Lyap(QR)={_fmt(lyap_qr)}, Lyap(PI)={_fmt(lyap_pi)}（二者应接近）")
        lines.append(f"- 收缩因子 exp(Lyap)≈{_fmt(contract)}（Lyap≈-1 对应每步约 0.367 倍收缩）")
        lines.append(f"- W_hh 谱半径 ρ(W_hh)={_fmt(rho_w)}, 谱范数 σ_max(W_hh)={_fmt(sigma_w)}")
        lines.append(f"- 上界 E[log(max φ')] + log σ_max(W_hh) ≈ {_fmt(ub)}（φ'=1-h^2）")
        if phi_max.size:
            lines.append(
                "- tanh 导数 φ'=1-h^2（沿驱动轨迹）: "
                f"max φ' median={_fmt(float(np.median(phi_max)))}, "
                f"mean φ' median={_fmt(float(np.median(phi_mean)))}"
            )
        if sat_frac.size:
            lines.append(
                "- 饱和比例 frac(|h|≥0.95)（沿驱动轨迹）: "
                f"median={_fmt(float(np.median(sat_frac)))}, p90={_fmt(float(np.percentile(sat_frac, 90)))}"
            )
        lines.append("")
        lines.append(f"![Lyapunov log-growth hist]({plot_dir_posix}/lyapunov_log_growth_hist.png)")
        lines.append("")
        lines.append(f"![Lyapunov log-growth curve]({plot_dir_posix}/lyapunov_log_growth_curve.png)")
        lines.append("")

    report_path.parent.mkdir(parents=True, exist_ok=True)
    report_path.write_text("\n".join(lines), encoding="utf-8")


def run_convex_quadratic_demo(args: argparse.Namespace) -> Tuple[Path, Path]:
    seed_everything(args.seed)

    A = np.array(
        [[float(args.convex_a11), float(args.convex_a12)], [float(args.convex_a12), float(args.convex_a22)]],
        dtype=np.float64,
    )
    eigs = np.linalg.eigvalsh(A)
    if np.any(~np.isfinite(eigs)) or np.any(eigs <= 0.0):
        raise ValueError(f"convex matrix must be positive definite; eigs={eigs}.")

    steps = max(1, int(args.convex_steps))
    lr = float(args.convex_lr)
    if not np.isfinite(lr) or lr <= 0.0:
        raise ValueError("convex_lr must be a positive finite float.")

    w0 = np.array([float(args.convex_start[0]), float(args.convex_start[1])], dtype=np.float64)
    w_star = np.array([float(args.convex_target[0]), float(args.convex_target[1])], dtype=np.float64)

    def loss_fn(w: np.ndarray) -> float:
        d = w - w_star
        return float(0.5 * d.T @ A @ d)

    diag = np.diag(A)
    if np.any(diag <= 0.0) or np.any(~np.isfinite(diag)):
        raise ValueError("convex matrix diagonal must be positive finite.")
    precond = np.diag(1.0 / diag)

    def simulate(method: str) -> Tuple[np.ndarray, np.ndarray]:
        w = w0.copy()
        w_path = [w.copy()]
        loss_path = [loss_fn(w)]
        for _ in range(steps):
            grad = A @ (w - w_star)
            if method == "bptt":
                w = w - lr * grad
            elif method == "local_rule":
                w = w - lr * (precond @ grad)
            else:
                raise ValueError("method must be 'bptt' or 'local_rule'.")
            w_path.append(w.copy())
            loss_path.append(loss_fn(w))
        return np.stack(w_path, axis=0), np.array(loss_path, dtype=np.float64)

    bptt_w, bptt_loss = simulate("bptt")
    local_w, local_loss = simulate("local_rule")

    plot_dir = build_plot_dir(Path.cwd(), args.task)
    report_path = (
        Path(args.report_path)
        if args.report_path
        else Path("Compare_RNN") / "result" / f"oll_physical_probe_{args.task}_report.md"
    )

    plot_convex_descent_3d(
        bptt_w,
        bptt_loss,
        local_w,
        local_loss,
        plot_dir / "convex_descent_3d.png",
        label_a="BPTT (GD)",
        label_b="local_rule (simplified diag-precond)",
        xlabel="w1",
        ylabel="w2",
        title="Toy Convex Quadratic: local_rule vs BPTT",
    )
    plot_convex_contours_with_paths(
        A,
        w_star,
        bptt_w,
        local_w,
        plot_dir / "convex_contours_paths.png",
        label_a="BPTT (GD)",
        label_b="local_rule (simplified diag-precond)",
    )
    steps_list = list(range(int(bptt_loss.size)))
    plot_two_curves(
        steps_list,
        bptt_loss.tolist(),
        local_loss.tolist(),
        "BPTT (GD)",
        "local_rule (simplified diag-precond)",
        "Toy Convex Quadratic: Loss vs Step",
        "Loss",
        plot_dir / "convex_loss_curve.png",
    )
    np.savez(
        plot_dir / "convex_descent_trajectory.npz",
        A=A,
        w0=w0,
        w_star=w_star,
        lr=lr,
        steps=steps,
        bptt_w=bptt_w,
        bptt_loss=bptt_loss,
        local_w=local_w,
        local_loss=local_loss,
    )

    rel_plot_dir = plot_dir.relative_to(Path.cwd()).as_posix()
    lines = [
        "# Toy Convex Quadratic Descent (local_rule vs BPTT)",
        "",
        "This is a tiny 2D convex quadratic to visualize descent paths and loss curves.",
        "(Note: `local_rule` here is a simplified diagonal-preconditioned GD, not the full OLL RNN update.)",
        "",
        f"- A: {A.tolist()}",
        f"- w0: {w0.tolist()}",
        f"- w*: {w_star.tolist()}",
        f"- lr: {lr}",
        f"- steps: {steps}",
        "",
        f"![Contours + paths]({rel_plot_dir}/convex_contours_paths.png)",
        "",
        f"![Loss curve]({rel_plot_dir}/convex_loss_curve.png)",
        "",
        f"![3D descent]({rel_plot_dir}/convex_descent_3d.png)",
        "",
        f"- Raw trajectory: `{(plot_dir / 'convex_descent_trajectory.npz').as_posix()}`",
    ]
    report_path.parent.mkdir(parents=True, exist_ok=True)
    report_path.write_text("\n".join(lines), encoding="utf-8")
    return report_path, plot_dir


def run_rnn_descent_demo(args: argparse.Namespace) -> Tuple[Path, Path]:
    seed_everything(args.seed)
    device = DEFAULT_DEVICE

    steps = max(1, int(args.descent_steps))
    batch_size = max(1, int(args.batch_size))
    seq_len = max(2, int(args.add_seq_len))

    train_inputs, train_targets, _, _ = load_adding_problem_sequences(
        train_samples=batch_size,
        test_samples=0,
        seq_len=seq_len,
        seed=int(args.seed),
    )
    inputs_batch = train_inputs
    targets_batch = train_targets

    input_size = int(inputs_batch.shape[1])
    output_size = int(targets_batch.shape[1])
    loss_mode = "mse"

    local_model = TorchLocalRuleRNN(
        input_size=input_size,
        hidden_size=int(args.hidden),
        output_size=output_size,
        eta=float(args.lr),
        loss_mode=loss_mode,
        seed=int(args.seed),
        device=device,
    )
    bptt_model = TorchBPTTRNN(
        input_size=input_size,
        hidden_size=int(args.hidden),
        output_size=output_size,
        eta=float(args.lr),
        loss_mode=loss_mode,
        seed=int(args.seed),
        device=device,
    )
    init_params = extract_params(local_model)
    load_params(bptt_model, init_params)

    theta0 = flatten_core_params(init_params).astype(np.float64)
    dim = int(theta0.size)
    theta_local = np.zeros((steps + 1, dim), dtype=np.float64)
    theta_bptt = np.zeros((steps + 1, dim), dtype=np.float64)
    local_losses = np.zeros((steps + 1,), dtype=np.float64)
    bptt_losses = np.zeros((steps + 1,), dtype=np.float64)

    theta_local[0] = theta0
    theta_bptt[0] = theta0
    local_losses[0] = evaluate_sequence_loss(local_model, inputs_batch, targets_batch, loss_mode=loss_mode)
    bptt_losses[0] = evaluate_sequence_loss(bptt_model, inputs_batch, targets_batch, loss_mode=loss_mode)

    h_prev = np.zeros((int(args.hidden), batch_size), dtype=np.float32)
    for t in range(steps):
        local_model.run_one_cycle_and_update_directly(inputs_batch, targets_batch, h_prev)
        bptt_model.train_batch(inputs_batch, targets_batch, h_prev)

        theta_local[t + 1] = flatten_core_params(extract_params(local_model)).astype(np.float64)
        theta_bptt[t + 1] = flatten_core_params(extract_params(bptt_model)).astype(np.float64)
        local_losses[t + 1] = evaluate_sequence_loss(local_model, inputs_batch, targets_batch, loss_mode=loss_mode)
        bptt_losses[t + 1] = evaluate_sequence_loss(bptt_model, inputs_batch, targets_batch, loss_mode=loss_mode)

    rng = np.random.default_rng(
        int(args.descent_proj_seed) if args.descent_proj_seed is not None else int(args.seed)
    )
    proj_mode = str(getattr(args, "descent_proj_mode", "pca")).lower()
    combined = np.vstack([theta_bptt, theta_local])
    if proj_mode == "random":
        v1, v2 = build_random_2d_basis(dim, rng)
        xlabel, ylabel = "proj-1", "proj-2"
        proj_label = "random"
    elif proj_mode in {"endpoints", "endpoint"}:
        v1, v2 = build_endpoints_2d_basis(theta0, theta_bptt[-1], theta_local[-1], rng)
        xlabel, ylabel = "dir-1", "dir-2"
        proj_label = "endpoints"
    else:
        v1, v2 = build_pca_2d_basis(combined, theta0, rng)
        xlabel, ylabel = "pca-1", "pca-2"
        proj_label = "pca"

    bptt_delta = theta_bptt - theta0[None, :]
    local_delta = theta_local - theta0[None, :]
    bptt_path = np.stack([bptt_delta @ v1, bptt_delta @ v2], axis=1)
    local_path = np.stack([local_delta @ v1, local_delta @ v2], axis=1)

    plot_dir = build_plot_dir(Path.cwd(), args.task)
    report_path = (
        Path(args.report_path)
        if args.report_path
        else Path("Compare_RNN") / "result" / f"oll_physical_probe_{args.task}_report.md"
    )

    plot_paths_2d(
        bptt_path,
        local_path,
        plot_dir / "rnn_descent_paths_2d.png",
        label_a="BPTT",
        label_b="OLL (local_rule)",
        xlabel=xlabel,
        ylabel=ylabel,
        title=f"RNN descent paths (2D, {proj_label} plane)",
    )
    plot_convex_descent_3d(
        bptt_path,
        bptt_losses,
        local_path,
        local_losses,
        plot_dir / "rnn_descent_3d.png",
        label_a="BPTT",
        label_b="OLL (local_rule)",
        xlabel=xlabel,
        ylabel=ylabel,
        title=f"RNN (Adding Problem): 2D plane + loss ({proj_label})",
    )
    steps_list = list(range(steps + 1))
    plot_two_curves(
        steps_list,
        bptt_losses.tolist(),
        local_losses.tolist(),
        "BPTT",
        "OLL (local_rule)",
        "Fixed-batch training loss over steps",
        "loss",
        plot_dir / "rnn_descent_loss_curve.png",
    )

    np.savez(
        plot_dir / "rnn_descent_trajectory.npz",
        task="adding_problem_fixed_batch",
        seq_len=seq_len,
        batch_size=batch_size,
        hidden=int(args.hidden),
        lr=float(args.lr),
        steps=steps,
        proj_mode=proj_label,
        proj_seed=int(args.descent_proj_seed) if args.descent_proj_seed is not None else int(args.seed),
        theta0=theta0.astype(np.float32),
        v1=np.asarray(v1, dtype=np.float32),
        v2=np.asarray(v2, dtype=np.float32),
        theta_bptt=theta_bptt.astype(np.float32),
        theta_local=theta_local.astype(np.float32),
        bptt_path=bptt_path,
        bptt_losses=bptt_losses,
        local_path=local_path,
        local_losses=local_losses,
    )

    rel_plot_dir = plot_dir.relative_to(Path.cwd()).as_posix()
    lines = [
        "# RNN Descent Demo (OLL local_rule vs BPTT)",
        "",
        "A small RNN task (Adding Problem) trained on a fixed batch to visualize full OLL(local_rule) vs BPTT descent.",
        "",
        f"- seq_len: {seq_len}",
        f"- batch_size: {batch_size}",
        f"- hidden: {int(args.hidden)}",
        f"- lr: {float(args.lr)}",
        f"- steps: {steps}",
        f"- plane: {proj_label}",
        f"- proj_seed: {int(args.descent_proj_seed) if args.descent_proj_seed is not None else int(args.seed)}",
        f"- final loss (BPTT / OLL): {bptt_losses[-1]:.6f} / {local_losses[-1]:.6f}",
        "",
        f"![2D paths]({rel_plot_dir}/rnn_descent_paths_2d.png)",
        "",
        f"![3D descent]({rel_plot_dir}/rnn_descent_3d.png)",
        "",
        f"![Loss curve]({rel_plot_dir}/rnn_descent_loss_curve.png)",
        "",
        f"- Raw trajectory: `{(plot_dir / 'rnn_descent_trajectory.npz').as_posix()}`",
    ]
    report_path.parent.mkdir(parents=True, exist_ok=True)
    report_path.write_text("\n".join(lines), encoding="utf-8")
    return report_path, plot_dir



def run_probe_experiment(args: argparse.Namespace) -> Dict[str, Any]:
    seed_everything(args.seed)
    device = DEFAULT_DEVICE

    if args.task == "row_mnist":
        (
            train_inputs,
            train_targets,
            train_labels,
            _,
            _,
            _,
        ) = load_row_mnist_sequences(
            train_limit=args.train_limit,
            test_limit=args.test_limit,
        )
        task_label = "Row-by-Row MNIST"
        loss_mode = "ce"
        metric_mode = "classification"
        eval_mode = "teacher"
        eval_warmup = 1
    elif args.task == "lorenz_image":
        train_inputs, train_targets, _, _ = load_lorenz_image_sequences(
            train_samples=args.train_samples,
            test_samples=args.test_samples,
            seq_len=args.seq_len,
            frame_h=args.frame_h,
            frame_w=args.frame_w,
            dt=args.dt,
            sigma=args.sigma,
            rho=args.rho,
            beta=args.beta,
            warmup=args.warmup,
            seed=args.seed,
            blur_sigma=args.blur_sigma,
        )
        train_labels = np.zeros((train_inputs.shape[0],), dtype=np.int64)
        task_label = "Lorenz Image Prediction"
        loss_mode = "mse"
        metric_mode = "regression"
        eval_mode = "rollout"
        eval_warmup = 1
    else:
        train_inputs, train_targets, _, _ = load_adding_problem_sequences(
            train_samples=args.train_samples,
            test_samples=args.test_samples,
            seq_len=args.add_seq_len,
            seed=args.seed,
        )
        train_labels = np.zeros((train_inputs.shape[0],), dtype=np.int64)
        task_label = "Adding Problem (Sequence-to-Sequence)"
        loss_mode = "mse"
        metric_mode = "regression"
        eval_mode = "teacher"
        eval_warmup = 1

    if args.eval_mode is not None:
        requested = str(args.eval_mode).lower()
        if args.task in {"adding_problem", "addtask"} and requested == "rollout":
            requested = "teacher"
        eval_mode = requested
    if args.eval_warmup is not None:
        eval_warmup = max(1, int(args.eval_warmup))

    input_size = train_inputs.shape[1]
    output_size = train_targets.shape[1]
    time_steps = train_inputs.shape[2]

    rng = np.random.default_rng(args.seed)
    tr_inputs, tr_targets, tr_labels, val_inputs, val_targets, val_labels = split_train_val(
        train_inputs, train_targets, train_labels, 0.1, rng
    )

    def build_local() -> TorchLocalRuleRNN:
        return TorchLocalRuleRNN(
            input_size=input_size,
            hidden_size=args.hidden,
            output_size=output_size,
            eta=args.lr,
            loss_mode=loss_mode,
            seed=args.seed,
            device=device,
        )

    if args.task == "lorenz_image":
        gains_default = np.linspace(0.05, 1.70, 12, endpoint=True)
    elif args.task in {"adding_problem", "addtask"}:
        gains_default = np.linspace(0.5, 1.6, 12, endpoint=False)
    else:
        gains_default = np.linspace(0.2, 2.2, 12, endpoint=False)
    if args.gains:
        gains = np.array([float(x) for x in args.gains.split(",") if x.strip()], dtype=np.float32)
        if gains.size == 0:
            gains = gains_default
    elif args.gain is not None:
        gains = np.array([args.gain], dtype=np.float32)
    else:
        gains = gains_default

    critical_metric = str(args.critical_metric).lower()
    if critical_metric not in {"rho_eff", "eff_gain_p95", "lyap_pre", "lyap_post"}:
        raise ValueError(
            "critical_metric must be 'rho_eff', 'eff_gain_p95', 'lyap_pre', or 'lyap_post'."
        )
    critical_target = args.critical_target
    critical_band = args.critical_band
    if critical_metric in {"lyap_pre", "lyap_post"}:
        if args.lyap_target is not None:
            critical_target = args.lyap_target
        if args.lyap_band is not None:
            critical_band = args.lyap_band
    if critical_target is None:
        if critical_metric == "rho_eff":
            critical_target = 1.0
        elif critical_metric == "eff_gain_p95":
            critical_target = 0.95
        else:
            critical_target = 0.0
    critical_batch = min(args.critical_batch or args.batch_size, tr_inputs.shape[0])
    critical_inputs = tr_inputs[:critical_batch]
    lyapunov_driver = (
        build_lyapunov_driver(tr_inputs)
        if critical_metric in {"lyap_pre", "lyap_post"}
        else None
    )

    best_g = float(gains[0])
    best_params: Dict[str, np.ndarray] | None = None
    best_stats: Dict[str, float] = {}
    best_critical_score = float("inf")
    if metric_mode == "classification":
        best_val_metric = -1.0
    else:
        best_val_metric = float("inf")

    print(f"STAGE 1: Scanning gains for {task_label} (probe init)...")
    for g in gains:
        model = build_local()
        model.initialize_weights_with_gain(float(g), seed=args.seed)
        init_params = extract_params(model)
        lyap_pre = float("nan")
        lyap_post = float("nan")
        if lyapunov_driver is not None:
            lyap_pre = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
        train_batches(
            model,
            tr_inputs,
            tr_targets,
            args.batch_size,
            args.scan_epochs,
            args.seed + 1,
        )
        if lyapunov_driver is not None:
            lyap_post = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
        if metric_mode == "classification":
            val_loss, val_acc = evaluate_classifier_final_step(
                model,
                val_inputs,
                val_targets,
                val_labels,
                args.batch_size,
            )
            val_metric = float(val_acc)
        else:
            if eval_mode == "rollout":
                val_loss = evaluate_regression_mse_rollout(
                    model,
                    val_inputs,
                    val_targets,
                    args.batch_size,
                    warmup_steps=eval_warmup,
                )
            else:
                val_loss = evaluate_regression_mse(
                    model,
                    val_inputs,
                    val_targets,
                    args.batch_size,
                )
            val_metric = float(val_loss)

        critical_stats: Dict[str, float] = {}
        critical_metric_value = float("nan")
        critical_score = float("inf")
        if args.gain_select == "critical":
            if critical_metric in {"rho_eff", "eff_gain_p95"}:
                critical_stats = compute_criticality_stats(
                    model,
                    critical_inputs,
                    args.critical_time_points,
                    rng,
                )
                critical_metric_value = float(
                    critical_stats[
                        "rho_eff_median" if critical_metric == "rho_eff" else "eff_gain_p95"
                    ]
                )
            else:
                critical_metric_value = (
                    float(lyap_pre) if critical_metric == "lyap_pre" else float(lyap_post)
                )
            if np.isfinite(critical_metric_value):
                critical_score = abs(critical_metric_value - float(critical_target))
                if critical_band is not None:
                    band = float(critical_band)
                    if abs(critical_metric_value - float(critical_target)) <= band:
                        critical_score = 0.0

        lyap_label = ""
        if lyapunov_driver is not None and np.isfinite(lyap_pre):
            lyap_label += f" | lyap_pre={lyap_pre:.4f}"
        if lyapunov_driver is not None and np.isfinite(lyap_post):
            lyap_label += f" | lyap_post={lyap_post:.4f}"
        crit_label = ""
        if args.gain_select == "critical":
            crit_label = (
                f" | crit_{critical_metric}={critical_metric_value:.4f}"
                f" | score={critical_score:.4f}"
            )

        if metric_mode == "classification":
            if args.gain_select == "critical":
                print(
                    f"[SCAN] g={float(g):.3f} | val_acc={val_acc:.4f} | val_loss={val_loss:.4f}"
                    f"{lyap_label}{crit_label}"
                )
            else:
                print(
                    f"[SCAN] g={float(g):.3f} | val_acc={val_acc:.4f} | val_loss={val_loss:.4f}"
                    f"{lyap_label}"
                )
        else:
            label = "val_mse_rollout" if eval_mode == "rollout" else "val_mse"
            if args.gain_select == "critical":
                print(
                    f"[SCAN] g={float(g):.3f} | {label}={val_loss:.6f}"
                    f"{lyap_label}{crit_label}"
                )
            else:
                print(f"[SCAN] g={float(g):.3f} | {label}={val_loss:.6f}{lyap_label}")

        if metric_mode == "classification":
            val_better = val_metric > best_val_metric
        else:
            val_better = val_metric < best_val_metric

        if args.gain_select == "critical":
            if (
                not best_stats
                or critical_score < best_critical_score
                or (np.isclose(critical_score, best_critical_score) and val_better)
            ):
                best_critical_score = (
                    critical_score if np.isfinite(critical_score) else float("inf")
                )
                best_val_metric = val_metric
                best_g = float(g)
                best_params = init_params
                if metric_mode == "classification":
                    best_stats = {
                        "val_acc": float(val_acc),
                        "val_loss": float(val_loss),
                        "critical_metric": float(critical_metric_value),
                        "critical_score": float(critical_score),
                    }
                else:
                    best_stats = {
                        "val_mse": float(val_loss),
                        "critical_metric": float(critical_metric_value),
                        "critical_score": float(critical_score),
                    }
                if lyapunov_driver is not None and np.isfinite(lyap_pre):
                    best_stats["lyap_pre"] = float(lyap_pre)
                if lyapunov_driver is not None and np.isfinite(lyap_post):
                    best_stats["lyap_post"] = float(lyap_post)
        else:
            if not best_stats or val_better:
                best_val_metric = val_metric
                best_g = float(g)
                best_params = init_params
                if metric_mode == "classification":
                    best_stats = {"val_acc": float(val_acc), "val_loss": float(val_loss)}
                else:
                    best_stats = {"val_mse": float(val_loss)}
                if lyapunov_driver is not None and np.isfinite(lyap_pre):
                    best_stats["lyap_pre"] = float(lyap_pre)
                if lyapunov_driver is not None and np.isfinite(lyap_post):
                    best_stats["lyap_post"] = float(lyap_post)

    best_lyap_label = ""
    if "lyap_pre" in best_stats:
        best_lyap_label += f" | lyap_pre={best_stats['lyap_pre']:.4f}"
    if "lyap_post" in best_stats:
        best_lyap_label += f" | lyap_post={best_stats['lyap_post']:.4f}"

    if metric_mode == "classification":
        if args.gain_select == "critical":
            print(
                f"Best g={best_g:.3f} | val_acc={best_stats['val_acc']:.4f} | "
                f"val_loss={best_stats['val_loss']:.4f} | "
                f"crit_{critical_metric}={best_stats['critical_metric']:.4f}"
                f"{best_lyap_label}"
            )
        else:
            print(
                f"Best g={best_g:.3f} | val_acc={best_stats['val_acc']:.4f} | "
                f"val_loss={best_stats['val_loss']:.4f}"
                f"{best_lyap_label}"
            )
    else:
        label = "val_mse_rollout" if eval_mode == "rollout" else "val_mse"
        if args.gain_select == "critical":
            print(
                f"Best g={best_g:.3f} | {label}={best_stats['val_mse']:.6f} | "
                f"crit_{critical_metric}={best_stats['critical_metric']:.4f}"
                f"{best_lyap_label}"
            )
        else:
            print(f"Best g={best_g:.3f} | {label}={best_stats['val_mse']:.6f}{best_lyap_label}")

    local_model = build_local()
    if best_params is not None:
        load_params(local_model, best_params)
    else:
        local_model.initialize_weights_with_gain(best_g, seed=args.seed)

    probe_model = VanillaRNN(input_size, args.hidden, output_size, device=device, loss_mode=loss_mode)
    step_weights = None
    probe = PhysicalProbe(device=device, step_weights=step_weights)

    total_updates = math.ceil(train_inputs.shape[0] / args.batch_size) * args.epochs
    probe_steps = choose_probe_steps(total_updates, args.probe_points)
    probe_step_set = set(probe_steps)
    analysis_step = None

    probe_inputs = train_inputs[: args.batch_size]
    probe_targets = train_targets[: args.batch_size]

    plot_dir = build_plot_dir(Path.cwd(), args.task)
    report_path = Path(args.report_path) if args.report_path else Path("Compare_RNN") / "result" / f"oll_physical_probe_{args.task}_report.md"

    spectral_curve: List[Dict[str, float]] = []
    grad_curve: List[Dict[str, float]] = []
    rank1_stats: Dict[str, Any] = {}
    eigen_residual_stats: Dict[str, Any] = {}
    rank_dyn_stats: Dict[str, Any] = {}
    ar1_stats: Dict[str, Any] = {}
    delta_approx_ar1_stats: Dict[str, Any] = {}
    duality_stats: Dict[str, Any] = {}
    smooth_stats: Dict[str, Any] = {}
    lambda_stats: Dict[str, Any] = {}
    stability_lambda_stats: Dict[str, Any] = {}
    eff_gain_stats: Dict[str, Any] = {}
    neumann_stats: Dict[str, Any] = {}
    grad_scatter: Dict[str, Any] = {}
    hypothesis_eval: Dict[str, Any] = {}
    eigen_align_curve_dynamic: Dict[str, Any] = {}
    eigen_align_curve_fixed: Dict[str, Any] = {}
    analysis_snapshot: ProbeSnapshot | None = None
    probe_snapshots: Dict[int, ProbeSnapshot] = {}

    def run_probe(current_step: int) -> None:
        probe_model.load_from_local(local_model)
        probe_model.zero_grad(set_to_none=True)
        inputs_t = torch.as_tensor(probe_inputs, device=device, dtype=torch.float32)
        targets_t = torch.as_tensor(probe_targets, device=device, dtype=torch.float32)
        lambda_vals = local_model.lambda_vals.detach()
        _, snapshot_local = probe_model.forward_with_probe(
            inputs_t,
            targets_t,
            step_weights,
            lambda_vals,
            probe,
        )

        snapshot_local.alpha_hat = local_model.alpha_hat.detach().cpu().numpy().reshape(-1)
        snapshot_local.lambda_vals = local_model.lambda_vals.detach().cpu().numpy().reshape(-1)
        probe_snapshots[current_step] = snapshot_local

        rho, _ = compute_spectral_radius(snapshot_local, snapshot_local.alpha_hat)
        spectral_curve.append(
            {
                "step": current_step,
                "rho_max": float(np.max(rho)),
                "rho_mean": float(np.mean(rho)),
            }
        )
        cos, sign = compute_grad_alignment(snapshot_local.delta_true, snapshot_local.delta_approx)
        grad_curve.append({"step": current_step, "cos": cos, "sign": sign})

    step = 0
    if 0 in probe_step_set:
        run_probe(step)
    for epoch in range(args.epochs):
        for inputs_batch, targets_batch in iterate_minibatches(
            train_inputs, train_targets, args.batch_size, rng
        ):
            h_prev = torch.zeros(
                (args.hidden, inputs_batch.shape[0]),
                dtype=torch.float32,
                device=device,
            )
            local_model.run_one_cycle_and_update_directly(
                inputs_batch,
                targets_batch,
                h_prev,
            )
            step += 1
            if step in probe_step_set:
                run_probe(step)

    if not probe_snapshots:
        raise RuntimeError("No probe snapshots captured; increase epochs or probe points.")

    analysis_select = str(args.analysis_select).lower()
    analysis_step_requested = args.analysis_step
    analysis_rho_target = float(args.analysis_rho_target)
    analysis_step = select_probe_step(
        probe_snapshots,
        spectral_curve,
        analysis_select,
        explicit_step=analysis_step_requested,
        rho_target=analysis_rho_target,
        exclude_step0=True,
    )
    analysis_snapshot = probe_snapshots[int(analysis_step)]
    analysis_rho_info = next(
        (item for item in spectral_curve if item["step"] == analysis_step),
        None,
    )
    analysis_rho_max = float(analysis_rho_info["rho_max"]) if analysis_rho_info else float("nan")
    analysis_rho_mean = float(analysis_rho_info["rho_mean"]) if analysis_rho_info else float("nan")

    heatmap_select = str(args.heatmap_select).lower()
    heatmap_step_requested = args.heatmap_step
    heatmap_rho_target = float(args.heatmap_rho_target)
    if heatmap_select == "analysis":
        heatmap_step = int(analysis_step)
    else:
        heatmap_step = select_probe_step(
            probe_snapshots,
            spectral_curve,
            heatmap_select,
            explicit_step=heatmap_step_requested,
            rho_target=heatmap_rho_target,
            exclude_step0=False,
        )
    heatmap_snapshot = probe_snapshots[int(heatmap_step)]
    heatmap_rho_info = next(
        (item for item in spectral_curve if item["step"] == heatmap_step),
        None,
    )
    heatmap_rho_max = float(heatmap_rho_info["rho_max"]) if heatmap_rho_info else float("nan")
    heatmap_rho_mean = float(heatmap_rho_info["rho_mean"]) if heatmap_rho_info else float("nan")

    centers = list(range(args.hidden))
    rank1_stats = compute_rank1_analysis(
        analysis_snapshot,
        centers=centers,
        k_neighbors=args.neighbor_k,
        time_points=args.probe_time_points,
        rng=rng,
    )
    rank1_stats["k_neighbors"] = args.neighbor_k
    rank1_stats["time_points"] = args.probe_time_points
    pc1_scores = [vals[0] for vals in rank1_stats["evr"].values() if vals]
    if pc1_scores:
        rank1_stats["pc1_mean"] = float(np.mean(pc1_scores))
        rank1_stats["pc1_min"] = float(np.min(pc1_scores))
        rank1_stats["pc1_max"] = float(np.max(pc1_scores))
        if rank1_stats["pc1_mean"] >= 0.8:
            rank1_stats["conclusion"] = "强低秩性（PC1 占比显著）"
        elif rank1_stats["pc1_mean"] >= 0.6:
            rank1_stats["conclusion"] = "中等低秩性（PC1 占比中等）"
        else:
            rank1_stats["conclusion"] = "低秩性不显著（PC1 占比偏低）"

    eigen_residual_stats = compute_eigen_residual_stats(
        analysis_snapshot,
        centers=centers,
        k_neighbors=args.neighbor_k,
        rng=rng,
    )
    ref_step = min(probe_snapshots.keys()) if probe_snapshots else 0
    ref_snapshot = probe_snapshots.get(ref_step, analysis_snapshot)
    ref_neighbors_map = build_neighbors_map(ref_snapshot.W_hh, centers, args.neighbor_k)
    eigen_align_curve_dynamic = compute_eigen_alignment_curve(
        probe_snapshots,
        centers=centers,
        k_neighbors=args.neighbor_k,
        rng=rng,
        neighbors_map=None,
    )
    eigen_align_curve_fixed = compute_eigen_alignment_curve(
        probe_snapshots,
        centers=centers,
        k_neighbors=args.neighbor_k,
        rng=rng,
        neighbors_map=ref_neighbors_map,
    )
    eigen_align_curve_dynamic["neighbors_mode"] = "dynamic"
    eigen_align_curve_fixed["neighbors_mode"] = "fixed"
    eigen_align_curve_fixed["ref_step"] = int(ref_step)

    rank_dyn_stats = compute_dynamic_rank_stats(
        analysis_snapshot,
        centers=centers,
        k_neighbors=args.neighbor_k,
        window=args.rank_window,
        stride=args.rank_stride,
    )

    ar1 = compute_ar1_stats(analysis_snapshot.h)
    ar1_stats = {
        "r2_zero_mean": float(np.mean(ar1["r2_zero"])),
        "r2_zero_median": float(np.median(ar1["r2_zero"])),
        "r2_centered_mean": float(np.mean(ar1["r2_centered"])),
        "r2_centered_median": float(np.median(ar1["r2_centered"])),
        "corr_mean": float(np.mean(ar1["corr"])),
        "corr_median": float(np.median(ar1["corr"])),
        "r2_zero": ar1["r2_zero"],
        "r2_centered": ar1["r2_centered"],
        "corr": ar1["corr"],
    }
    delta_approx_ar1 = compute_ar1_stats(analysis_snapshot.delta_approx)
    delta_approx_ar1_stats = {
        "alpha_mean": float(np.mean(delta_approx_ar1["alpha"])),
        "alpha_median": float(np.median(delta_approx_ar1["alpha"])),
        "alpha_centered_mean": float(np.mean(delta_approx_ar1["alpha_centered"])),
        "alpha_centered_median": float(np.median(delta_approx_ar1["alpha_centered"])),
        "r2_zero_mean": float(np.mean(delta_approx_ar1["r2_zero"])),
        "r2_zero_median": float(np.median(delta_approx_ar1["r2_zero"])),
        "r2_centered_mean": float(np.mean(delta_approx_ar1["r2_centered"])),
        "r2_centered_median": float(np.median(delta_approx_ar1["r2_centered"])),
        "corr_mean": float(np.mean(delta_approx_ar1["corr"])),
        "corr_median": float(np.median(delta_approx_ar1["corr"])),
        "alpha": delta_approx_ar1["alpha"],
        "alpha_centered": delta_approx_ar1["alpha_centered"],
        "r2_zero": delta_approx_ar1["r2_zero"],
        "r2_centered": delta_approx_ar1["r2_centered"],
        "corr": delta_approx_ar1["corr"],
    }
    grad_corr = compute_gradient_duality(analysis_snapshot.delta_true)
    duality_stats = {
        "corr_mean": float(np.mean(grad_corr)),
        "corr_median": float(np.median(grad_corr)),
        "corr": grad_corr,
    }
    smooth_corr = compute_smoothed_autocorr(analysis_snapshot.delta_true, args.grad_smooth_rho)
    smooth_stats = {
        "grad_corr_mean": float(np.mean(smooth_corr)),
        "grad_corr_median": float(np.median(smooth_corr)),
        "grad_corr": smooth_corr,
        "rho": float(args.grad_smooth_rho),
    }

    lambda_consistency = compute_lambda_consistency(analysis_snapshot, analysis_snapshot.alpha_hat)
    lambda_stats = {
        "r2_mean": float(np.mean(lambda_consistency["r2"])),
        "r2_median": float(np.median(lambda_consistency["r2"])),
        "corr_mean": float(np.mean(lambda_consistency["corr"])),
        "corr_median": float(np.median(lambda_consistency["corr"])),
    }

    lambda_curve_stats = compute_lambda_volatility(probe_snapshots)
    pc1_scores = np.array(
        [rank1_stats["pc1_cos"].get(center, np.nan) for center in centers],
        dtype=np.float64,
    )
    lambda_volatility = np.array(
        lambda_curve_stats.get("volatility_excl0", lambda_curve_stats.get("volatility", np.array([]))),
        dtype=np.float64,
    )
    lambda_volatility_centers = (
        lambda_volatility[centers] if lambda_volatility.size else np.array([])
    )
    lambda_vol_p95 = (
        float(np.percentile(lambda_volatility, 95)) if lambda_volatility.size else float("nan")
    )
    lambda_vol_median = (
        float(np.median(lambda_volatility)) if lambda_volatility.size else float("nan")
    )
    pc1_lambda_corr = float("nan")
    if pc1_scores.size and lambda_volatility_centers.size:
        mask = ~np.isnan(pc1_scores) & ~np.isnan(lambda_volatility_centers)
        if np.sum(mask) >= 2:
            scores = pc1_scores[mask]
            vols = lambda_volatility_centers[mask]
            if float(np.std(scores)) > 1e-12 and float(np.std(vols)) > 1e-12:
                pc1_lambda_corr = float(np.corrcoef(scores, vols)[0, 1])

    pc1_items = [
        (center, score)
        for center, score in rank1_stats["pc1_cos"].items()
        if not np.isnan(score)
    ]
    most_stable = None
    least_stable = None
    most_stable_score = float("nan")
    least_stable_score = float("nan")
    most_stable_vol = float("nan")
    least_stable_vol = float("nan")
    if pc1_items:
        most_stable, most_stable_score = max(pc1_items, key=lambda item: item[1])
        least_stable, least_stable_score = min(pc1_items, key=lambda item: item[1])
        if lambda_volatility.size:
            most_stable_vol = float(lambda_volatility[most_stable])
            least_stable_vol = float(lambda_volatility[least_stable])

    stability_lambda_stats = {
        "most_stable": most_stable,
        "least_stable": least_stable,
        "most_stable_score": most_stable_score,
        "least_stable_score": least_stable_score,
        "most_stable_volatility": most_stable_vol,
        "least_stable_volatility": least_stable_vol,
        "pc1_lambda_corr": pc1_lambda_corr,
        "vol_p95": lambda_vol_p95,
        "vol_median": lambda_vol_median,
        "lambda_vol_p95": lambda_vol_p95,
        "lambda_vol_median": lambda_vol_median,
    }

    eff_gain_stats = compute_effective_gain_stats(analysis_snapshot)

    rho, _ = compute_spectral_radius(analysis_snapshot, analysis_snapshot.alpha_hat)
    neumann_err = compute_neumann_error(
        analysis_snapshot,
        analysis_snapshot.alpha_hat,
        analysis_snapshot.lambda_vals,
        rng=rng,
    )
    err_p90 = float(np.percentile(neumann_err, 90)) if neumann_err.size else float("nan")
    neumann_stats = {
        "rho_mean": float(np.mean(rho)),
        "rho_max": float(np.max(rho)),
        "err_mean": float(np.mean(neumann_err)) if neumann_err.size else float("nan"),
        "err_median": float(np.median(neumann_err)) if neumann_err.size else float("nan"),
        "err_p90": err_p90,
        "errors": neumann_err,
    }

    last_t = heatmap_snapshot.delta_true.shape[0] - 1
    grad_scale_time = (
        float(step_weights.sum().item())
        if step_weights is not None
        else float(heatmap_snapshot.delta_true.shape[0])
    )
    grad_scale = float(heatmap_snapshot.delta_true.shape[2]) * max(grad_scale_time, 1.0)
    grad_scatter = {
        "delta_true": heatmap_snapshot.delta_true[last_t].mean(axis=1) * grad_scale,
        "delta_approx": heatmap_snapshot.delta_approx[last_t].mean(axis=1),
    }
    delta_true_heat = heatmap_snapshot.delta_true.mean(axis=2) * grad_scale
    delta_approx_heat = heatmap_snapshot.delta_approx.mean(axis=2)

    plot_rank1_scree(rank1_stats["evr"], plot_dir / "rank1_scree.png")
    s_samples, _ = sample_time_points(analysis_snapshot.s, rng, args.probe_time_points)
    ref_center = rank1_stats["centers"][0]
    neighbor_list = rank1_stats["neighbors"][ref_center]
    if len(neighbor_list) >= 3:
        pick = rng.choice(neighbor_list, size=3, replace=False).tolist()
        plot_rank1_manifold(s_samples, pick, plot_dir / "rank1_manifold.png")

    residuals = np.asarray(eigen_residual_stats.get("residuals", []), dtype=np.float64)
    residuals = residuals[np.isfinite(residuals)]
    if residuals.size:
        plot_histogram(
            residuals,
            "Eigen-Residual Relative Norm",
            "relative residual",
            plot_dir / "eigen_residual_hist.png",
        )
    if eigen_align_curve_dynamic.get("steps") and eigen_align_curve_fixed.get("steps"):
        steps = [int(s) for s in eigen_align_curve_dynamic["steps"]]
        fixed_label = f"fixed@step{ref_step}"
        plot_two_curves(
            steps,
            list(eigen_align_curve_dynamic.get("eig_align_median", [])),
            list(eigen_align_curve_fixed.get("eig_align_median", [])),
            "dynamic neighbors",
            fixed_label,
            "Align($q_s$, dominant eigvec of $A_{local}$) (median)",
            "alignment",
            plot_dir / "eig_align_curve.png",
            ylim=(0.0, 1.0),
        )
        plot_two_curves(
            steps,
            list(eigen_align_curve_dynamic.get("svd_align_median", [])),
            list(eigen_align_curve_fixed.get("svd_align_median", [])),
            "dynamic neighbors",
            fixed_label,
            "Align($q_s$, dominant singular vec of $A_{local}$) (median)",
            "alignment",
            plot_dir / "svd_align_curve.png",
            ylim=(0.0, 1.0),
        )
        plot_two_curves(
            steps,
            list(eigen_align_curve_dynamic.get("residual_median", [])),
            list(eigen_align_curve_fixed.get("residual_median", [])),
            "dynamic neighbors",
            fixed_label,
            "Eigen-residual $\\|Aq-\\mu q\\|/\\|Aq\\|$ (median)",
            "relative residual",
            plot_dir / "eigen_residual_curve.png",
        )

    plot_histogram(
        ar1_stats["r2_centered"],
        "AR(1) $R^2$ Distribution (Centered)",
        "$R^2$ (centered)",
        plot_dir / "ar1_r2_hist.png",
    )
    plot_histogram(
        delta_approx_ar1_stats["r2_zero"],
        r"OLL Teaching Signal $\tilde{\delta}$: AR(1) $R^2$ (Zero-intercept)",
        r"$R^2$ (zero-intercept)",
        plot_dir / "delta_approx_ar1_r2_hist.png",
    )
    plot_histogram(
        delta_approx_ar1_stats["alpha"],
        r"OLL Teaching Signal $\tilde{\delta}$: AR(1) $\alpha$",
        r"$\alpha$",
        plot_dir / "delta_approx_ar1_alpha_hist.png",
    )
    plot_scatter(
        ar1_stats["corr"],
        smooth_stats["grad_corr"],
        "Activation vs Smoothed Gradient Autocorrelation",
        r"$\mathrm{corr}(h_t, h_{t-1})$",
        r"$\mathrm{corr}(\mathrm{EWMA}(\delta)_t, \mathrm{EWMA}(\delta)_{t-1})$",
        plot_dir / "duality_scatter_smoothed.png",
    )
    if pc1_scores.size and lambda_volatility_centers.size:
        plot_scatter(
            pc1_scores,
            lambda_volatility_centers,
            "PC1 Stability vs Lambda Volatility",
            "PC1 stability (cosine)",
            r"$\mathrm{std}(\lambda)$ over probe steps (excluding step0)",
            plot_dir / "stability_lambda_volatility_scatter.png",
            reference=False,
        )
    if (
        lambda_curve_stats["lambda_series"].size
        and most_stable is not None
        and least_stable is not None
    ):
        plot_lambda_trajectories(
            lambda_curve_stats["steps"],
            lambda_curve_stats["lambda_series"][:, most_stable],
            lambda_curve_stats["lambda_series"][:, least_stable],
            most_stable,
            least_stable,
            plot_dir / "lambda_trajectory_extremes.png",
        )

    spectral_steps = [item["step"] for item in spectral_curve]
    spectral_vals = [item["rho_max"] for item in spectral_curve]
    plot_curve(
        spectral_steps,
        spectral_vals,
        "Spectral Radius Over Training",
        r"$\max \, \rho(A_t)$",
        plot_dir / "spectral_radius_curve.png",
    )
    plot_histogram(
        neumann_stats["errors"],
        "Neumann Approximation Relative Error",
        "relative error",
        plot_dir / "neumann_error_hist.png",
    )

    grad_steps = [item["step"] for item in grad_curve]
    grad_cos = [item["cos"] for item in grad_curve]
    plot_curve(
        grad_steps,
        grad_cos,
        "Gradient Cosine Similarity",
        r"$\cos(\delta_{true}, \tilde{\delta})$",
        plot_dir / "grad_cos_curve.png",
    )
    plot_scatter(
        grad_scatter["delta_true"],
        grad_scatter["delta_approx"],
        "OLL vs BPTT Teaching Signal",
        r"$\delta_{true}$ (rescaled)",
        r"$\tilde{\delta}$",
        plot_dir / "grad_scatter.png",
    )
    plot_heatmap(
        delta_true_heat,
        r"BPTT True Gradient $\delta_{true}$ (mean over batch, rescaled)",
        "Neuron index",
        "Time step",
        plot_dir / "delta_true_heatmap.png",
        cmap="coolwarm",
        center=0.0,
        symmetric=True,
    )
    plot_heatmap(
        delta_approx_heat,
        r"OLL Teaching Signal $\tilde{\delta}$ (mean over batch)",
        "Neuron index",
        "Time step",
        plot_dir / "delta_approx_heatmap.png",
        cmap="coolwarm",
        center=0.0,
        symmetric=True,
    )

    lyap_diag: Dict[str, Any] | None = None
    if args.task == "lorenz_image":
        driver = build_lyapunov_driver(val_inputs)
        lyap_diag = compute_lyapunov_diagnostics_numpy(local_model, driver, seed=int(args.seed))
        if isinstance(lyap_diag.get("log_growth"), np.ndarray) and lyap_diag["log_growth"].size:
            plot_histogram(
                lyap_diag["log_growth"],
                "Lyapunov: Instantaneous log-growth",
                r"$\log \|J_t v\|$",
                plot_dir / "lyapunov_log_growth_hist.png",
            )
            plot_xy_curve(
                np.arange(lyap_diag["log_growth"].size),
                lyap_diag["log_growth"],
                "Lyapunov log-growth over time",
                "Time step",
                r"$\log \|J_t v\|$",
                plot_dir / "lyapunov_log_growth_curve.png",
            )

    grad_stats = {
        "cos_mean": float(np.mean([item["cos"] for item in grad_curve])),
        "cos_median": float(np.median([item["cos"] for item in grad_curve])),
        "sign_mean": float(np.mean([item["sign"] for item in grad_curve])),
        "sign_median": float(np.median([item["sign"] for item in grad_curve])),
    }

    attempt_idx = max(1, int(args.hypothesis_attempt))
    attempts_per_tier = max(1, int(args.attempts_per_tier))
    tier_idx, target_tier, target_thresholds = select_threshold_tier(attempt_idx, attempts_per_tier)
    target_eval = evaluate_hypotheses(
        rank_dyn_stats,
        ar1_stats,
        smooth_stats,
        eff_gain_stats,
        stability_lambda_stats,
        neumann_stats,
        grad_stats,
        target_thresholds,
    )
    tier_eval = evaluate_all_tiers(
        rank_dyn_stats,
        ar1_stats,
        smooth_stats,
        eff_gain_stats,
        stability_lambda_stats,
        neumann_stats,
        grad_stats,
    )
    best_idx = tier_eval["best_idx"]
    best_tier = tier_eval["best_tier"]
    best_thresholds = tier_eval["best_thresholds"]
    best_eval = tier_eval["best_eval"]
    tier_start = ((attempt_idx - 1) // attempts_per_tier) * attempts_per_tier + 1
    tier_end = tier_start + attempts_per_tier - 1

    hypothesis_eval = {
        "tier_name": target_tier,
        "tier_idx": tier_idx,
        "attempt": int(attempt_idx),
        "attempts_per_tier": int(attempts_per_tier),
        "attempt_range": (int(tier_start), int(tier_end)),
        "target_tier": target_tier,
        "target_passed": bool(target_eval["passed"]),
        "passed": bool(target_eval["passed"]),
        "best_tier": best_tier,
        "best_idx": best_idx,
        "best_passed": bool(best_eval["passed"]) if best_eval is not None else False,
        "H1": bool(target_eval["H1"]),
        "H2": bool(target_eval["H2"]),
        "H3": bool(target_eval["H3"]),
        "H4": bool(target_eval["H4"]),
        "H5": bool(target_eval["H5"]),
        "thresholds": target_thresholds,
        "best_thresholds": best_thresholds,
        "tier_results": tier_eval["tier_results"],
    }

    config = {
        "task": args.task,
        "task_label": task_label,
        "epochs": args.epochs,
        "scan_epochs": args.scan_epochs,
        "batch_size": args.batch_size,
        "hidden": args.hidden,
        "lr": args.lr,
        "seed": int(args.seed),
        "gain_select": args.gain_select,
        "critical_metric": critical_metric,
        "critical_target": float(critical_target),
        "critical_band": critical_band,
        "lyap_target": args.lyap_target,
        "lyap_band": args.lyap_band,
        "lyap_stage": args.lyap_stage,
        "critical_time_points": args.critical_time_points,
        "critical_batch": int(critical_batch),
        "gains": [float(x) for x in gains.tolist()],
        "best_gain": best_g,
        "best_val_acc": best_stats.get("val_acc"),
        "best_val_loss": best_stats.get("val_loss"),
        "best_val_mse": best_stats.get("val_mse"),
        "best_critical_metric": best_stats.get("critical_metric"),
        "best_critical_score": best_stats.get("critical_score"),
        "best_lyap_pre": best_stats.get("lyap_pre"),
        "best_lyap_post": best_stats.get("lyap_post"),
        "probe_points": args.probe_points,
        "probe_time_points": args.probe_time_points,
        "neighbor_k": args.neighbor_k,
        "rank_window": args.rank_window,
        "rank_stride": args.rank_stride,
        "grad_smooth_rho": args.grad_smooth_rho,
        "analysis_select": analysis_select,
        "analysis_step_requested": analysis_step_requested,
        "analysis_step": int(analysis_step),
        "analysis_rho_target": float(analysis_rho_target),
        "analysis_rho_max": analysis_rho_max,
        "analysis_rho_mean": analysis_rho_mean,
        "heatmap_select": heatmap_select,
        "heatmap_step_requested": heatmap_step_requested,
        "heatmap_step": int(heatmap_step),
        "heatmap_rho_target": float(heatmap_rho_target),
        "heatmap_rho_max": heatmap_rho_max,
        "heatmap_rho_mean": heatmap_rho_mean,
        "threshold_tier": target_tier,
        "threshold_best_tier": best_tier,
        "threshold_attempt": int(attempt_idx),
        "threshold_attempts_per_tier": int(attempts_per_tier),
        "threshold_target_tier": target_tier,
        "threshold_target_passed": bool(target_eval["passed"]),
        "threshold_best_passed": bool(best_eval["passed"]) if best_eval is not None else False,
        "device": str(device),
        "time_steps": time_steps,
        "loss_mode": loss_mode,
        "eval_mode": eval_mode,
        "eval_warmup": eval_warmup,
    }
    if args.task == "row_mnist":
        config["train_limit"] = args.train_limit
        config["test_limit"] = args.test_limit
    elif args.task == "lorenz_image":
        config.update(
            {
                "train_samples": args.train_samples,
                "test_samples": args.test_samples,
                "seq_len": args.seq_len,
                "frame_h": args.frame_h,
                "frame_w": args.frame_w,
                "dt": args.dt,
                "sigma": args.sigma,
                "rho": args.rho,
                "beta": args.beta,
                "warmup": args.warmup,
                "blur_sigma": args.blur_sigma,
            }
        )
    else:
        config.update(
            {
                "train_samples": args.train_samples,
                "test_samples": args.test_samples,
                "add_seq_len": args.add_seq_len,
            }
        )


    return {
        "report_path": report_path,
        "plot_dir": plot_dir,
        "task_label": task_label,
        "config": config,
        "lyap_diag": lyap_diag,
        "rank1_stats": rank1_stats,
        "eigen_residual_stats": eigen_residual_stats,
        "eigen_align_curve_dynamic": eigen_align_curve_dynamic,
        "eigen_align_curve_fixed": eigen_align_curve_fixed,
        "rank_dyn_stats": rank_dyn_stats,
        "ar1_stats": ar1_stats,
        "delta_approx_ar1_stats": delta_approx_ar1_stats,
        "duality_stats": duality_stats,
        "smooth_stats": smooth_stats,
        "lambda_stats": lambda_stats,
        "stability_lambda_stats": stability_lambda_stats,
        "eff_gain_stats": eff_gain_stats,
        "neumann_stats": neumann_stats,
        "grad_stats": grad_stats,
        "hypothesis_eval": hypothesis_eval,
        "spectral_curve": spectral_curve,
    }

def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--task",
        type=str,
        choices=[
            "row_mnist",
            "lorenz_image",
            "adding_problem",
            "addtask",
            "convex_quadratic",
            "rnn_descent",
        ],
        default="row_mnist",
    )
    parser.add_argument("--epochs", type=int, default=40)
    parser.add_argument("--scan-epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--hidden", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--train-limit", type=int, default=60000)
    parser.add_argument("--test-limit", type=int, default=10000)
    parser.add_argument("--gain", type=float, default=None)
    parser.add_argument("--gains", type=str, default=None)
    parser.add_argument("--gain-select", type=str, choices=["val", "critical"], default="val")
    parser.add_argument(
        "--critical-metric",
        type=str,
        choices=["rho_eff", "eff_gain_p95", "lyap_pre", "lyap_post"],
        default="rho_eff",
    )
    parser.add_argument("--critical-target", type=float, default=None)
    parser.add_argument("--critical-band", type=float, default=None)
    parser.add_argument("--critical-batch", type=int, default=None)
    parser.add_argument("--critical-time-points", type=int, default=10)
    parser.add_argument("--lyap-target", type=float, default=None)
    parser.add_argument("--lyap-band", type=float, default=None)
    parser.add_argument("--lyap-stage", type=str, choices=["pre", "post"], default="pre")
    parser.add_argument("--probe-points", type=int, default=6)
    parser.add_argument("--probe-time-points", type=int, default=100)
    parser.add_argument("--probe-centers", type=int, default=5)
    parser.add_argument(
        "--analysis-select",
        type=str,
        choices=["rho_target", "final", "step"],
        default="rho_target",
        help="How to pick the snapshot used for analysis metrics.",
    )
    parser.add_argument(
        "--analysis-step",
        type=int,
        default=None,
        help="Probe step to analyze when --analysis-select=step (nearest available is used).",
    )
    parser.add_argument(
        "--analysis-rho-target",
        type=float,
        default=1.0,
        help="Target rho_max when --analysis-select=rho_target.",
    )
    parser.add_argument(
        "--heatmap-select",
        type=str,
        choices=["analysis", "rho_target", "final", "step"],
        default="analysis",
        help="How to pick the snapshot used for gradient scatter/heatmaps.",
    )
    parser.add_argument(
        "--heatmap-step",
        type=int,
        default=None,
        help="Probe step for heatmaps when --heatmap-select=step (nearest available is used).",
    )
    parser.add_argument(
        "--heatmap-rho-target",
        type=float,
        default=1.0,
        help="Target rho_max when --heatmap-select=rho_target.",
    )
    parser.add_argument("--neighbor-k", type=int, default=30)
    parser.add_argument("--rank-window", type=int, default=6)
    parser.add_argument("--rank-stride", type=int, default=2)
    parser.add_argument("--grad-smooth-rho", type=float, default=0.8)
    parser.add_argument("--hypothesis-attempt", type=int, default=1)
    parser.add_argument("--attempts-per-tier", type=int, default=5)
    parser.add_argument("--max-attempts", type=int, default=1)
    parser.add_argument("--retrain-on-fail", action="store_true")
    parser.add_argument("--run-all-attempts", action="store_true")
    parser.add_argument("--report-path", type=str, default=None)
    parser.add_argument("--train-samples", type=int, default=10000)
    parser.add_argument("--test-samples", type=int, default=1000)
    parser.add_argument("--seq-len", type=int, default=30)
    parser.add_argument("--add-seq-len", type=int, default=64)
    parser.add_argument("--frame-h", type=int, default=16)
    parser.add_argument("--frame-w", type=int, default=16)
    parser.add_argument("--dt", type=float, default=0.01)
    parser.add_argument("--sigma", type=float, default=10.0)
    parser.add_argument("--rho", type=float, default=28.0)
    parser.add_argument("--beta", type=float, default=2.6667)
    parser.add_argument("--warmup", type=int, default=100)
    parser.add_argument("--blur-sigma", type=float, default=1.2)
    parser.add_argument("--eval-mode", type=str, choices=["teacher", "rollout"], default=None)
    parser.add_argument("--eval-warmup", type=int, default=None)
    parser.add_argument("--convex-steps", type=int, default=120)
    parser.add_argument("--convex-lr", type=float, default=0.18)
    parser.add_argument("--convex-a11", type=float, default=8.0)
    parser.add_argument("--convex-a22", type=float, default=2.0)
    parser.add_argument("--convex-a12", type=float, default=3.0)
    parser.add_argument("--convex-start", type=float, nargs=2, default=(-2.5, 2.0))
    parser.add_argument("--convex-target", type=float, nargs=2, default=(1.0, -1.0))
    parser.add_argument("--descent-steps", type=int, default=120)
    parser.add_argument("--descent-proj-seed", type=int, default=None)
    parser.add_argument("--descent-proj-mode", type=str, choices=["pca", "endpoints", "random"], default="pca")
    args = parser.parse_args()

    if args.task == "convex_quadratic":
        report_path, plot_dir = run_convex_quadratic_demo(args)
        print(f"[DONE] Report: {report_path}")
        print(f"[DONE] Plots: {plot_dir}")
        return
    if args.task == "rnn_descent":
        report_path, plot_dir = run_rnn_descent_demo(args)
        print(f"[DONE] Report: {report_path}")
        print(f"[DONE] Plots: {plot_dir}")
        return

    base_seed = int(args.seed)
    base_attempt = max(1, int(args.hypothesis_attempt))
    max_attempts = int(args.max_attempts)
    if max_attempts <= 0:
        max_attempts = 10000
    retrain_on_fail = bool(args.retrain_on_fail)
    best_result: Dict[str, Any] | None = None
    best_passed = False
    attempt_summaries: List[Dict[str, Any]] = []

    for attempt_offset in range(max_attempts):
        attempt_idx = base_attempt + attempt_offset
        if retrain_on_fail:
            args.seed = base_seed + attempt_offset
        else:
            args.seed = base_seed
        args.hypothesis_attempt = attempt_idx

        result = run_probe_experiment(args)
        passed = bool(result["hypothesis_eval"].get("passed", False))
        tier_idx = result["hypothesis_eval"].get("tier_idx")
        config = result.get("config", {})
        attempt_summaries.append(
            {
                "attempt": attempt_idx,
                "seed": int(args.seed),
                "target_tier": result["hypothesis_eval"].get("target_tier"),
                "target_passed": passed,
                "best_tier": result["hypothesis_eval"].get("best_tier"),
                "best_tier_idx": result["hypothesis_eval"].get("best_idx"),
                "best_gain": config.get("best_gain"),
                "best_val_acc": config.get("best_val_acc"),
                "best_val_mse": config.get("best_val_mse"),
                "best_critical_metric": config.get("best_critical_metric"),
                "best_critical_score": config.get("best_critical_score"),
                "best_lyap_pre": config.get("best_lyap_pre"),
                "best_lyap_post": config.get("best_lyap_post"),
            }
        )
        if passed and not best_passed:
            best_result = result
            best_passed = True
        elif not best_passed:
            best_result = result
        if passed and not args.run_all_attempts:
            break

    if best_result is None:
        raise RuntimeError("No experiment results captured.")

    best_result["config"]["retrain_on_fail"] = retrain_on_fail
    best_result["config"]["retrain_attempts_max"] = int(max_attempts)
    best_result["config"]["retrain_attempts_run"] = int(len(attempt_summaries))
    best_result["config"]["retrain_best_attempt"] = int(best_result["hypothesis_eval"].get("attempt", base_attempt))
    best_result["config"]["retrain_best_tier"] = str(best_result["hypothesis_eval"].get("tier_name", "unknown"))
    best_result["config"]["retrain_best_seed"] = int(best_result["config"].get("seed", base_seed))
    best_result["config"]["threshold_attempt_start"] = int(base_attempt)
    best_result["config"]["threshold_attempt_max"] = int(max_attempts)
    best_result["config"]["threshold_attempts_run"] = int(len(attempt_summaries))

    write_report(
        best_result["report_path"],
        best_result["plot_dir"],
        best_result["task_label"],
        best_result["config"],
        best_result["rank1_stats"],
        best_result["eigen_residual_stats"],
        best_result["rank_dyn_stats"],
        best_result["ar1_stats"],
        best_result["delta_approx_ar1_stats"],
        best_result["duality_stats"],
        best_result["smooth_stats"],
        best_result["lambda_stats"],
        best_result["stability_lambda_stats"],
        best_result["eff_gain_stats"],
        best_result["neumann_stats"],
        best_result["grad_stats"],
        best_result["hypothesis_eval"],
        best_result["spectral_curve"],
        attempt_summaries,
        eigen_align_curve_dynamic=best_result.get("eigen_align_curve_dynamic"),
        eigen_align_curve_fixed=best_result.get("eigen_align_curve_fixed"),
        lyap_diag=best_result.get("lyap_diag"),
    )

    print(f"[DONE] Report: {best_result['report_path']}")
    print(f"[DONE] Plots: {best_result['plot_dir']}")
if __name__ == "__main__":
    main()
