from __future__ import annotations

import argparse
import json
import math
import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any

import numpy as np
import torch


def _timestamp() -> str:
    return datetime.now().strftime("%Y%m%d_%H%M%S")


def _ensure_dir(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


def _safe_norm(x: np.ndarray, eps: float = 1e-12) -> float:
    return float(np.linalg.norm(x) + eps)


def _cosine(a: np.ndarray, b: np.ndarray, eps: float = 1e-12) -> float:
    denom = (np.linalg.norm(a) * np.linalg.norm(b)) + eps
    return float(np.dot(a, b) / denom)


def _relerr(a: np.ndarray, b: np.ndarray, eps: float = 1e-12) -> float:
    return float(np.linalg.norm(a - b) / (np.linalg.norm(a) + eps))


def _torch_softmax0(y: torch.Tensor) -> torch.Tensor:
    return torch.softmax(y, dim=0)


def _to_np(x: torch.Tensor) -> np.ndarray:
    return x.detach().cpu().numpy()


def _add_repo_to_path() -> Path:
    compare_dir = Path(__file__).resolve().parents[2]
    if str(compare_dir) not in sys.path:
        sys.path.insert(0, str(compare_dir))
    return compare_dir


COMPARE_DIR = _add_repo_to_path()

from task.common.sequence_core import (  # noqa: E402
    TorchLocalRuleRNN,
    build_repeated_targets,
    build_lyapunov_driver,
    calculate_lyapunov_exponent_numpy,
    evaluate_classifier_final_step,
    load_mnist_images,
    train_batches,
)


def load_row_mnist_sequences(
    train_limit: int,
    test_limit: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    train_images, train_labels, test_images, test_labels = load_mnist_images(
        train_limit=int(train_limit),
        test_limit=int(test_limit),
    )
    train_inputs = np.transpose(train_images, (0, 2, 1)).astype(np.float32)
    test_inputs = np.transpose(test_images, (0, 2, 1)).astype(np.float32)
    time_steps = int(train_inputs.shape[2])
    train_targets = build_repeated_targets(train_labels, 10, time_steps)
    test_targets = build_repeated_targets(test_labels, 10, time_steps)
    return train_inputs, train_targets, train_labels, test_inputs, test_targets, test_labels


def integrated_autocorr_time(x: np.ndarray, max_lag: int = 200) -> float:
    x = np.asarray(x, dtype=np.float64).reshape(-1)
    x = x - float(np.mean(x))
    var = float(np.var(x))
    if not np.isfinite(var) or var < 1e-12:
        return float("nan")
    max_lag = int(min(max_lag, x.size - 2))
    if max_lag < 2:
        return float("nan")
    tau = 0.5
    for k in range(1, max_lag + 1):
        c = float(np.mean(x[:-k] * x[k:]) / (var + 1e-12))
        if not np.isfinite(c) or c <= 0.0:
            break
        tau += c
    return float(tau)


def explained_variance_ratios(samples: np.ndarray, top_k: int = 8) -> list[float]:
    """
    samples: (N, M) columns are samples.
    Returns EVR for the top components of centered samples.
    """
    X = np.asarray(samples, dtype=np.float64)
    Xc = X - np.mean(X, axis=1, keepdims=True)
    try:
        _, s, _ = np.linalg.svd(Xc, full_matrices=False)
    except np.linalg.LinAlgError:
        return []
    power = s * s
    denom = float(np.sum(power) + 1e-12)
    evr = (power / denom).tolist()
    return [float(v) for v in evr[: max(1, int(top_k))]]


def resolvent_stats(B: np.ndarray) -> dict[str, float]:
    N = int(B.shape[0])
    eye = np.eye(N, dtype=np.float64)
    try:
        R = np.linalg.solve(eye - B, eye)
    except np.linalg.LinAlgError:
        return {"rho": float("nan"), "evr1": float("nan"), "diag_ratio": float("nan")}
    try:
        s = np.linalg.svd(R, compute_uv=False)
        power = s * s
        evr1 = float(power[0] / (np.sum(power) + 1e-12))
    except np.linalg.LinAlgError:
        evr1 = float("nan")
    diag = np.diag(R).astype(np.float64)
    diag_ratio = float(np.sum(diag * diag) / (np.sum(R * R) + 1e-12))
    try:
        eig = np.linalg.eigvals(B)
        rho = float(np.max(np.abs(eig)))
    except np.linalg.LinAlgError:
        rho = float("nan")
    return {"rho": rho, "evr1": evr1, "diag_ratio": diag_ratio}


@dataclass
class ProbeMetrics:
    cos_median: float
    relerr_median: float
    sign_median: float
    cos_p10: float
    relerr_p90: float
    src_evr1: float
    src_evr5: float
    energy_tau_int: float
    energy_r1: float
    rho_B_mid: float
    R_evr1_mid: float
    R_diag_ratio_mid: float


@dataclass
class HypothesisThresholds:
    """
    A minimal, *task-facing* hypothesis set (all must pass):

    - Stability: post-training Lyapunov exponent is negative (stable side).
    - Simplicity (local + order-parameter): susceptibility energy is concentrated in
      (i) the diagonal self-response and/or (ii) a single dominant mode.
    - Functional credit assignment: OLL δ aligns with exact BPTT δ.
    """

    lyap_post_max: float = 0.0
    simplicity_min: float = 0.75
    cos_median_min: float = 0.35
    sign_median_min: float = 0.58


def evaluate_hypotheses(
    *,
    lyap_post: float,
    metrics: ProbeMetrics,
    th: HypothesisThresholds,
) -> dict[str, Any]:
    h_stable = bool(np.isfinite(lyap_post) and lyap_post <= th.lyap_post_max)
    simplicity = float(metrics.R_diag_ratio_mid + metrics.R_evr1_mid)
    h_simple = bool(np.isfinite(simplicity) and simplicity >= th.simplicity_min)
    h_align = bool(
        np.isfinite(metrics.cos_median)
        and np.isfinite(metrics.sign_median)
        and metrics.cos_median >= th.cos_median_min
        and metrics.sign_median >= th.sign_median_min
    )
    return {
        "passed": bool(h_stable and h_simple and h_align),
        "H_stable": h_stable,
        "H_simple": h_simple,
        "H_align": h_align,
        "simplicity": float(simplicity),
    }


def probe_one_cycle(
    model: TorchLocalRuleRNN,
    inputs: np.ndarray,
    targets: np.ndarray,
    *,
    use_step_weights: bool = False,
) -> ProbeMetrics:
    device = model.W_hh.device
    inputs_t = torch.as_tensor(inputs, device=device, dtype=torch.float32)
    targets_t = torch.as_tensor(targets, device=device, dtype=torch.float32)
    batch = int(inputs_t.shape[0])
    time_steps = int(inputs_t.shape[2])

    # local copies of the learner state so probing does not modify training buffers
    alpha_num = model.alpha_num.clone()
    alpha_den = model.alpha_den.clone()
    alpha_hat = model.alpha_hat.clone()
    S_A2 = model.S_A2.clone()
    S_AB = model.S_AB.clone()
    lambda_vals = model.lambda_vals.clone()

    h_prev = torch.zeros((model.hidden_size, batch), device=device, dtype=torch.float32)
    prev_g: torch.Tensor | None = None
    prev_u: torch.Tensor | None = None

    h_list: list[torch.Tensor] = []
    u_list: list[torch.Tensor] = []
    g_list: list[torch.Tensor] = []
    delta_est_list: list[torch.Tensor] = []

    W_hh = model.W_hh
    W_xh = model.W_xh
    b_h = model.b_h
    W_hy = model.W_hy
    b_y = model.b_y

    step_weights = model.step_weights if (use_step_weights and model.step_weights is not None) else None

    with torch.no_grad():
        for t in range(time_steps):
            w_t = step_weights[t] if step_weights is not None else 1.0
            I_t = inputs_t[:, :, t].T
            y_target_t = targets_t[:, :, t].T

            x_t = W_hh @ h_prev + W_xh @ I_t + b_h
            h_t = torch.tanh(x_t)
            u_t = 1.0 - h_t**2
            y_hat_t = W_hy @ h_t + b_y

            P_t = _torch_softmax0(y_hat_t)
            dL_dyhat = (P_t - y_target_t) * w_t
            g_t = W_hy.T @ dL_dyhat

            lambda_used = lambda_vals
            denom = 1.0 - lambda_used * u_t
            denom_mask = torch.abs(denom) < model.denom_floor
            denom = torch.where(denom_mask, model.denom_floor * torch.sign(denom + 1e-12), denom)
            delta_est = (u_t * g_t) / denom

            # update alpha_hat
            hthp_mean = torch.mean(h_t * h_prev, dim=1, keepdim=True)
            hphp_mean = torch.mean(h_prev**2, dim=1, keepdim=True)
            alpha_num = model.alpha_rho * alpha_num + (1.0 - model.alpha_rho) * hthp_mean
            alpha_den = model.alpha_rho * alpha_den + (1.0 - model.alpha_rho) * hphp_mean
            raw_alpha = alpha_num / (alpha_den + model.epsilon)
            alpha_hat = torch.clamp(raw_alpha, model.alpha_clip_min, model.alpha_clip_max)

            # update lambda_vals
            if prev_g is not None and prev_u is not None:
                A_s = prev_u * u_t * (alpha_hat * prev_g - g_t)
                B_s = alpha_hat * prev_u * prev_g - u_t * g_t
                A2_mean = torch.mean(A_s**2, dim=1, keepdim=True)
                AB_mean = torch.mean(A_s * B_s, dim=1, keepdim=True)
                S_A2 = model.lambda_rho * S_A2 + (1.0 - model.lambda_rho) * A2_mean
                S_AB = model.lambda_rho * S_AB + (1.0 - model.lambda_rho) * AB_mean
                lambda_unproj = S_AB / (S_A2 + model.eps_lambda)
                u_abs_max = torch.max(torch.abs(u_t), dim=1, keepdim=True).values + 1e-12
                safe_cap = (1.0 - model.denom_floor) / u_abs_max
                cap = torch.minimum(safe_cap, torch.full_like(safe_cap, model.lambda_cap))
                lambda_vals = torch.clamp(lambda_unproj, min=-cap, max=cap)

            h_list.append(h_t)
            u_list.append(u_t)
            g_list.append(g_t)
            delta_est_list.append(delta_est)

            prev_g = g_t
            prev_u = u_t
            h_prev = h_t

    # exact BPTT recursion (same g_t, u_t; fixed W_hh)
    delta_true_list: list[torch.Tensor] = [torch.zeros_like(delta_est_list[0]) for _ in range(time_steps)]
    delta_next = torch.zeros_like(delta_est_list[0])
    W_T = W_hh.T
    with torch.no_grad():
        for t in reversed(range(time_steps)):
            delta_t = u_list[t] * (g_list[t] + (W_T @ delta_next))
            delta_true_list[t] = delta_t
            delta_next = delta_t

    # metrics across time: use batch-mean vectors per t
    cos_vals: list[float] = []
    rel_vals: list[float] = []
    sign_vals: list[float] = []
    for t in range(time_steps):
        dt = _to_np(delta_true_list[t].mean(dim=1))
        de = _to_np(delta_est_list[t].mean(dim=1))
        cos_vals.append(_cosine(dt, de))
        rel_vals.append(_relerr(dt, de))
        sign_vals.append(float(np.mean(np.sign(dt) == np.sign(de))))

    cos_arr = np.array(cos_vals, dtype=np.float64)
    rel_arr = np.array(rel_vals, dtype=np.float64)
    sign_arr = np.array(sign_vals, dtype=np.float64)

    # source subspace: samples are (t,b) pairs
    s_stack = torch.stack([u_list[t] * g_list[t] for t in range(time_steps)], dim=0)  # (T,H,B)
    s_np = _to_np(s_stack).transpose(1, 0, 2).reshape(model.hidden_size, -1)  # (H, T*B)
    src_evr = explained_variance_ratios(s_np, top_k=8)
    src_evr1 = float(src_evr[0]) if src_evr else float("nan")
    src_evr5 = float(sum(src_evr[:5])) if len(src_evr) >= 5 else float("nan")

    # critical slowing down proxy from energy(t)
    h_stack = torch.stack(h_list, dim=0)
    energy = _to_np(torch.mean(h_stack * h_stack, dim=(1, 2)))
    if energy.size >= 3 and float(np.std(energy)) > 1e-12:
        energy_r1 = float(np.corrcoef(energy[:-1], energy[1:])[0, 1])
    else:
        energy_r1 = float("nan")
    energy_tau_int = integrated_autocorr_time(energy, max_lag=200)

    # mid-point operator stats using mean-u
    t_mid = time_steps // 2
    u_mid = _to_np(u_list[t_mid].mean(dim=1)).astype(np.float64)  # (H,)
    W_hh_np = _to_np(W_hh).astype(np.float64)
    B_mid = (u_mid[:, None] * W_hh_np.T).astype(np.float64)
    R_stats = resolvent_stats(B_mid)

    return ProbeMetrics(
        cos_median=float(np.nanmedian(cos_arr)),
        relerr_median=float(np.nanmedian(rel_arr)),
        sign_median=float(np.nanmedian(sign_arr)),
        cos_p10=float(np.nanpercentile(cos_arr, 10)),
        relerr_p90=float(np.nanpercentile(rel_arr, 90)),
        src_evr1=src_evr1,
        src_evr5=src_evr5,
        energy_tau_int=float(energy_tau_int),
        energy_r1=float(energy_r1),
        rho_B_mid=float(R_stats["rho"]),
        R_evr1_mid=float(R_stats["evr1"]),
        R_diag_ratio_mid=float(R_stats["diag_ratio"]),
    )


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description="Stage-1 probe: train OLL local rule then validate hypotheses on task sources.")
    parser.add_argument("--epochs", type=int, default=2)
    parser.add_argument("--scan-epochs", type=int, default=0)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--hidden", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--train-limit", type=int, default=4000)
    parser.add_argument("--test-limit", type=int, default=1000)
    parser.add_argument("--probe-batch", type=int, default=64)
    parser.add_argument("--gains", type=str, default="0.7,0.9,1.0,1.1")
    parser.add_argument("--out-dir", type=str, default="")
    parser.add_argument("--simplicity-min", type=float, default=HypothesisThresholds.simplicity_min)
    parser.add_argument("--cos-median-min", type=float, default=HypothesisThresholds.cos_median_min)
    parser.add_argument("--sign-median-min", type=float, default=HypothesisThresholds.sign_median_min)
    args = parser.parse_args(argv)

    gains = [float(x.strip()) for x in str(args.gains).split(",") if x.strip()]
    if not gains:
        raise ValueError("--gains must contain at least one value.")

    train_inputs, train_targets, train_labels, test_inputs, test_targets, test_labels = load_row_mnist_sequences(
        train_limit=int(args.train_limit),
        test_limit=int(args.test_limit),
    )
    driver = build_lyapunov_driver(train_inputs)

    out_dir = Path(args.out_dir) if args.out_dir else (Path("plots") / f"oll_stage1_probe_{_timestamp()}")
    _ensure_dir(out_dir)

    thresholds = HypothesisThresholds(
        simplicity_min=float(args.simplicity_min),
        cos_median_min=float(args.cos_median_min),
        sign_median_min=float(args.sign_median_min),
    )

    results: list[dict[str, Any]] = []
    for g in gains:
        model = TorchLocalRuleRNN(
            input_size=int(train_inputs.shape[1]),
            hidden_size=int(args.hidden),
            output_size=int(train_targets.shape[1]),
            eta=float(args.lr),
            seed=int(args.seed),
            device="cpu",
        )
        model.initialize_weights_with_gain(float(g), seed=int(args.seed))

        lyap_pre = calculate_lyapunov_exponent_numpy(model, driver)

        if int(args.scan_epochs) > 0:
            train_batches(
                model,
                train_inputs,
                train_targets,
                batch_size=int(args.batch_size),
                epochs=int(args.scan_epochs),
                seed=int(args.seed) + 1,
            )

        if int(args.epochs) > 0:
            train_batches(
                model,
                train_inputs,
                train_targets,
                batch_size=int(args.batch_size),
                epochs=int(args.epochs),
                seed=int(args.seed) + 2,
            )

        lyap_post = calculate_lyapunov_exponent_numpy(model, driver)

        test_loss, test_acc = evaluate_classifier_final_step(
            model,
            test_inputs,
            test_targets,
            test_labels,
            batch_size=min(int(args.batch_size), int(test_inputs.shape[0])),
        )

        probe_n = min(int(args.probe_batch), int(test_inputs.shape[0]))
        probe_inputs = test_inputs[:probe_n]
        probe_targets = test_targets[:probe_n]
        metrics = probe_one_cycle(model, probe_inputs, probe_targets)
        hyp = evaluate_hypotheses(lyap_post=lyap_post, metrics=metrics, th=thresholds)

        item = {
            "g": float(g),
            "lyap_pre": float(lyap_pre),
            "lyap_post": float(lyap_post),
            "test_loss": float(test_loss),
            "test_acc": float(test_acc),
            **dataclass_to_dict(metrics),
            "hypotheses": hyp,
        }
        results.append(item)
        print(
            f"[Stage1] g={g:.3f} | acc={test_acc:.3f} | lyap(pre,post)=({lyap_pre:.4f},{lyap_post:.4f}) | "
            f"H={int(hyp['passed'])}({int(hyp['H_stable'])}{int(hyp['H_simple'])}{int(hyp['H_align'])}) | "
            f"cos_med={metrics.cos_median:.3f} | sign_med={metrics.sign_median:.3f} | "
            f"diagR={metrics.R_diag_ratio_mid:.3f} | evr1={metrics.R_evr1_mid:.3f} | simp={hyp['simplicity']:.3f}"
        )

    with open(out_dir / "summary.json", "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    print(f"[Stage1] Wrote {out_dir / 'summary.json'}")
    return 0


def dataclass_to_dict(obj: Any) -> dict[str, Any]:
    if not hasattr(obj, "__dataclass_fields__"):
        raise TypeError("dataclass_to_dict expects a dataclass instance.")
    out: dict[str, Any] = {}
    for k in obj.__dataclass_fields__.keys():
        out[k] = getattr(obj, k)
    return out


if __name__ == "__main__":
    raise SystemExit(main())
