import sys
import time
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch

from plotting_utils import apply_plot_style

COMPARE_DIR = Path(__file__).resolve().parents[1]
if str(COMPARE_DIR) not in sys.path:
    sys.path.insert(0, str(COMPARE_DIR))

from methods.strict_fptt import StrictFPTTClassifier
from methods.standard_eprop import StandardEPropRNN


# ==============================
# Global configs
# ==============================

LAMBDA_WINDOW = 50
MNIST_NPZ_FILENAME = "mnist.npz"

# FPTT 配置：这里按官方分类实验对齐
FPTT_PARTS = 10              # 官方 MNIST-10 / CIFAR-10 默认 10 块
FPTT_ORACLE_MOMENTUM = 1.0   # = 直接覆盖行为
FPTT_LAMBDA = 1.0            # 分类任务中 λ=1
FPTT_USE_FED_DYN = False     # 你现在 strict 版本里已经去掉 fed-dyn

EPROP_FEEDBACK = "symmetric"
EPROP_SEED = 1234


# ==============================
# Utilities
# ==============================

def calculate_lyapunov_exponent_numpy(model: Any, driver_input: np.ndarray) -> float:
    """
    Benettin-style QR accumulation for largest Lyapunov exponent.
    """
    n_hidden = int(model.hidden_size)
    W_hh = np.asarray(model.W_hh, dtype=np.float64)
    W_xh = np.asarray(model.W_xh, dtype=np.float64)
    b_h  = np.asarray(model.b_h,  dtype=np.float64)

    h = np.zeros((n_hidden, 1), dtype=np.float64)
    Q = np.eye(n_hidden, dtype=np.float64)
    log_r_diag_sum = np.zeros(n_hidden, dtype=np.float64)
    log_floor = 1e-12

    T = int(driver_input.shape[1])
    for t in range(T):
        I_t = driver_input[:, t].reshape(-1, 1).astype(np.float64)
        x_t = W_hh @ h + W_xh @ I_t + b_h
        h_next = np.tanh(x_t)

        phi_prime = (1.0 - h_next**2).flatten()
        jacobian = (phi_prime[:, None]) * W_hh

        Z = jacobian @ Q
        try:
            Q, R = np.linalg.qr(Z)
        except np.linalg.LinAlgError:
            return float("nan")

        r_diag_abs = np.abs(np.diag(R))
        log_r_diag_sum += np.log(np.clip(r_diag_abs, log_floor, None))
        h = h_next

    lyaps = log_r_diag_sum / max(T, 1)
    return float(np.max(lyaps))


def softmax(logits: np.ndarray) -> np.ndarray:
    """
    logits: (classes, batch)
    """
    logits_shifted = logits - np.max(logits, axis=0, keepdims=True)
    exp_logits = np.exp(logits_shifted)
    return exp_logits / (np.sum(exp_logits, axis=0, keepdims=True) + 1e-12)


def load_mnist_sequences(
    train_limit: int | None = None,
    test_limit: int | None = None,
    npz_path: str | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    加载 MNIST，并按 “行 = 时间步” 的方式转成序列：(N, 28, 28) -> (N, 28, 28) 视作 (batch, input_size, time_steps)。
    """
    candidate_paths: List[Path] = []
    if npz_path is not None:
        candidate_paths.append(Path(npz_path))
    candidate_paths.append(Path(MNIST_NPZ_FILENAME))
    candidate_paths.append(Path.home() / ".keras" / "datasets" / MNIST_NPZ_FILENAME)

    dataset: Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]] | None = None
    for path in candidate_paths:
        if path.exists():
            with np.load(path) as npz_data:
                x_train = npz_data["x_train"]
                y_train = npz_data["y_train"]
                x_test = npz_data["x_test"]
                y_test = npz_data["y_test"]
            dataset = ((x_train, y_train), (x_test, y_test))
            print(f"Loaded MNIST data from '{path}'.")
            break

    if dataset is None:
        try:
            from tensorflow.keras.datasets import mnist  # type: ignore
        except Exception as exc:
            raise RuntimeError(
                "MNIST dataset not found locally and TensorFlow is unavailable. "
                "Place 'mnist.npz' alongside this script or install tensorflow."
            ) from exc

        dataset = mnist.load_data()
        print("Loaded MNIST data via tensorflow.keras.datasets.mnist.")

    (train_images, train_labels), (test_images, test_labels) = dataset

    train_labels = train_labels.astype(np.int64)
    test_labels = test_labels.astype(np.int64)

    if train_limit is not None:
        train_images = train_images[:train_limit]
        train_labels = train_labels[:train_limit]
    if test_limit is not None:
        test_images = test_images[:test_limit]
        test_labels = test_labels[:test_limit]

    display_test_images = test_images.astype(np.float32) / 255.0

    train_images = train_images.astype(np.float32) / 255.0
    test_images = test_images.astype(np.float32) / 255.0

    # 归一化
    mean_pixel = np.mean(train_images)
    std_pixel = np.std(train_images) + 1e-7
    train_images = (train_images - mean_pixel) / std_pixel
    test_images = (test_images - mean_pixel) / std_pixel

    # (N, 28, 28) -> (N, 28, 28) (每一行作为一个时间步)
    train_inputs = np.transpose(train_images, (0, 2, 1))
    test_inputs = np.transpose(test_images, (0, 2, 1))

    time_steps = train_inputs.shape[2]
    num_classes = 10

    train_targets_onehot = np.eye(num_classes, dtype=np.float32)[train_labels]
    test_targets_onehot = np.eye(num_classes, dtype=np.float32)[test_labels]

    # 每个时间步都复制同一个 one-hot label（FPTT 只用最后一个时间步）
    train_targets = np.repeat(train_targets_onehot[:, :, None], time_steps, axis=2)
    test_targets = np.repeat(test_targets_onehot[:, :, None], time_steps, axis=2)

    return (
        train_inputs.astype(np.float32),
        train_targets.astype(np.float32),
        train_labels,
        test_inputs.astype(np.float32),
        test_targets.astype(np.float32),
        test_labels,
        display_test_images,
    )


def iterate_minibatches(
    inputs: np.ndarray,
    targets: np.ndarray,
    batch_size: int,
    rng: np.random.Generator,
) -> Iterable[Tuple[np.ndarray, np.ndarray]]:
    indices = np.arange(inputs.shape[0])
    rng.shuffle(indices)
    for start_idx in range(0, inputs.shape[0], batch_size):
        batch_indices = indices[start_idx : start_idx + batch_size]
        yield inputs[batch_indices], targets[batch_indices]


def split_train_val(
    inputs: np.ndarray,
    targets: np.ndarray,
    labels: np.ndarray,
    val_fraction: float,
    rng: np.random.Generator,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    total = int(inputs.shape[0])
    val_size = max(1, int(total * float(val_fraction)))
    indices = rng.permutation(total)
    val_idx = indices[:val_size]
    train_idx = indices[val_size:]
    return (
        inputs[train_idx],
        targets[train_idx],
        labels[train_idx],
        inputs[val_idx],
        targets[val_idx],
        labels[val_idx],
    )


def extract_params(model: Any) -> Dict[str, np.ndarray]:
    params: Dict[str, np.ndarray] = {}
    for name in ("W_hh", "W_xh", "b_h", "W_hy", "b_y"):
        if not hasattr(model, name):
            continue
        value = getattr(model, name)
        if torch.is_tensor(value):
            params[name] = value.detach().cpu().numpy().copy()
        else:
            params[name] = np.asarray(value, dtype=np.float32).copy()
    return params


def load_params(model: Any, params: Dict[str, np.ndarray]) -> None:
    for name in ("W_hh", "W_xh", "b_h", "W_hy", "b_y"):
        if name not in params or not hasattr(model, name):
            continue
        value = params[name]
        current = getattr(model, name)
        if torch.is_tensor(current):
            tensor = torch.as_tensor(value, device=current.device, dtype=current.dtype)
            with torch.no_grad():
                current.copy_(tensor)
        else:
            setattr(model, name, np.asarray(value, dtype=np.float32).copy())

    if hasattr(model, "reset_state_buffers"):
        model.reset_state_buffers()
    if hasattr(model, "reset_optimizer"):
        model.reset_optimizer()
    elif hasattr(model, "optimizer") and hasattr(model.optimizer, "state"):
        model.optimizer.state.clear()


def evaluate_model(
    model: Any,
    inputs: np.ndarray,
    targets: np.ndarray,
    labels: np.ndarray,
    batch_size: int,
) -> Tuple[float, float]:
    """
    通用评估函数，兼容：
    - 我们的 NumPy RNN（输出序列元素为 (classes, batch)）
    - 你的 strict_fptt（输出序列元素为 (batch, classes)）
    """
    total_loss = 0.0
    total_samples = 0
    correct = 0
    eps = 1e-12

    num_classes = targets.shape[1]

    for start_idx in range(0, inputs.shape[0], batch_size):
        end_idx = min(start_idx + batch_size, inputs.shape[0])
        input_batch = inputs[start_idx:end_idx]
        target_batch = targets[start_idx:end_idx]
        label_batch = labels[start_idx:end_idx]
        current_batch = input_batch.shape[0]

        h_prev_eval = np.zeros((model.hidden_size, current_batch), dtype=np.float32)
        outputs_seq, _ = model.forward_cycle(input_batch, h_prev_eval)

        # 兼容两种 shape：(C,B) 或 (B,C)
        first = outputs_seq[0]
        if first.shape[0] == num_classes:
            # elements: (C,B)
            outputs_stack = np.stack(outputs_seq, axis=2)   # (C,B,T)
            outputs_batch = np.transpose(outputs_stack, (1, 0, 2))  # (B,C,T)
        elif first.shape[1] == num_classes:
            # elements: (B,C)
            outputs_batch = np.stack(outputs_seq, axis=2)   # (B,C,T)
        else:
            raise RuntimeError(
                f"Unexpected output shape from model.forward_cycle(): {first.shape}"
            )

        batch_loss = 0.0
        for t in range(outputs_batch.shape[2]):
            # outputs_batch[:, :, t] : (B,C) -> transpose to (C,B) for softmax
            probs = softmax(outputs_batch[:, :, t].T).T  # (B,C)
            y_t = target_batch[:, :, t]                  # (B,C)
            ce = -np.sum(y_t * np.log(probs + eps), axis=1)
            batch_loss += float(np.mean(ce))
        batch_loss /= outputs_batch.shape[2]

        final_logits = outputs_batch[:, :, -1]
        preds = np.argmax(final_logits, axis=1)
        correct += int(np.sum(preds == label_batch))
        total_loss += batch_loss * current_batch
        total_samples += current_batch

    avg_loss = total_loss / max(total_samples, 1)
    accuracy = correct / max(total_samples, 1)
    return avg_loss, accuracy


# ==============================
# Local Rule RNN (NumPy)
# ==============================

class NumpyLocalRuleRNN:
    """
    你的本地学习（local rule）RNN，实现里支持 FPTT 式 surrogate Q_t。
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        eta: float = 0.001,
        lambda_window: int = LAMBDA_WINDOW,
        eta_rec_scale: float = 1.0,
    ) -> None:
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.output_size = output_size
        self.eta = eta
        self.eta_rec_scale = eta_rec_scale  # scale recurrent updates to optionally preserve structure
        self.epsilon = 1e-8
        self.max_grad_norm = 5.0  # Local Rule 使用 5.0 裁剪

        self.W_xh = np.random.randn(hidden_size, input_size) * 0.1
        self.W_hh = np.random.randn(hidden_size, hidden_size)
        self.b_h = np.zeros((hidden_size, 1))
        self.W_hy = np.random.randn(output_size, hidden_size) * 0.1
        self.b_y = np.zeros((output_size, 1))

        self.alpha_rho = 0.995
        self.alpha_clip_min = -0.99
        self.alpha_clip_max = 0.99

        self.lambda_window = max(2, int(lambda_window))
        self.lambda_rho = (self.lambda_window - 1) / self.lambda_window
        self.lambda_cap = 0.99
        self.denom_floor = 1e-3
        self.eps_lambda = 1e-8

        self.diagnostics: Dict[str, Any] = {}
        self.reset_learning_state()
        self.reset_diagnostics()

        self.use_fptt_surrogates = True
        self.beta_schedule = "linear"
        self.fptt_Q_prev: np.ndarray | None = None
        self.fptt_Q_sum: np.ndarray | None = None
        self.fptt_Q_count: np.ndarray | None = None

    def enable_fptt_surrogates(
        self,
        time_steps: int,
        output_size: int,
        Q_init: np.ndarray | None = None,
        beta_schedule: str = "linear",
    ) -> None:
        if Q_init is None:
            Q_init = np.full((output_size, time_steps), 1.0 / output_size, dtype=np.float32)
        self.use_fptt_surrogates = True
        self.beta_schedule = beta_schedule
        self.fptt_Q_prev = Q_init.astype(np.float32).copy()
        self.reset_fptt_epoch_accumulators(time_steps, output_size)

    def reset_fptt_epoch_accumulators(
        self,
        time_steps: int | None = None,
        output_size: int | None = None,
    ) -> None:
        if self.fptt_Q_prev is None and (time_steps is None or output_size is None):
            raise ValueError("Call enable_fptt_surrogates first.")
        if time_steps is None or output_size is None:
            time_steps = int(self.fptt_Q_prev.shape[1])
            output_size = int(self.fptt_Q_prev.shape[0])
        self.fptt_Q_sum = np.zeros((output_size, time_steps), dtype=np.float64)
        self.fptt_Q_count = np.zeros((time_steps,), dtype=np.int64)

    def finalize_fptt_epoch(self) -> None:
        if self.fptt_Q_sum is None or self.fptt_Q_count is None:
            return
        counts = np.maximum(1, self.fptt_Q_count.astype(np.float64))
        Q_new = (self.fptt_Q_sum / counts[None, :]).astype(np.float32)
        self.fptt_Q_prev = Q_new

    def reset_diagnostics(self) -> None:
        self.diagnostics = {
            "total_neurons_steps": 0,
            "lambda_clip_count": 0,
            "denom_floor_hits": 0,
            "lambda_history": [],
        }

    def reset_learning_state(self) -> None:
        shape = (self.hidden_size, 1)
        self.alpha_num = np.zeros(shape)
        self.alpha_den = np.zeros(shape)
        self.alpha_hat = np.zeros(shape)
        self.S_A2 = np.zeros(shape)
        self.S_AB = np.zeros(shape)
        self.lambda_vals = np.zeros(shape)

    def get_lambda_clip_percentage(self) -> float:
        total = self.diagnostics["total_neurons_steps"]
        if total == 0:
            return 0.0
        return (self.diagnostics["lambda_clip_count"] / total) * 100.0

    def initialize_weights_with_gain(self, g: float) -> None:
        std_dev = g / np.sqrt(self.hidden_size)
        self.W_hh = np.random.randn(*self.W_hh.shape) * std_dev

    def run_one_cycle_and_update_directly(
        self,
        inputs_cycle: np.ndarray,
        targets_cycle: np.ndarray,
        h_prev_cycle: np.ndarray,
    ) -> Tuple[float, np.ndarray]:
        """
        单个序列（一个 batch）用本地规则更新。
        """
        h_prev = h_prev_cycle
        total_cycle_loss = 0.0
        batch_size = inputs_cycle.shape[0]
        time_steps = inputs_cycle.shape[2]
        eps = 1e-12

        prev_g = None
        prev_u = None
        prev_delta = None

        for t in range(time_steps):
            I_t = inputs_cycle[:, :, t].T
            y_target_t = targets_cycle[:, :, t].T

            x_t = self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h
            h_t = np.tanh(x_t)
            y_hat_t = self.W_hy @ h_t + self.b_y

            if self.use_fptt_surrogates and self.fptt_Q_prev is not None:
                beta_t = float(t + 1) / float(time_steps)
                P_t = softmax(y_hat_t)
                Q_t = self.fptt_Q_prev[:, t].reshape(-1, 1)
                Q_t_batch = np.repeat(Q_t, repeats=y_hat_t.shape[1], axis=1)
                Y_tilde = beta_t * y_target_t + (1.0 - beta_t) * Q_t_batch

                CE_true = -np.sum(y_target_t * np.log(P_t + eps), axis=0)
                CE_div  = -np.sum(Q_t_batch * np.log(P_t + eps), axis=0)
                loss_t  = np.mean(beta_t * CE_true + (1.0 - beta_t) * CE_div)

                dL_dyhat = P_t - Y_tilde

                if self.fptt_Q_sum is not None and self.fptt_Q_count is not None:
                    self.fptt_Q_sum[:, t] += np.sum(P_t, axis=1)
                    self.fptt_Q_count[t]  += batch_size
            else:
                P_t = softmax(y_hat_t)
                dL_dyhat = P_t - y_target_t
                loss_t = -np.mean(np.sum(y_target_t * np.log(P_t + eps), axis=0))

            total_cycle_loss += loss_t

            g_t = self.W_hy.T @ dL_dyhat
            u_t = 1.0 - h_t**2

            lambda_used = self.lambda_vals.copy()
            self.diagnostics["lambda_history"].append(float(np.mean(lambda_used)))
            self.diagnostics["total_neurons_steps"] += self.hidden_size * batch_size

            denominator = 1.0 - lambda_used * u_t
            denom_mask = np.abs(denominator) < self.denom_floor
            if np.any(denom_mask):
                self.diagnostics["denom_floor_hits"] += int(np.sum(denom_mask))
            denominator = np.where(denom_mask, self.denom_floor * np.sign(denominator), denominator)
            delta_t = (u_t * g_t) / denominator

            # Estimate alpha_hat from the teaching signal itself:
            # \widetilde{delta}_t \approx alpha * \widetilde{delta}_{t-1}.
            if prev_delta is not None:
                dtp_mean = np.mean(delta_t * prev_delta, axis=1, keepdims=True)
                dpp_mean = np.mean(prev_delta**2, axis=1, keepdims=True)
                self.alpha_num = self.alpha_rho * self.alpha_num + (1.0 - self.alpha_rho) * dtp_mean
                self.alpha_den = self.alpha_rho * self.alpha_den + (1.0 - self.alpha_rho) * dpp_mean
                raw_alpha = self.alpha_num / (self.alpha_den + self.epsilon)
                self.alpha_hat = np.clip(raw_alpha, self.alpha_clip_min, self.alpha_clip_max)

            dW_hh = (delta_t @ h_prev.T) / batch_size
            dW_xh = (delta_t @ I_t.T) / batch_size
            db_h  = np.mean(delta_t, axis=1, keepdims=True)
            dW_hy = (dL_dyhat @ h_t.T) / batch_size
            db_y  = np.mean(dL_dyhat, axis=1, keepdims=True)

            if self.max_grad_norm > 0:
                grads = [dW_hh, dW_xh, db_h, dW_hy, db_y]
                total_norm = np.sqrt(sum(np.sum(g**2) for g in grads))
                if total_norm > self.max_grad_norm:
                    clip_factor = self.max_grad_norm / (total_norm + 1e-8)
                    dW_hh *= clip_factor
                    dW_xh *= clip_factor
                    db_h  *= clip_factor
                    dW_hy *= clip_factor
                    db_y  *= clip_factor

            # Optionally damp recurrent updates to retain imposed structure
            self.W_hh -= self.eta * self.eta_rec_scale * dW_hh
            self.W_xh -= self.eta * dW_xh
            self.b_h  -= self.eta * db_h
            self.W_hy -= self.eta * dW_hy
            self.b_y  -= self.eta * db_y

            # 在线估计 lambda
            if prev_g is not None and prev_u is not None:
                A_s = prev_u * u_t * (self.alpha_hat * prev_g - g_t)
                B_s = self.alpha_hat * prev_u * prev_g - u_t * g_t
                A2_mean = np.mean(A_s**2, axis=1, keepdims=True)
                AB_mean = np.mean(A_s * B_s, axis=1, keepdims=True)

                self.S_A2 = self.lambda_rho * self.S_A2 + (1.0 - self.lambda_rho) * A2_mean
                self.S_AB = self.lambda_rho * self.S_AB + (1.0 - self.lambda_rho) * AB_mean

                lambda_unproj = self.S_AB / (self.S_A2 + self.eps_lambda)

                u_abs_max = np.max(np.abs(u_t), axis=1, keepdims=True) + 1e-12
                safe_cap = (1.0 - self.denom_floor) / u_abs_max
                cap = np.minimum(safe_cap, self.lambda_cap)
                lambda_clipped = np.clip(lambda_unproj, -cap, cap)

                self.diagnostics["lambda_clip_count"] += int(np.sum(lambda_unproj != lambda_clipped))
                self.lambda_vals = lambda_clipped

            prev_g = g_t
            prev_u = u_t
            h_prev = h_t
            prev_delta = delta_t

        return total_cycle_loss, h_prev

    def forward_cycle(self, inputs_cycle: np.ndarray, h_prev_cycle: np.ndarray) -> Tuple[List[np.ndarray], np.ndarray]:
        h_prev = h_prev_cycle
        outputs: List[np.ndarray] = []
        for t in range(inputs_cycle.shape[2]):
            I_t = inputs_cycle[:, :, t].T
            x_t = self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h
            h_t = np.tanh(x_t)
            y_hat_t = self.W_hy @ h_t + self.b_y  # (C,B)
            outputs.append(y_hat_t)
            h_prev = h_t
        return outputs, h_prev


# ==============================
# NumPy BPTT RNN baseline
# ==============================

class BPTTRNN:
    def __init__(self, input_size: int, hidden_size: int, output_size: int, eta: float = 0.001) -> None:
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.eta = eta

        self.W_xh = np.random.randn(hidden_size, input_size) * 0.1
        self.W_hh = np.random.randn(hidden_size, hidden_size)
        self.b_h = np.zeros((hidden_size, 1))
        self.W_hy = np.random.randn(output_size, hidden_size) * 0.1
        self.b_y = np.zeros((output_size, 1))

    def initialize_weights_with_gain(self, g: float) -> None:
        std_dev = g / np.sqrt(self.hidden_size)
        self.W_hh = np.random.randn(*self.W_hh.shape) * std_dev

    def train_batch(
        self,
        inputs_batch: np.ndarray,
        targets_batch: np.ndarray,
        h_prev_batch: np.ndarray,
    ) -> Tuple[float, np.ndarray]:
        batch_size = inputs_batch.shape[0]
        time_steps = inputs_batch.shape[2]
        eps = 1e-12

        h_acts: Dict[int, np.ndarray] = {-1: h_prev_batch}
        y_probs: Dict[int, np.ndarray] = {}
        total_loss = 0.0

        # Forward
        for t in range(time_steps):
            I_t = inputs_batch[:, :, t].T  # (input, batch)
            h_prev_t = h_acts[t - 1]
            x_t = self.W_hh @ h_prev_t + self.W_xh @ I_t + self.b_h
            h_t = np.tanh(x_t)
            y_hat_t = self.W_hy @ h_t + self.b_y

            h_acts[t] = h_t
            probs_t = softmax(y_hat_t)
            y_true_t = targets_batch[:, :, t].T
            loss_t = -np.sum(y_true_t * np.log(probs_t + eps)) / batch_size
            total_loss += loss_t

            y_probs[t] = probs_t

        # Backward
        dW_xh = np.zeros_like(self.W_xh)
        dW_hh = np.zeros_like(self.W_hh)
        db_h  = np.zeros_like(self.b_h)
        dW_hy = np.zeros_like(self.W_hy)
        db_y  = np.zeros_like(self.b_y)

        dh_next = np.zeros((self.hidden_size, batch_size))

        for t in reversed(range(time_steps)):
            I_t = inputs_batch[:, :, t].T
            y_true_t = targets_batch[:, :, t].T
            h_t = h_acts[t]
            h_prev_t = h_acts[t - 1]

            dy = y_probs[t] - y_true_t  # (C,B)

            dW_hy += dy @ h_t.T
            db_y  += np.sum(dy, axis=1, keepdims=True)

            dh = self.W_hy.T @ dy + dh_next
            dtanh = (1 - h_t**2) * dh

            dW_hh += dtanh @ h_prev_t.T
            dW_xh += dtanh @ I_t.T
            db_h  += np.sum(dtanh, axis=1, keepdims=True)

            dh_next = self.W_hh.T @ dtanh

        # BPTT：只按 batch_size 做平均，不做梯度裁剪
        self.W_xh -= self.eta * (dW_xh / batch_size)
        self.W_hh -= self.eta * (dW_hh / batch_size)
        self.b_h  -= self.eta * (db_h  / batch_size)
        self.W_hy -= self.eta * (dW_hy / batch_size)
        self.b_y  -= self.eta * (db_y  / batch_size)

        avg_loss = total_loss / time_steps
        return avg_loss, h_acts[time_steps - 1]

    def forward_cycle(self, inputs_cycle: np.ndarray, h_prev_cycle: np.ndarray) -> Tuple[List[np.ndarray], np.ndarray]:
        h_prev = h_prev_cycle
        outputs: List[np.ndarray] = []
        for t in range(inputs_cycle.shape[2]):
            I_t = inputs_cycle[:, :, t].T
            x_t = self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h
            h_t = np.tanh(x_t)
            y_hat_t = self.W_hy @ h_t + self.b_y  # (C,B)
            outputs.append(y_hat_t)
            h_prev = h_t
        return outputs, h_prev


# ==============================
# FPTT RNN wrapper (PyTorch)
# ==============================

class FPTTRNN(StrictFPTTClassifier):
    """
    严格 FPTT 模型，用于 MNIST 序列分类。
    参数已经对齐你现在 strict_fptt.py 的接口。
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        eta: float = 0.001,
        parts: int = FPTT_PARTS,
        oracle_momentum: float = FPTT_ORACLE_MOMENTUM,
        clip: float = 5.0,            # 使用 5.0 的梯度裁剪
        alpha: float = 0.1,
        beta: float = 0.5,
        rho: float = 0.0,
        lmbda: float = FPTT_LAMBDA,   # 分类任务 λ=1.0
        warmup_epochs: int = 20,
        oracle_id: str = "mnist",
        **_: Any,
    ) -> None:
        super().__init__(
            input_size=input_size,
            hidden_size=hidden_size,
            output_size=output_size,
            eta=eta,
            parts=parts,
            clip=clip,
            alpha=alpha,
            beta=beta,
            rho=rho,
            lmbda=lmbda,
            oracle_momentum=oracle_momentum,
            warmup_epochs=warmup_epochs,
            oracle_id=oracle_id,
            label_mode="last",
            use_oracle=True,
        )


# ==============================
# Main experiment
# ==============================

if __name__ == "__main__":
    np.random.seed(42)
    apply_plot_style()

    TRAIN_LIMIT = 10000
    TEST_LIMIT = 2000
    BATCH_SIZE = 64
    EPOCHS = 25
    ETA = 1e-3
    HIDDEN_SIZE = 128

    (
        train_inputs,
        train_targets,
        train_labels,
        test_inputs,
        test_targets,
        test_labels,
        _,
    ) = load_mnist_sequences(
        train_limit=TRAIN_LIMIT, test_limit=TEST_LIMIT
    )

    input_size = train_inputs.shape[1]
    time_steps = train_inputs.shape[2]
    output_size = train_targets.shape[1]

    # 驱动序列用于 Lyapunov 指数
    # Standard protocol: use validation for model selection/curves, and use test only once for final reporting.
    rng_split = np.random.default_rng(42)
    (
        train_inputs_fit,
        train_targets_fit,
        train_labels_fit,
        val_inputs,
        val_targets,
        val_labels,
    ) = split_train_val(train_inputs, train_targets, train_labels, val_fraction=0.1, rng=rng_split)

    # Driver sequence for Lyapunov exponent (do NOT use the test set).
    lyapunov_driver = np.mean(train_inputs_fit[: min(32, train_inputs_fit.shape[0])], axis=0)

    total_start_time = time.time()

    print("=" * 60)
    print("STAGE 1: Finding optimal gain 'g' for Local Rule RNN...")
    print("=" * 60)

    gains_dense = np.linspace(0.5, 1.6, 13)
    gains_to_test = gains_dense

    local_rule_scan_results = []

    for g in gains_to_test:
        model = NumpyLocalRuleRNN(
            input_size=input_size,
            hidden_size=HIDDEN_SIZE,
            output_size=output_size,
            eta=ETA,
        )
        model.initialize_weights_with_gain(g)
        lambda_pre = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)

        # 记录初始参数，以便 Stage2 统一初始化
        initial_params = {
            "W_hh": model.W_hh.copy(),
            "W_xh": model.W_xh.copy(),
            "b_h": model.b_h.copy(),
            "W_hy": model.W_hy.copy(),
            "b_y": model.b_y.copy(),
        }

        # 初始化 FPTT surrogate Q_t
        Q0 = np.full((output_size, time_steps), 1.0 / output_size, dtype=np.float32)
        model.enable_fptt_surrogates(time_steps=time_steps, output_size=output_size, Q_init=Q0)

        rng = np.random.default_rng(42)
        for epoch in range(EPOCHS):
            model.reset_fptt_epoch_accumulators()
            for inputs_batch, targets_batch in iterate_minibatches(
                train_inputs_fit, train_targets_fit, BATCH_SIZE, rng
            ):
                h_prev = np.zeros(
                    (model.hidden_size, inputs_batch.shape[0]), dtype=np.float32
                )
                _, _ = model.run_one_cycle_and_update_directly(
                    inputs_batch, targets_batch, h_prev
                )
            model.finalize_fptt_epoch()

        val_loss, val_acc = evaluate_model(model, val_inputs, val_targets, val_labels, BATCH_SIZE)
        lambda_post = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)

        print(
            f"[SCAN] g={g:.3f} | val_acc={val_acc:.4f} | val_loss={val_loss:.4f} | "
            f"Lyapunov (pre, post) = ({lambda_pre:.4f}, {lambda_post:.4f}), "
            f"Δ={lambda_post - lambda_pre:.4f}"
        )

        if not (np.isnan(lambda_pre) or np.isnan(lambda_post)):
            local_rule_scan_results.append(
                {
                    "g": g,
                    "val_accuracy": val_acc,
                    "val_loss": val_loss,
                    "lambda_pre": lambda_pre,
                    "lambda_post": lambda_post,
                    "initial_params": initial_params,
                }
            )

    # 选出精度最高的若干，再从中选 Lyapunov 最接近 0 的
    top_performers = sorted(local_rule_scan_results, key=lambda x: x["val_accuracy"], reverse=True)[:3]
    best_critical_model_info = min(
        top_performers, key=lambda x: abs(x["lambda_pre"])
    )
    g_optimal = best_critical_model_info["g"]
    initial_params_optimal = best_critical_model_info["initial_params"]

    print("\n--- Stage 1 Summary ---")
    print(f"Optimal gain 'g' found: {g_optimal:.4f}")
    print(
        f"  - Achieved Val Accuracy: {best_critical_model_info['val_accuracy']:.4f}"
    )
    print(
        f"  - Lyapunov Exponent PRE (close to 0): {best_critical_model_info['lambda_pre']:.4f}"
    )
    print(
        f"  - Lyapunov Exponent POST: {best_critical_model_info['lambda_post']:.4f}"
    )

    print("\n" + "=" * 60)
    print(f"STAGE 2: Comparing algorithms with optimal g = {g_optimal:.4f}")
    print("=" * 60)

    models_to_compare = {
        "Local Rule (FPTT)": NumpyLocalRuleRNN(
            input_size, HIDDEN_SIZE, output_size, eta=ETA
        ),
        "BPTT": BPTTRNN(input_size, HIDDEN_SIZE, output_size, eta=ETA),
        "E-Prop": StandardEPropRNN(
            input_size,
            HIDDEN_SIZE,
            output_size,
            eta=ETA,
            feedback=EPROP_FEEDBACK,
            seed=EPROP_SEED,
            loss_mode="ce",
        ),
        "FPTT": FPTTRNN(
            input_size,
            HIDDEN_SIZE,
            output_size,
            eta=ETA,
            parts=FPTT_PARTS,
            oracle_momentum=FPTT_ORACLE_MOMENTUM,
        ),
    }

    comparison_results = {}

    for name, model in models_to_compare.items():
        print(f"\nTraining {name}...")

        # 统一初始化（用 Stage1 扫描出来的最“临界”的 local rule 初始参数）
        load_params(model, initial_params_optimal)

        lambda_pre_model = calculate_lyapunov_exponent_numpy(
            model, lyapunov_driver
        )
        epoch_acc_history = []
        rng = np.random.default_rng(123)
        uses_fptt_surrogates = name == "Local Rule (FPTT)"
        best_val_acc = -float("inf")
        best_val_loss = float("inf")
        best_epoch = 0
        best_params = None

        if uses_fptt_surrogates:
            Q0 = np.full(
                (output_size, time_steps), 1.0 / output_size, dtype=np.float32
            )
            model.enable_fptt_surrogates(
                time_steps=time_steps, output_size=output_size, Q_init=Q0
            )

        for epoch in range(EPOCHS):
            if hasattr(model, "set_epoch"):
                model.set_epoch(epoch)
            if uses_fptt_surrogates:
                model.reset_fptt_epoch_accumulators()

            for inputs_batch, targets_batch in iterate_minibatches(
                train_inputs_fit, train_targets_fit, BATCH_SIZE, rng
            ):
                h_prev = np.zeros(
                    (model.hidden_size, inputs_batch.shape[0]), dtype=np.float32
                )
                if isinstance(model, NumpyLocalRuleRNN):
                    _, _ = model.run_one_cycle_and_update_directly(
                        inputs_batch, targets_batch, h_prev
                    )
                else:  # BPTT, FPTT, E-Prop
                    _, _ = model.train_batch(inputs_batch, targets_batch, h_prev)

            if uses_fptt_surrogates:
                model.finalize_fptt_epoch()

            val_loss, val_acc = evaluate_model(model, val_inputs, val_targets, val_labels, BATCH_SIZE)
            epoch_acc_history.append(val_acc)
            if val_acc > best_val_acc:
                best_val_acc = float(val_acc)
                best_val_loss = float(val_loss)
                best_epoch = int(epoch + 1)
                best_params = extract_params(model)
            if (epoch + 1) % 5 == 0:
                print(f"  Epoch {epoch+1}/{EPOCHS} | Val Acc: {val_acc:.4f} | Val Loss: {val_loss:.4f}")

        if best_params is not None:
            load_params(model, best_params)

        final_test_loss, final_test_acc = evaluate_model(model, test_inputs, test_targets, test_labels, BATCH_SIZE)
        final_lambda = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)

        comparison_results[name] = {
            "test_accuracy": final_test_acc,
            "best_val_accuracy": best_val_acc,
            "best_val_loss": best_val_loss,
            "best_epoch": best_epoch,
            "lyapunov_pre": lambda_pre_model,
            "lyapunov_post": final_lambda,
            "lyapunov_exponent": final_lambda,
            "accuracy_history": epoch_acc_history,
            "model": model,
        }
        print(
            f"-> Final Results for {name}: test_acc={final_test_acc:.4f} | best_val_acc={best_val_acc:.4f} (epoch={best_epoch:02d}) | "
            f"Lyapunov pre={lambda_pre_model:.4f}, post={final_lambda:.4f}, "
            f"Δ={final_lambda - lambda_pre_model:.4f}"
        )

    print(
        f"\nExperiment finished! Total time: {(time.time() - total_start_time) / 60:.2f} minutes"
    )

    # ==============================
    # Plot results
    # ==============================

    names = list(comparison_results.keys())
    accuracies = [comparison_results[n]["test_accuracy"] for n in names]
    lyapunovs = [comparison_results[n]["lyapunov_exponent"] for n in names]

    fig, axs = plt.subplots(1, 3, figsize=(22, 7), constrained_layout=True)
    fig.suptitle(
        f"RNN Learning Algorithm Comparison (Initialized at g={g_optimal:.2f})",
        fontsize=20,
        weight="bold",
    )

    colors = [cm.viridis(0.3), cm.inferno(0.5), cm.plasma(0.6), cm.magma(0.7)]

    # (a) final test accuracy
    ax1 = axs[0]
    bars = ax1.bar(names, accuracies, color=colors, zorder=3)
    ax1.set_ylabel("Final Test Accuracy", fontsize=12)
    ax1.set_title("a) Generalization Performance", fontsize=14, weight="bold")
    ax1.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
    ax1.set_ylim(0, 1.0)
    ax1.spines["top"].set_visible(False)
    ax1.spines["right"].set_visible(False)
    for bar in bars:
        yval = bar.get_height()
        ax1.text(
            bar.get_x() + bar.get_width() / 2.0,
            yval + 0.02,
            f"{yval:.3f}",
            ha="center",
            va="bottom",
        )
    ax1.tick_params(axis="x", rotation=15)

    # (b) post-training Lyapunov
    ax2 = axs[1]
    bars = ax2.bar(names, lyapunovs, color=colors, zorder=3)
    ax2.axhline(
        0,
        color="crimson",
        linestyle="--",
        linewidth=1.5,
        label="Edge of Chaos (λ=0)",
        zorder=2,
    )
    ax2.set_ylabel("Max Lyapunov Exponent (λ_max)", fontsize=12)
    ax2.set_title("b) Post-Training Dynamics", fontsize=14, weight="bold")
    ax2.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
    ax2.spines["top"].set_visible(False)
    ax2.spines["right"].set_visible(False)
    ax2.legend()
    for bar in bars:
        yval = bar.get_height()
        ax2.text(
            bar.get_x() + bar.get_width() / 2.0,
            yval + np.sign(yval) * 0.03,
            f"{yval:.3f}",
            ha="center",
            va="bottom",
        )
    ax2.tick_params(axis="x", rotation=15)

    # (c) learning curves
    ax3 = axs[2]
    epochs_axis = np.arange(1, EPOCHS + 1)
    for i, name in enumerate(names):
        ax3.plot(
            epochs_axis,
            comparison_results[name]["accuracy_history"],
            label=name,
            color=colors[i],
            linewidth=2.5,
            marker="o",
            markersize=4,
            markevery=5,
        )
    ax3.set_xlabel("Epoch", fontsize=12)
    ax3.set_ylabel("Val Accuracy", fontsize=12)
    ax3.set_title("c) Learning Curves", fontsize=14, weight="bold")
    ax3.grid(True, linestyle=":", linewidth=0.7)
    ax3.legend(fontsize=11)
    ax3.set_ylim(0, 1.0)
    ax3.spines["top"].set_visible(False)
    ax3.spines["right"].set_visible(False)

    plt.show()
