import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Tuple

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np

from plotting_utils import apply_plot_style

# 依赖：你自己的 strict_fptt.py / standard_eprop.py
COMPARE_DIR = Path(__file__).resolve().parents[1]
if str(COMPARE_DIR) not in sys.path:
    sys.path.insert(0, str(COMPARE_DIR))

from methods.strict_fptt import StrictFPTTRegressor
from methods.standard_eprop import StandardEPropRNN
from task.common.sequence_core import extract_params as extract_params_common, load_params as load_params_common


# =============================
# Config
# =============================
BLOCK_SIZE = 64          # sequence length for the adding task
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 1e-3
HIDDEN_SIZE = 128

INPUT_SIZE = 2           # (value, marker)
OUTPUT_SIZE = 1          # running sum of marked values

# Stage 1 Scan Config
GAINS_SCAN = np.linspace(0.5, 1.6, 13, endpoint=False)
SCAN_EPOCHS = 10  # fewer epochs for the scan

# training schedule
STEPS_PER_EPOCH = 100
VAL_STEPS = 50
FPTT_PARTS = 8
EPROP_SEED = 1234
EPROP_FEEDBACK = "symmetric"
FPTT_LAMBDA = 0.5
FPTT_USE_FED_DYN = False  # 不再使用，但保留常量无伤大雅


# =============================
# Optimization helpers
# =============================

def normalize_and_clip_gradients(
    grads: List[np.ndarray],
    norm_scale: float,
    max_norm: float,
) -> List[np.ndarray]:
    """
    Normalize gradients by `norm_scale` (typically batch_size)
    and clip their global norm to `max_norm` to avoid divergence.
    """
    safe_scale = max(1.0, float(norm_scale))
    normalized = [g / safe_scale for g in grads]

    total_norm_sq = 0.0
    for g in normalized:
        total_norm_sq += float(np.sum(g**2))
    total_norm = float(np.sqrt(total_norm_sq))

    if not np.isfinite(total_norm) or total_norm <= max_norm:
        return normalized

    clip = max_norm / (total_norm + 1e-8)
    return [g * clip for g in normalized]


# =============================
# Utilities
# =============================

def generate_addition_batch(
    batch_size: int,
    seq_len: int,
    rng: np.random.Generator,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Sequence-to-sequence variant of the classic "adding problem".
    Inputs:
      - channel 0: random values in [0, 1]
      - channel 1: binary markers, exactly two positions set to 1
    Targets:
      - at each time step t, the sum of values at all marked positions <= t
        (so the final step is the sum of the two marked values).
    Returns:
      inputs:  (B, 2, T)
      targets: (B, 1, T)
    """
    inputs = np.zeros((batch_size, INPUT_SIZE, seq_len), dtype=np.float32)
    targets = np.zeros((batch_size, OUTPUT_SIZE, seq_len), dtype=np.float32)

    for i in range(batch_size):
        values = rng.random(seq_len, dtype=np.float32)
        positions = rng.choice(seq_len, size=2, replace=False)
        markers = np.zeros(seq_len, dtype=np.float32)
        markers[positions] = 1.0

        inputs[i, 0, :] = values
        inputs[i, 1, :] = markers

        running_sum = 0.0
        for t in range(seq_len):
            if markers[t] > 0.5:
                running_sum += float(values[t])
            targets[i, 0, t] = running_sum

    return inputs, targets


def evaluate_addition_model(
    model: Any,
    seq_len: int,
    batch_size: int,
    steps: int,
    rng_seed: int = 12345,
) -> float:
    """
    Evaluate a model on the adding task using mean squared error (MSE).
    We compute the MSE over the full output sequence.
    """
    rng = np.random.default_rng(rng_seed)
    total_loss = 0.0
    count = 0

    for _ in range(steps):
        inputs, targets = generate_addition_batch(batch_size, seq_len, rng)
        h_prev = np.zeros((model.hidden_size, batch_size), dtype=np.float32)

        # 统一接口：forward_cycle 返回 list[T] of (output, batch)
        outputs_seq, _ = model.forward_cycle(inputs, h_prev)
        outputs_stack = np.stack(outputs_seq, axis=2)   # (output, batch, T)
        targets_stack = targets.transpose(1, 0, 2)      # (output, batch, T)

        error = outputs_stack - targets_stack
        loss = 0.5 * float(np.mean(error ** 2))
        total_loss += loss
        count += 1

    return total_loss / max(count, 1)


# ===== Lyapunov (Benettin-style QR) =====

def calculate_lyapunov_exponent_numpy(model: Any, driver_input: np.ndarray) -> float:
    """
    Estimate the largest Lyapunov exponent of the RNN driven by a fixed input sequence.
    driver_input: (input_size, T)
    """
    n_hidden = int(model.hidden_size)
    W_hh = np.asarray(model.W_hh, dtype=np.float64)
    W_xh = np.asarray(model.W_xh, dtype=np.float64)
    b_h = np.asarray(model.b_h, dtype=np.float64)

    h = np.zeros((n_hidden, 1), dtype=np.float64)
    Q = np.eye(n_hidden, dtype=np.float64)
    log_r_diag_sum = np.zeros(n_hidden, dtype=np.float64)
    log_floor = 1e-12
    T = int(driver_input.shape[1])

    for t in range(T):
        I_t = driver_input[:, t].reshape(-1, 1).astype(np.float64)
        x_t = W_hh @ h + W_xh @ I_t + b_h
        h_next = np.tanh(x_t)
        phi_prime = (1.0 - h_next**2).flatten()
        jacobian = (phi_prime[:, None]) * W_hh
        Z = jacobian @ Q
        try:
            Q, R = np.linalg.qr(Z)
        except np.linalg.LinAlgError:
            return float("nan")
        r_diag_abs = np.abs(np.diag(R))
        log_r_diag_sum += np.log(np.clip(r_diag_abs, log_floor, None))
        h = h_next

    return float(np.max(log_r_diag_sum / max(T, 1)))


def make_lyapunov_driver_for_addition(seq_len: int, batch_size: int) -> np.ndarray:
    """
    Build a fixed driver input for Lyapunov estimation by averaging
    over one random batch from the adding task.
    Returns: (input_size, seq_len)
    """
    rng = np.random.default_rng(999)
    x_batch, _ = generate_addition_batch(batch_size, seq_len, rng)
    return np.mean(x_batch, axis=0).astype(np.float32)


# =============================
# RNN Implementations
# =============================

class NumpyLocalRuleRNN:
    """
    Local-rule RNN with online updates and optional FPTT surrogates,
    adapted here for a regression task (MSE loss).
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        eta: float = 0.001,
        lambda_window: int = 50,
    ) -> None:
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.output_size = output_size
        self.eta = eta
        self.epsilon = 1e-8
        # 和其它任务一致：Local Rule 使用梯度裁剪
        self.max_grad_norm = 5.0

        # parameters
        self.W_xh = np.random.randn(hidden_size, input_size) * 0.1
        self.W_hh = np.random.randn(hidden_size, hidden_size)
        self.b_h = np.zeros((hidden_size, 1))
        self.W_hy = np.random.randn(output_size, hidden_size) * 0.1
        self.b_y = np.zeros((output_size, 1))

        # alpha-related (EMA for AR(1) approximation)
        self.alpha_rho = 0.995
        self.alpha_clip_min = -0.99
        self.alpha_clip_max = 0.99

        # lambda-related
        self.lambda_window = max(2, int(lambda_window))
        self.lambda_rho = (self.lambda_window - 1) / self.lambda_window
        self.lambda_cap = 0.99
        self.denom_floor = 1e-3
        self.eps_lambda = 1e-8

        # diagnostics (optional)
        self.diagnostics: Dict[str, Any] = {}
        self.reset_learning_state()
        self.reset_diagnostics()

        # FPTT-related (unused for this experiment, kept for compatibility)
        self.use_fptt_surrogates = False
        self.beta_schedule = "linear"
        self.fptt_Q_prev: np.ndarray | None = None
        self.fptt_Q_sum: np.ndarray | None = None
        self.fptt_Q_count: np.ndarray | None = None

    def reset_diagnostics(self) -> None:
        self.diagnostics = {
            "total_neurons_steps": 0,
            "lambda_clip_count": 0,
            "denom_floor_hits": 0,
            "lambda_history": [],
        }

    def reset_learning_state(self) -> None:
        shape = (self.hidden_size, 1)
        self.alpha_num = np.zeros(shape)
        self.alpha_den = np.zeros(shape)
        self.alpha_hat = np.zeros(shape)
        self.S_A2 = np.zeros(shape)
        self.S_AB = np.zeros(shape)
        self.lambda_vals = np.zeros(shape)

    def initialize_weights_with_gain(self, g: float) -> None:
        rng = np.random.default_rng(int(g * 1000))
        self.W_xh = rng.standard_normal(self.W_xh.shape, dtype=np.float32) * 0.1
        self.W_hh = rng.standard_normal(self.W_hh.shape, dtype=np.float32) * (
            g / np.sqrt(self.hidden_size)
        )
        self.b_h = np.zeros(self.b_h.shape, dtype=np.float32)
        self.W_hy = rng.standard_normal(self.W_hy.shape, dtype=np.float32) * 0.1
        self.b_y = np.zeros(self.b_y.shape, dtype=np.float32)
        self.reset_learning_state()

    def run_one_cycle_and_update_directly(
        self,
        inputs_cycle: np.ndarray,
        targets_cycle: np.ndarray,
        h_prev_cycle: np.ndarray,
    ) -> Tuple[float, np.ndarray]:
        h_prev = h_prev_cycle
        total_cycle_loss = 0.0
        batch_size, time_steps = inputs_cycle.shape[0], inputs_cycle.shape[2]
        prev_g, prev_u, prev_delta = None, None, None

        for t in range(time_steps):
            I_t = inputs_cycle[:, :, t].T  # (input, batch)
            y_target_t = targets_cycle[:, :, t].T  # (output, batch)

            x_t = self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h
            h_t = np.tanh(x_t)
            y_hat_t = self.W_hy @ h_t + self.b_y

            error = y_hat_t - y_target_t
            loss_t = 0.5 * np.mean(error ** 2)
            dL_dyhat = error
            total_cycle_loss += float(loss_t)

            g_t = self.W_hy.T @ dL_dyhat
            u_t = 1.0 - h_t**2

            lambda_used = self.lambda_vals.copy()
            denominator = 1.0 - lambda_used * u_t
            denom_mask = np.abs(denominator) < self.denom_floor
            denominator = np.where(
                denom_mask,
                self.denom_floor * np.sign(denominator),
                denominator,
            )
            delta_t = (u_t * g_t) / denominator

            # Estimate alpha_hat from the teaching signal itself:
            # \widetilde{delta}_t \approx alpha * \widetilde{delta}_{t-1}.
            if prev_delta is not None:
                dtp_mean = np.mean(delta_t * prev_delta, axis=1, keepdims=True)
                dpp_mean = np.mean(prev_delta**2, axis=1, keepdims=True)
                self.alpha_num = self.alpha_rho * self.alpha_num + (1.0 - self.alpha_rho) * dtp_mean
                self.alpha_den = self.alpha_rho * self.alpha_den + (1.0 - self.alpha_rho) * dpp_mean
                self.alpha_hat = np.clip(
                    self.alpha_num / (self.alpha_den + 1e-8),
                    self.alpha_clip_min,
                    self.alpha_clip_max,
                )

            dW_hh = (delta_t @ h_prev.T) / batch_size
            dW_xh = (delta_t @ I_t.T) / batch_size
            db_h = np.mean(delta_t, axis=1, keepdims=True)
            dW_hy = (dL_dyhat @ h_t.T) / batch_size
            db_y = np.mean(dL_dyhat, axis=1, keepdims=True)

            # Local rule 用梯度裁剪
            if self.max_grad_norm > 0:
                grads = [dW_xh, dW_hh, db_h, dW_hy, db_y]
                dW_xh, dW_hh, db_h, dW_hy, db_y = normalize_and_clip_gradients(
                    grads,
                    norm_scale=1.0,
                    max_norm=self.max_grad_norm,
                )

            self.W_hh -= self.eta * dW_hh
            self.W_xh -= self.eta * dW_xh
            self.b_h -= self.eta * db_h
            self.W_hy -= self.eta * dW_hy
            self.b_y -= self.eta * db_y

            if prev_g is not None and prev_u is not None:
                A_s = prev_u * u_t * (self.alpha_hat * prev_g - g_t)
                B_s = self.alpha_hat * prev_u * prev_g - u_t * g_t
                A2_mean = np.mean(A_s**2, axis=1, keepdims=True)
                AB_mean = np.mean(A_s * B_s, axis=1, keepdims=True)
                self.S_A2 = self.lambda_rho * self.S_A2 + (1.0 - self.lambda_rho) * A2_mean
                self.S_AB = self.lambda_rho * self.S_AB + (1.0 - self.lambda_rho) * AB_mean
                lambda_unproj = self.S_AB / (self.S_A2 + self.eps_lambda)
                u_abs_max = np.max(np.abs(u_t), axis=1, keepdims=True) + 1e-12
                cap = np.minimum((1.0 - self.denom_floor) / u_abs_max, self.lambda_cap)
                self.lambda_vals = np.clip(lambda_unproj, -cap, cap)

            prev_g, prev_u, prev_delta, h_prev = g_t, u_t, delta_t, h_t

        return total_cycle_loss, h_prev

    def forward_cycle(
        self, inputs: np.ndarray, h_prev: np.ndarray
    ) -> Tuple[List[np.ndarray], np.ndarray]:
        outputs: List[np.ndarray] = []
        for t in range(inputs.shape[2]):
            I_t = inputs[:, :, t].T
            h_t = np.tanh(self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h)
            y_t = self.W_hy @ h_t + self.b_y
            outputs.append(y_t)  # (output, batch)
            h_prev = h_t
        return outputs, h_prev


class BPTTRNN:
    """
    纯 BPTT：不做梯度裁剪（只做 batch*time 的归一化）。
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        eta: float = 0.001,
    ) -> None:
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.eta = eta

        self.W_xh = np.random.randn(hidden_size, input_size) * 0.1
        self.W_hh = np.random.randn(hidden_size, hidden_size)
        self.b_h = np.zeros((hidden_size, 1))
        self.W_hy = np.random.randn(output_size, hidden_size) * 0.1
        self.b_y = np.zeros((output_size, 1))

    def initialize_weights_with_gain(self, g: float) -> None:
        self.W_hh = np.random.randn(*self.W_hh.shape) * (g / np.sqrt(self.hidden_size))

    def train_batch(
        self,
        inputs: np.ndarray,
        targets: np.ndarray,
        h_prev: np.ndarray,
    ) -> Tuple[float, np.ndarray]:
        batch_size, time_steps = inputs.shape[0], inputs.shape[2]

        h_acts: Dict[int, np.ndarray] = {-1: h_prev}
        y_hats: Dict[int, np.ndarray] = {}
        total_loss = 0.0

        # forward
        for t in range(time_steps):
            I_t = inputs[:, :, t].T
            y_target_t = targets[:, :, t].T

            h_t = np.tanh(self.W_hh @ h_acts[t - 1] + self.W_xh @ I_t + self.b_h)
            y_hat_t = self.W_hy @ h_t + self.b_y

            h_acts[t] = h_t
            y_hats[t] = y_hat_t

            total_loss += 0.5 * float(np.mean((y_hat_t - y_target_t) ** 2))

        # backward
        dW_xh = np.zeros_like(self.W_xh)
        dW_hh = np.zeros_like(self.W_hh)
        db_h = np.zeros_like(self.b_h)
        dW_hy = np.zeros_like(self.W_hy)
        db_y = np.zeros_like(self.b_y)
        dh_next = np.zeros((self.hidden_size, batch_size))

        for t in reversed(range(time_steps)):
            y_target_t = targets[:, :, t].T
            y_hat_t = y_hats[t]
            h_t = h_acts[t]
            h_prev_t = h_acts[t - 1]

            dL_dyhat = y_hat_t - y_target_t
            dW_hy += dL_dyhat @ h_t.T
            db_y += np.sum(dL_dyhat, axis=1, keepdims=True)

            dh = self.W_hy.T @ dL_dyhat + dh_next
            dtanh = (1.0 - h_t**2) * dh

            dW_hh += dtanh @ h_prev_t.T
            dW_xh += dtanh @ inputs[:, :, t]
            db_h += np.sum(dtanh, axis=1, keepdims=True)

            dh_next = self.W_hh.T @ dtanh

        # 只做平均，不裁剪
        norm_scale = max(1.0, float(batch_size))
        dW_xh /= norm_scale
        dW_hh /= norm_scale
        db_h /= norm_scale
        dW_hy /= norm_scale
        db_y /= norm_scale

        self.W_xh -= self.eta * dW_xh
        self.W_hh -= self.eta * dW_hh
        self.b_h -= self.eta * db_h
        self.W_hy -= self.eta * dW_hy
        self.b_y -= self.eta * db_y

        return total_loss / time_steps, h_acts[time_steps - 1]

    def forward_cycle(
        self, inputs: np.ndarray, h_prev: np.ndarray
    ) -> Tuple[List[np.ndarray], np.ndarray]:
        outputs: List[np.ndarray] = []
        for t in range(inputs.shape[2]):
            I_t = inputs[:, :, t].T
            h_t = np.tanh(self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h)
            y_t = self.W_hy @ h_t + self.b_y
            outputs.append(y_t)  # (output, batch)
            h_prev = h_t
        return outputs, h_prev


# =============================
# Main Experiment Logic
# =============================

def plot_comparison(results: Dict[str, Dict[str, Any]], gain: float) -> None:
    names = list(results.keys())
    mses = [results[n]["final_metric"] for n in names]
    lyapunovs = [results[n]["lambda_post"] for n in names]

    fig, axs = plt.subplots(1, 3, figsize=(22, 7), constrained_layout=True)
    fig.suptitle(
        f"RNN Algorithm Comparison on Adding Task (Initialized at g={gain:.2f})",
        fontsize=20,
        weight="bold",
    )
    colors = [cm.viridis(x) for x in np.linspace(0.2, 0.9, len(names))]

    # a) Generalization Performance
    bars = axs[0].bar(names, mses, color=colors, zorder=3)
    axs[0].set_ylabel("Final Validation MSE (lower is better)", fontsize=12)
    axs[0].set_title("a) Generalization Performance", fontsize=14, weight="bold")
    axs[0].grid(True, axis="y", linestyle=":", zorder=0)
    for bar in bars:
        axs[0].text(
            bar.get_x() + bar.get_width() / 2.0,
            bar.get_height(),
            f"{bar.get_height():.4f}",
            ha="center",
            va="bottom",
        )
    axs[0].tick_params(axis="x", rotation=15)

    # b) Post-Training Dynamics
    bars = axs[1].bar(names, lyapunovs, color=colors, zorder=3)
    axs[1].axhline(
        0, color="crimson", linestyle="--", label="Edge of Chaos (λ=0)", zorder=2
    )
    axs[1].set_ylabel("Max Lyapunov Exponent (λ_max)", fontsize=12)
    axs[1].set_title("b) Post-Training Dynamics", fontsize=14, weight="bold")
    axs[1].grid(True, axis="y", linestyle=":", zorder=0)
    axs[1].legend()
    for bar in bars:
        height = bar.get_height()
        va = "bottom" if height >= 0 else "top"
        offset = abs(height) * 0.05 + 0.01
        y_pos = height + offset if height >= 0 else height - offset
        axs[1].text(
            bar.get_x() + bar.get_width() / 2.0,
            y_pos,
            f"{height:.3f}",
            ha="center",
            va=va,
        )
    axs[1].tick_params(axis="x", rotation=15)

    # c) Learning Curves
    for i, name in enumerate(names):
        axs[2].plot(
            np.arange(1, EPOCHS + 1),
            results[name]["history"],
            label=name,
            color=colors[i],
            lw=2.5,
            marker="o",
            ms=5,
        )
    axs[2].set_xlabel("Epoch", fontsize=12)
    axs[2].set_ylabel("Validation MSE", fontsize=12)
    axs[2].set_title("c) Learning Curves", fontsize=14, weight="bold")
    axs[2].grid(True, linestyle=":")
    axs[2].legend(fontsize=11)
    axs[2].set_xticks(np.arange(1, EPOCHS + 1))

    plt.show()


def main() -> None:
    np.random.seed(42)
    apply_plot_style()
    start_time = time.time()

    print("Preparing Lyapunov driver for the adding task...")
    lyap_driver = make_lyapunov_driver_for_addition(BLOCK_SIZE, BATCH_SIZE)

    # ========================================================================
    # STAGE 1: Find the best performing Local Rule model near critical state
    # ========================================================================
    print("\n" + "=" * 60)
    print("STAGE 1: Finding optimal 'g' for Local Rule RNN...")
    print("=" * 60)

    scan_results: List[Dict[str, Any]] = []
    for g in GAINS_SCAN:
        model = NumpyLocalRuleRNN(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, eta=LEARNING_RATE)
        model.initialize_weights_with_gain(g)

        lambda_pre = calculate_lyapunov_exponent_numpy(model, lyap_driver)
        initial_params = extract_params_common(model)

        rng = np.random.default_rng(123)
        for epoch in range(SCAN_EPOCHS):
            for _ in range(STEPS_PER_EPOCH):
                inputs, targets = generate_addition_batch(BATCH_SIZE, BLOCK_SIZE, rng)
                h_prev = np.zeros((model.hidden_size, inputs.shape[0]), dtype=np.float32)
                model.run_one_cycle_and_update_directly(inputs, targets, h_prev)

        val_mse = evaluate_addition_model(model, BLOCK_SIZE, BATCH_SIZE, VAL_STEPS)
        lambda_post = calculate_lyapunov_exponent_numpy(model, lyap_driver)

        print(
            f"[SCAN] g={g:.3f} | Val MSE: {val_mse:.6f} | "
            f"Lyapunov (pre, post) = ({lambda_pre:.4f}, {lambda_post:.4f}), "
            f"Δ={lambda_post - lambda_pre:.4f}"
        )

        if not (np.isnan(lambda_pre) or np.isnan(lambda_post)):
            scan_results.append(
                {
                    "g": g,
                    "mse": val_mse,
                    "lambda_pre": lambda_pre,
                    "lambda_post": lambda_post,
                    "params": initial_params,
                }
            )

    top_performers = sorted(scan_results, key=lambda x: x["mse"])[:3]
    best_critical_model = min(top_performers, key=lambda x: abs(x["lambda_pre"]))
    g_optimal = best_critical_model["g"]
    initial_params_optimal = best_critical_model["params"]

    print("\n--- Stage 1 Summary ---")
    print(f"Optimal gain 'g' found: {g_optimal:.4f}")
    print(f"  - Achieved MSE: {best_critical_model['mse']:.6f}")
    print(f"  - Lyapunov Exponent PRE (close to 0): {best_critical_model['lambda_pre']:.4f}")
    print(f"  - Lyapunov Exponent POST: {best_critical_model['lambda_post']:.4f}")

    # ========================================================================
    # STAGE 2: Compare algorithms at the optimal point
    # ========================================================================
    print("\n" + "=" * 60)
    print(f"STAGE 2: Comparing algorithms with optimal g = {g_optimal:.4f}")
    print("=" * 60)

    models_to_compare: Dict[str, Any] = {
        "Local Rule": NumpyLocalRuleRNN(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, eta=LEARNING_RATE),
        "BPTT": BPTTRNN(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, eta=LEARNING_RATE),
        "E-Prop": StandardEPropRNN(
            INPUT_SIZE,
            HIDDEN_SIZE,
            OUTPUT_SIZE,
            eta=LEARNING_RATE,
            feedback=EPROP_FEEDBACK,
            seed=EPROP_SEED,
            loss_mode="mse",
            device="cpu",
        ),
        "FPTT": StrictFPTTRegressor(
            input_size=INPUT_SIZE,
            hidden_size=HIDDEN_SIZE,
            output_size=OUTPUT_SIZE,
            eta=LEARNING_RATE,
            parts=FPTT_PARTS,
            clip=5.0,          # PyTorch 侧 max_grad_norm
            lmbda=FPTT_LAMBDA,
        ),
    }

    comparison_results: Dict[str, Dict[str, Any]] = {}

    for name, model in models_to_compare.items():
        print(f"\nTraining {name}...")

        # 用 Stage 1 找到的“临界初始化”来初始化其它算法（兼容 torch/numpy 模型）。
        load_params_common(model, initial_params_optimal)

        if hasattr(model, "reset_learning_state"):
            model.reset_learning_state()
        if hasattr(model, "reset_state_buffers"):
            model.reset_state_buffers()

        lambda_pre_model = calculate_lyapunov_exponent_numpy(model, lyap_driver)

        history: List[float] = []
        rng = np.random.default_rng(456)

        for epoch in range(EPOCHS):
            if hasattr(model, "set_epoch"):
                model.set_epoch(epoch)

            for _ in range(STEPS_PER_EPOCH):
                inputs, targets = generate_addition_batch(BATCH_SIZE, BLOCK_SIZE, rng)
                h_prev = np.zeros((model.hidden_size, inputs.shape[0]), dtype=np.float32)

                if hasattr(model, "run_one_cycle_and_update_directly"):
                    model.run_one_cycle_and_update_directly(inputs, targets, h_prev)
                elif hasattr(model, "train_batch"):
                    model.train_batch(inputs, targets, h_prev)
                else:
                    raise NotImplementedError(
                        f"Model {name} has no recognized training method."
                    )

            val_mse = evaluate_addition_model(model, BLOCK_SIZE, BATCH_SIZE, VAL_STEPS)
            history.append(val_mse)
            print(f"  Epoch {epoch+1}/{EPOCHS} | Val MSE: {val_mse:.6f}")

        final_lambda = calculate_lyapunov_exponent_numpy(model, lyap_driver)

        comparison_results[name] = {
            "final_metric": history[-1],
            "lambda_pre": lambda_pre_model,
            "lambda_post": final_lambda,
            "history": history,
        }

        print(
            f"-> Final Results for {name}: MSE={history[-1]:.6f}, "
            f"Lyapunov pre={lambda_pre_model:.4f}, post={final_lambda:.4f}, "
            f"Δ={final_lambda - lambda_pre_model:.4f}"
        )

    print(
        f"\nExperiment finished! Total time: {(time.time() - start_time) / 60:.2f} minutes"
    )

    # ========================================================================
    # STAGE 3: Plotting the comparison results
    # ========================================================================
    plot_comparison(comparison_results, g_optimal)


if __name__ == "__main__":
    main()
