import io
import os
import sys
import time
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np

from plotting_utils import apply_plot_style

COMPARE_DIR = Path(__file__).resolve().parents[1]
if str(COMPARE_DIR) not in sys.path:
    sys.path.insert(0, str(COMPARE_DIR))

from methods.strict_fptt import StrictFPTTClassifier
from methods.standard_eprop import StandardEPropRNN
from task.common.sequence_core import extract_params as extract_params_common, load_params as load_params_common


# =============================
# Optimization helpers
# =============================

def normalize_and_clip_gradients(
    grads: List[np.ndarray],
    norm_scale: float,
    max_norm: float,
) -> List[np.ndarray]:
    """
    Normalize gradients by `norm_scale` (typically batch_size)
    and clip their global norm to `max_norm` to avoid divergence.
    """
    safe_scale = max(1.0, float(norm_scale))
    normalized = [g / safe_scale for g in grads]

    total_norm_sq = 0.0
    for g in normalized:
        total_norm_sq += float(np.sum(g**2))
    total_norm = float(np.sqrt(total_norm_sq))

    if not np.isfinite(total_norm) or total_norm <= max_norm:
        return normalized

    clip = max_norm / (total_norm + 1e-8)
    return [g * clip for g in normalized]


# =============================
# Config
# =============================
BLOCK_SIZE = 80
BATCH_SIZE = 64
EPOCHS = 8
LEARNING_RATE = 1e-3
HIDDEN_SIZE = 256

# 阶段一扫描范围更广，阶段二用最优值
GAINS_SCAN = np.linspace(0.5, 1.6, 13, endpoint=False)
STEPS_PER_EPOCH = 200
VAL_STEPS = 50
DATA_PATH = "wikitext-2-raw"
FPTT_PARTS = 8
FPTT_ORACLE_MOMENTUM = 0.35
FPTT_LAMBDA = 0.5
FPTT_USE_FED_DYN = False  # 这里暂时不用，但保留常量
EPROP_FEEDBACK = "symmetric"
EPROP_SEED = 1234


# =============================
# Utilities
# =============================

def softmax(logits: np.ndarray) -> np.ndarray:
    """
    约定：logits 形状为 (vocab, batch) 或 (dim, batch)，
    在 axis=0 上做 softmax。
    """
    logits_shifted = logits - np.max(logits, axis=0, keepdims=True)
    exp_logits = np.exp(logits_shifted)
    return exp_logits / (np.sum(exp_logits, axis=0, keepdims=True) + 1e-12)


def try_load_wikitext2_raw(data_path: str | None) -> Tuple[str, str, str]:
    """
    Returns (train_text, valid_text, test_text) as big strings.
    """

    def read_file(p: str) -> str:
        with io.open(p, "r", encoding="utf-8") as f:
            return f.read()

    def try_local(dirpath: str | None) -> Tuple[str, str, str] | None:
        if dirpath is None:
            return None
        train_p = os.path.join(dirpath, "wiki.train.raw")
        valid_p = os.path.join(dirpath, "wiki.valid.raw")
        test_p = os.path.join(dirpath, "wiki.test.raw")
        if (
            os.path.exists(train_p)
            and os.path.exists(valid_p)
            and os.path.exists(test_p)
        ):
            print(f"Loaded WikiText-2 raw from '{dirpath}/'.")
            return read_file(train_p), read_file(valid_p), read_file(test_p)
        return None

    # 先尝试本地
    for cand in [data_path, "wikitext-2-raw", "wikitext-2", "./", "./data"]:
        got = try_local(cand)
        if got is not None:
            return got

    # 再尝试从 HuggingFace 下载
    try:
        from datasets import load_dataset
    except Exception as exc:
        raise RuntimeError(
            "未找到本地 WikiText-2 raw 文件，且缺少 HuggingFace datasets。\n"
            "请先安装 `pip install datasets`，让代码自动下载并缓存。"
        ) from exc

    print("Downloading WikiText-2 (raw) from HuggingFace ...")
    ds = load_dataset("wikitext", "wikitext-2-raw-v1")
    train_text = "\n".join(ds["train"]["text"])
    valid_text = "\n".join(ds["validation"]["text"])
    test_text = "\n".join(ds["test"]["text"])

    cache_dir = "wikitext-2-raw"
    os.makedirs(cache_dir, exist_ok=True)
    with io.open(os.path.join(cache_dir, "wiki.train.raw"), "w", encoding="utf-8") as f:
        f.write(train_text)
    with io.open(os.path.join(cache_dir, "wiki.valid.raw"), "w", encoding="utf-8") as f:
        f.write(valid_text)
    with io.open(os.path.join(cache_dir, "wiki.test.raw"), "w", encoding="utf-8") as f:
        f.write(test_text)

    print(f"Cached WikiText-2 raw to '{cache_dir}/'.")
    return train_text, valid_text, test_text


def build_char_vocab(text: str) -> Tuple[Dict[str, int], Dict[int, str]]:
    chars = sorted(list(set(text)))
    stoi = {ch: i for i, ch in enumerate(chars)}
    itos = {i: ch for ch, i in stoi.items()}
    return stoi, itos


def encode(text: str, stoi: Dict[str, int]) -> np.ndarray:
    return np.fromiter((stoi[ch] for ch in text if ch in stoi), dtype=np.int64)


def batch_iterator(
    data: np.ndarray,
    vocab_size: int,
    block_size: int,
    batch_size: int,
    steps: int,
    rng: np.random.Generator,
) -> Iterable[Tuple[np.ndarray, np.ndarray]]:
    """
    inputs:  (B, vocab, T) one-hot
    targets: (B, vocab, T) 下一字符 one-hot
    """
    N = data.shape[0]
    for _ in range(steps):
        starts = rng.integers(0, max(1, N - block_size - 1), size=(batch_size,))
        x = np.zeros((batch_size, vocab_size, block_size), dtype=np.float32)
        y = np.zeros((batch_size, vocab_size, block_size), dtype=np.float32)
        for i, s in enumerate(starts):
            seq = data[s : s + block_size + 1]
            for t in range(block_size):
                x[i, seq[t], t] = 1.0
                y[i, seq[t + 1], t] = 1.0
        yield x, y


def evaluate_language_model(
    model: Any,
    data: np.ndarray,
    vocab_size: int,
    block_size: int,
    batch_size: int,
    steps: int,
) -> float:
    """
    评估时按 token 计算平均 CE（nats per char）：
      - forward_cycle 输出 list[T] of (vocab, batch)
      - 每个时间步单独 softmax（按 vocab 维）
      - 再按 batch 求 CE，并在所有时间步累计
    """
    total_loss = 0.0
    total_tokens = 0
    eps = 1e-12
    rng = np.random.default_rng(12345)

    for inputs, targets in batch_iterator(
        data, vocab_size, block_size, batch_size, steps, rng
    ):
        batch_size_eval = inputs.shape[0]
        h_prev = np.zeros((model.hidden_size, batch_size_eval), dtype=np.float32)

        outputs_seq, _ = model.forward_cycle(inputs, h_prev)  # list[T] of (vocab, batch)
        outputs_stack = np.stack(outputs_seq, axis=2)  # (vocab, batch, time)
        T = outputs_stack.shape[2]

        for t in range(T):
            logits_t = outputs_stack[:, :, t]  # (vocab, batch)
            probs_t = softmax(logits_t)        # (vocab, batch)
            probs_batch = probs_t.T            # (batch, vocab)

            y_t = targets[:, :, t]             # (batch, vocab)
            ce_t = -np.sum(y_t * np.log(probs_batch + eps), axis=1)  # (batch,)
            total_loss += float(np.sum(ce_t))
            total_tokens += batch_size_eval

    return total_loss / max(total_tokens, 1)


# ===== Lyapunov (Benettin-style QR) =====

def calculate_lyapunov_exponent_numpy(model: Any, driver_input: np.ndarray) -> float:
    n_hidden = int(model.hidden_size)
    W_hh = np.asarray(model.W_hh, dtype=np.float64)
    W_xh = np.asarray(model.W_xh, dtype=np.float64)
    b_h = np.asarray(model.b_h, dtype=np.float64)

    h = np.zeros((n_hidden, 1), dtype=np.float64)
    Q = np.eye(n_hidden, dtype=np.float64)
    log_r_diag_sum = np.zeros(n_hidden, dtype=np.float64)
    log_floor = 1e-12
    T = int(driver_input.shape[1])

    for t in range(T):
        I_t = driver_input[:, t].reshape(-1, 1).astype(np.float64)
        x_t = W_hh @ h + W_xh @ I_t + b_h
        h_next = np.tanh(x_t)
        phi_prime = (1.0 - h_next**2).flatten()
        jacobian = (phi_prime[:, None]) * W_hh
        Z = jacobian @ Q
        try:
            Q, R = np.linalg.qr(Z)
        except np.linalg.LinAlgError:
            return float("nan")
        r_diag_abs = np.abs(np.diag(R))
        log_r_diag_sum += np.log(np.clip(r_diag_abs, log_floor, None))
        h = h_next

    return float(np.max(log_r_diag_sum / max(T, 1)))


def make_lyapunov_driver_from_data(
    data: np.ndarray, vocab_size: int, block_size: int, batch_size: int
) -> np.ndarray:
    rng = np.random.default_rng(999)
    x_batch, _ = next(
        batch_iterator(data, vocab_size, block_size, batch_size, steps=1, rng=rng)
    )
    return np.mean(x_batch, axis=0).astype(np.float32)


# =============================
# RNN Implementations
# =============================

class NumpyLocalRuleRNN:
    """
    与加法任务版本对齐的 Local Rule：
      - g_t = W_hy^T dL/dy_hat
      - u_t = 1 - h_t^2
      - alpha_hat 用 EMA(h_t h_{t-1}) / EMA(h_{t-1}^2) 并 clip 到 [-0.99,0.99]
      - lambda 用 AR(1) 估计，同样做自适应投影
      - 本实验中使用 CE loss
      - 这里 **有梯度裁剪**（max_grad_norm = 5）
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        eta: float = 0.001,
        lambda_window: int = 50,
    ) -> None:
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.output_size = output_size
        self.eta = eta
        self.epsilon = 1e-8
        self.max_grad_norm = 5.0

        # 参数
        self.W_xh = np.random.randn(hidden_size, input_size) * 0.1
        self.W_hh = np.random.randn(hidden_size, hidden_size)
        self.b_h = np.zeros((hidden_size, 1))
        self.W_hy = np.random.randn(output_size, hidden_size) * 0.1
        self.b_y = np.zeros((output_size, 1))

        # alpha 相关
        self.alpha_rho = 0.995
        self.alpha_clip_min = -0.99
        self.alpha_clip_max = 0.99

        # lambda 相关
        self.lambda_window = max(2, int(lambda_window))
        self.lambda_rho = (self.lambda_window - 1) / self.lambda_window
        self.lambda_cap = 0.99
        self.denom_floor = 1e-3
        self.eps_lambda = 1e-8

        # 诊断信息
        self.diagnostics: Dict[str, Any] = {}
        self.reset_learning_state()
        self.reset_diagnostics()

        # FPTT 相关（本实验不用，只保留接口）
        self.use_fptt_surrogates = False
        self.beta_schedule = "linear"
        self.fptt_Q_prev: np.ndarray | None = None
        self.fptt_Q_sum: np.ndarray | None = None
        self.fptt_Q_count: np.ndarray | None = None

    # ----- diagnostics -----
    def reset_diagnostics(self) -> None:
        self.diagnostics = {
            "total_neurons_steps": 0,
            "lambda_clip_count": 0,
            "denom_floor_hits": 0,
            "lambda_history": [],
        }

    def reset_learning_state(self) -> None:
        shape = (self.hidden_size, 1)
        self.alpha_num = np.zeros(shape)
        self.alpha_den = np.zeros(shape)
        self.alpha_hat = np.zeros(shape)

        self.S_A2 = np.zeros(shape)
        self.S_AB = np.zeros(shape)
        self.lambda_vals = np.zeros(shape)

    def initialize_weights_with_gain(self, g: float) -> None:
        std_dev = g / np.sqrt(self.hidden_size)
        self.W_hh = np.random.randn(*self.W_hh.shape) * std_dev

    # ----- 单个 batch（一个 cycle）的前向 + 更新 -----
    def run_one_cycle_and_update_directly(
        self,
        inputs_cycle: np.ndarray,
        targets_cycle: np.ndarray,
        h_prev_cycle: np.ndarray,
    ) -> Tuple[float, np.ndarray]:
        h_prev = h_prev_cycle
        total_cycle_loss = 0.0
        batch_size = inputs_cycle.shape[0]
        time_steps = inputs_cycle.shape[2]
        eps = 1e-12

        prev_g: np.ndarray | None = None
        prev_u: np.ndarray | None = None
        prev_delta: np.ndarray | None = None

        for t in range(time_steps):
            I_t = inputs_cycle[:, :, t].T       # (input, batch)
            y_target_t = targets_cycle[:, :, t].T  # (classes, batch)

            x_t = self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h
            h_t = np.tanh(x_t)
            logits_t = self.W_hy @ h_t + self.b_y
            P_t = softmax(logits_t)  # (classes, batch)

            # CE loss
            loss_t = -np.mean(np.sum(y_target_t * np.log(P_t + eps), axis=0))
            total_cycle_loss += float(loss_t)

            dL_dyhat = P_t - y_target_t  # (classes, batch)

            g_t = self.W_hy.T @ dL_dyhat   # (hidden, batch)
            u_t = 1.0 - h_t**2             # (hidden, batch)

            # ---- 用“上一时刻”的 lambda 做更新 ----
            lambda_used = self.lambda_vals.copy()

            denominator = 1.0 - lambda_used * u_t
            denom_mask = np.abs(denominator) < self.denom_floor
            if np.any(denom_mask):
                self.diagnostics["denom_floor_hits"] += int(np.sum(denom_mask))
            denominator = np.where(
                denom_mask,
                self.denom_floor * np.sign(denominator),
                denominator,
            )
            delta_t = (u_t * g_t) / denominator

            # ---- alpha 更新（基于 teaching signal 的短窗 AR(1)） ----
            if prev_delta is not None:
                dtp_mean = np.mean(delta_t * prev_delta, axis=1, keepdims=True)
                dpp_mean = np.mean(prev_delta**2, axis=1, keepdims=True)
                self.alpha_num = self.alpha_rho * self.alpha_num + (1.0 - self.alpha_rho) * dtp_mean
                self.alpha_den = self.alpha_rho * self.alpha_den + (1.0 - self.alpha_rho) * dpp_mean
                raw_alpha = self.alpha_num / (self.alpha_den + self.epsilon)
                self.alpha_hat = np.clip(raw_alpha, self.alpha_clip_min, self.alpha_clip_max)

            dW_hh = (delta_t @ h_prev.T) / batch_size
            dW_xh = (delta_t @ I_t.T) / batch_size
            db_h = np.mean(delta_t, axis=1, keepdims=True)
            dW_hy = (dL_dyhat @ h_t.T) / batch_size
            db_y = np.mean(dL_dyhat, axis=1, keepdims=True)

            # Local Rule：梯度裁剪
            if self.max_grad_norm > 0:
                grads = [dW_xh, dW_hh, db_h, dW_hy, db_y]
                dW_xh, dW_hh, db_h, dW_hy, db_y = normalize_and_clip_gradients(
                    grads,
                    norm_scale=1.0,           # 已平均过 batch
                    max_norm=self.max_grad_norm,
                )

            self.W_hh -= self.eta * dW_hh
            self.W_xh -= self.eta * dW_xh
            self.b_h -= self.eta * db_h
            self.W_hy -= self.eta * dW_hy
            self.b_y -= self.eta * db_y

            # ---- 步末：用 (prev_g, prev_u) & (g_t, u_t) 更新 lambda_vals ----
            if prev_g is not None and prev_u is not None:
                A_s = prev_u * u_t * (self.alpha_hat * prev_g - g_t)
                B_s = self.alpha_hat * prev_u * prev_g - u_t * g_t

                A2_mean = np.mean(A_s**2, axis=1, keepdims=True)
                AB_mean = np.mean(A_s * B_s, axis=1, keepdims=True)

                self.S_A2 = self.lambda_rho * self.S_A2 + (1.0 - self.lambda_rho) * A2_mean
                self.S_AB = self.lambda_rho * self.S_AB + (1.0 - self.lambda_rho) * AB_mean

                lambda_unproj = self.S_AB / (self.S_A2 + self.eps_lambda)

                u_abs_max = np.max(np.abs(u_t), axis=1, keepdims=True) + 1e-12
                safe_cap = (1.0 - self.denom_floor) / u_abs_max
                cap = np.minimum(safe_cap, self.lambda_cap)
                lambda_clipped = np.clip(lambda_unproj, -cap, cap)

                self.diagnostics["lambda_clip_count"] += int(
                    np.sum(lambda_unproj != lambda_clipped)
                )
                self.lambda_vals = lambda_clipped

            prev_g = g_t
            prev_u = u_t
            prev_delta = delta_t
            h_prev = h_t

        return total_cycle_loss, h_prev

    def forward_cycle(
        self, inputs: np.ndarray, h_prev: np.ndarray
    ) -> Tuple[List[np.ndarray], np.ndarray]:
        outputs: List[np.ndarray] = []
        for t in range(inputs.shape[2]):
            I_t = inputs[:, :, t].T
            h_t = np.tanh(self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h)
            y_t = self.W_hy @ h_t + self.b_y  # (vocab, batch)
            outputs.append(y_t)
            h_prev = h_t
        return outputs, h_prev


class BPTTRNN:
    """
    纯 BPTT 语言模型版本：**不做梯度裁剪**，只做 batch*time 归一化。
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        eta: float = 0.001,
    ) -> None:
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.eta = eta

        self.W_xh = np.random.randn(hidden_size, input_size) * 0.1
        self.W_hh = np.random.randn(hidden_size, hidden_size)
        self.b_h = np.zeros((hidden_size, 1))
        self.W_hy = np.random.randn(output_size, hidden_size) * 0.1
        self.b_y = np.zeros((output_size, 1))

    def initialize_weights_with_gain(self, g: float) -> None:
        self.W_hh = np.random.randn(*self.W_hh.shape) * (g / np.sqrt(self.hidden_size))

    def train_batch(
        self,
        inputs: np.ndarray,
        targets: np.ndarray,
        h_prev: np.ndarray,
    ) -> Tuple[float, np.ndarray]:
        batch_size, time_steps = inputs.shape[0], inputs.shape[2]
        eps = 1e-12

        h_acts: Dict[int, np.ndarray] = {-1: h_prev}
        y_probs: Dict[int, np.ndarray] = {}
        total_loss = 0.0

        # forward
        for t in range(time_steps):
            I_t = inputs[:, :, t].T
            y_target_t = targets[:, :, t].T

            h_t = np.tanh(self.W_hh @ h_acts[t - 1] + self.W_xh @ I_t + self.b_h)
            logits_t = self.W_hy @ h_t + self.b_y
            probs_t = softmax(logits_t)

            h_acts[t] = h_t
            y_probs[t] = probs_t

            total_loss -= np.sum(y_target_t * np.log(probs_t + eps))

        # backward
        dW_xh = np.zeros_like(self.W_xh)
        dW_hh = np.zeros_like(self.W_hh)
        db_h = np.zeros_like(self.b_h)
        dW_hy = np.zeros_like(self.W_hy)
        db_y = np.zeros_like(self.b_y)
        dh_next = np.zeros((self.hidden_size, batch_size))

        for t in reversed(range(time_steps)):
            y_target_t = targets[:, :, t].T
            probs_t = y_probs[t]
            h_t = h_acts[t]
            h_prev_t = h_acts[t - 1]

            dy = probs_t - y_target_t  # (vocab, batch)

            dW_hy += dy @ h_t.T
            db_y += np.sum(dy, axis=1, keepdims=True)

            dh = self.W_hy.T @ dy + dh_next
            dtanh = (1.0 - h_t**2) * dh

            dW_hh += dtanh @ h_prev_t.T
            dW_xh += dtanh @ inputs[:, :, t]
            db_h += np.sum(dtanh, axis=1, keepdims=True)

            dh_next = self.W_hh.T @ dtanh

        # 平均，但不裁剪
        norm_scale = max(1.0, float(batch_size))
        dW_xh /= norm_scale
        dW_hh /= norm_scale
        db_h /= norm_scale
        dW_hy /= norm_scale
        db_y /= norm_scale

        self.W_xh -= self.eta * dW_xh
        self.W_hh -= self.eta * dW_hh
        self.b_h -= self.eta * db_h
        self.W_hy -= self.eta * dW_hy
        self.b_y -= self.eta * db_y

        return total_loss / (batch_size * time_steps), h_acts[time_steps - 1]

    def forward_cycle(
        self, inputs: np.ndarray, h_prev: np.ndarray
    ) -> Tuple[List[np.ndarray], np.ndarray]:
        outputs: List[np.ndarray] = []
        for t in range(inputs.shape[2]):
            I_t = inputs[:, :, t].T
            h_t = np.tanh(self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h)
            y_t = self.W_hy @ h_t + self.b_y
            outputs.append(y_t)  # (vocab, batch)
            h_prev = h_t
        return outputs, h_prev


class FPTTRNN(StrictFPTTClassifier):
    """Strict FPTT sequence model for WikiText-2 (char LM, no oracle)."""

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        eta: float = 1e-3,
        parts: int = FPTT_PARTS,
        oracle_momentum: float = FPTT_ORACLE_MOMENTUM,
        clip: float = 5.0,
        alpha: float = 0.1,
        beta: float = 0.5,
        rho: float = 0.0,
        lmbda: float = FPTT_LAMBDA,
        warmup_epochs: int = 5,
        oracle_id: str = "wikitext2",
        **_: Any,
    ) -> None:
        super().__init__(
            input_size=input_size,
            hidden_size=hidden_size,
            output_size=output_size,
            eta=eta,
            parts=parts,
            clip=clip,
            alpha=alpha,
            beta=beta,
            rho=rho,
            lmbda=lmbda,
            oracle_momentum=oracle_momentum,
            warmup_epochs=warmup_epochs,
            oracle_id=oracle_id,
            label_mode="all",
            use_oracle=False,
        )


# =============================
# Main Experiment Logic
# =============================

def main() -> None:
    apply_plot_style()
    np.random.seed(42)
    start_time = time.time()

    # --- 数据加载与预处理 ---
    print("Loading WikiText-2 raw and building vocab...")
    train_text, valid_text, _ = try_load_wikitext2_raw(DATA_PATH)
    stoi, itos = build_char_vocab(train_text)
    vocab_size = len(stoi)
    train_ids = encode(train_text, stoi)
    valid_ids = encode(valid_text, stoi)
    lyap_driver = make_lyapunov_driver_from_data(
        train_ids, vocab_size, BLOCK_SIZE, BATCH_SIZE
    )
    print(f"Vocab size: {vocab_size}, Data loaded.")

    # ========================================================================
    # STAGE 1: Find the best performing Local Rule model near critical state
    # ========================================================================
    print("\n" + "=" * 60)
    print("STAGE 1: Finding optimal 'g' for Local Rule RNN (FPTT OFF)...")
    print("=" * 60)

    scan_results: List[Dict[str, Any]] = []
    for g in GAINS_SCAN:
        model = NumpyLocalRuleRNN(vocab_size, HIDDEN_SIZE, vocab_size, eta=LEARNING_RATE)
        model.initialize_weights_with_gain(g)

        lambda_pre = calculate_lyapunov_exponent_numpy(model, lyap_driver)
        initial_params = extract_params_common(model)

        rng = np.random.default_rng(123)
        for epoch in range(EPOCHS):
            for inputs, targets in batch_iterator(
                train_ids,
                vocab_size,
                BLOCK_SIZE,
                BATCH_SIZE,
                STEPS_PER_EPOCH,
                rng,
            ):
                h_prev = np.zeros((model.hidden_size, inputs.shape[0]), dtype=np.float32)
                model.run_one_cycle_and_update_directly(inputs, targets, h_prev)

        val_loss = evaluate_language_model(
            model, valid_ids, vocab_size, BLOCK_SIZE, BATCH_SIZE, VAL_STEPS
        )
        val_bpc = val_loss / np.log(2.0)
        lambda_post = calculate_lyapunov_exponent_numpy(model, lyap_driver)

        print(
            f"[SCAN] g={g:.3f} | Val BPC: {val_bpc:.4f} | "
            f"Lyapunov (pre, post) = ({lambda_pre:.4f}, {lambda_post:.4f}), "
            f"Δ={lambda_post - lambda_pre:.4f}"
        )

        if not (np.isnan(lambda_pre) or np.isnan(lambda_post)):
            scan_results.append(
                {
                    "g": g,
                    "bpc": val_bpc,
                    "lambda_pre": lambda_pre,
                    "lambda_post": lambda_post,
                    "params": initial_params,
                }
            )

    top_performers = sorted(scan_results, key=lambda x: x["bpc"])[:3]
    best_critical_model = min(top_performers, key=lambda x: abs(x["lambda_pre"]))
    g_optimal = best_critical_model["g"]
    initial_params_optimal = best_critical_model["params"]

    print("\n--- Stage 1 Summary ---")
    print(f"Optimal gain 'g' found: {g_optimal:.4f}")
    print(f"  - Achieved BPC: {best_critical_model['bpc']:.4f}")
    print(f"  - Lyapunov Exponent PRE (close to 0): {best_critical_model['lambda_pre']:.4f}")
    print(f"  - Lyapunov Exponent POST: {best_critical_model['lambda_post']:.4f}")

    # ========================================================================
    # STAGE 2: Compare algorithms at the optimal point
    # ========================================================================
    print("\n" + "=" * 60)
    print(f"STAGE 2: Comparing algorithms with optimal g = {g_optimal:.4f}")
    print("=" * 60)

    models_to_compare: Dict[str, Any] = {
        "Local Rule": NumpyLocalRuleRNN(vocab_size, HIDDEN_SIZE, vocab_size, eta=LEARNING_RATE),
        "BPTT": BPTTRNN(vocab_size, HIDDEN_SIZE, vocab_size, eta=LEARNING_RATE),
        "E-Prop": StandardEPropRNN(
            vocab_size,
            HIDDEN_SIZE,
            vocab_size,
            eta=LEARNING_RATE,
            feedback=EPROP_FEEDBACK,
            seed=EPROP_SEED,
            loss_mode="ce",
            device="cpu",
        ),
        "FPTT": FPTTRNN(
            vocab_size,
            HIDDEN_SIZE,
            vocab_size,
            eta=LEARNING_RATE,
            parts=FPTT_PARTS,
        ),
    }

    comparison_results: Dict[str, Dict[str, Any]] = {}

    for name, model in models_to_compare.items():
        print(f"\nTraining {name}...")

        # 用 Stage 1 的“临界初始化”初始化其它算法
        load_params_common(model, initial_params_optimal)

        if hasattr(model, "reset_learning_state"):
            model.reset_learning_state()
        if hasattr(model, "reset_state_buffers"):
            model.reset_state_buffers()

        lambda_pre_model = calculate_lyapunov_exponent_numpy(model, lyap_driver)

        history: List[float] = []
        rng = np.random.default_rng(456)

        for epoch in range(EPOCHS):
            if hasattr(model, "set_epoch"):
                model.set_epoch(epoch)

            for inputs, targets in batch_iterator(
                train_ids,
                vocab_size,
                BLOCK_SIZE,
                BATCH_SIZE,
                STEPS_PER_EPOCH,
                rng,
            ):
                h_prev = np.zeros((model.hidden_size, inputs.shape[0]), dtype=np.float32)

                if isinstance(model, NumpyLocalRuleRNN):
                    model.run_one_cycle_and_update_directly(inputs, targets, h_prev)
                elif isinstance(model, StandardEPropRNN):
                    model.train_batch(inputs, targets, h_prev)
                else:
                    # BPTT / FPTT
                    model.train_batch(inputs, targets, h_prev)

            val_loss = evaluate_language_model(
                model, valid_ids, vocab_size, BLOCK_SIZE, BATCH_SIZE, VAL_STEPS
            )
            val_bpc = val_loss / np.log(2.0)
            history.append(val_bpc)
            print(f"  Epoch {epoch+1}/{EPOCHS} | Val BPC: {val_bpc:.4f}")

        final_lambda = calculate_lyapunov_exponent_numpy(model, lyap_driver)

        comparison_results[name] = {
            "bpc": history[-1],
            "lambda_pre": lambda_pre_model,
            "lambda_post": final_lambda,
            "lambda": final_lambda,
            "history": history,
        }

        print(
            f"-> Final Results for {name}: BPC={history[-1]:.4f}, "
            f"Lyapunov pre={lambda_pre_model:.4f}, post={final_lambda:.4f}, "
            f"Δ={final_lambda - lambda_pre_model:.4f}"
        )

    print(
        f"\nExperiment finished! Total time: {(time.time() - start_time) / 60:.2f} minutes"
    )

    # ========================================================================
    # STAGE 3: Plotting the comparison results
    # ========================================================================
    names = list(comparison_results.keys())
    bpcs = [comparison_results[n]["bpc"] for n in names]
    lyapunovs = [comparison_results[n]["lambda"] for n in names]

    fig, axs = plt.subplots(1, 3, figsize=(22, 7), constrained_layout=True)
    fig.suptitle(
        f"RNN Algorithm Comparison on WikiText-2 (Initialized at g={g_optimal:.2f})",
        fontsize=20,
        weight="bold",
    )
    colors = [cm.viridis(x) for x in np.linspace(0.2, 0.9, len(names))]

    # --- a) Final Performance ---
    bars = axs[0].bar(names, bpcs, color=colors, zorder=3)
    axs[0].set_ylabel("Final Validation BPC (lower is better)", fontsize=12)
    axs[0].set_title("a) Generalization Performance", fontsize=14, weight="bold")
    axs[0].grid(True, axis="y", linestyle=":", zorder=0)
    for bar in bars:
        axs[0].text(
            bar.get_x() + bar.get_width() / 2.0,
            bar.get_height(),
            f"{bar.get_height():.3f}",
            ha="center",
            va="bottom",
        )
    axs[0].tick_params(axis="x", rotation=15)

    # --- b) Post-Training Dynamics ---
    bars = axs[1].bar(names, lyapunovs, color=colors, zorder=3)
    axs[1].axhline(
        0, color="crimson", linestyle="--", label="Edge of Chaos (λ=0)", zorder=2
    )
    axs[1].set_ylabel("Max Lyapunov Exponent (λ_max)", fontsize=12)
    axs[1].set_title("b) Post-Training Dynamics", fontsize=14, weight="bold")
    axs[1].grid(True, axis="y", linestyle=":", zorder=0)
    axs[1].legend()
    for bar in bars:
        axs[1].text(
            bar.get_x() + bar.get_width() / 2.0,
            bar.get_height() + np.sign(bar.get_height()) * 0.01,
            f"{bar.get_height():.3f}",
            ha="center",
            va="bottom",
        )
    axs[1].tick_params(axis="x", rotation=15)

    # --- c) Learning Curves ---
    for i, name in enumerate(names):
        axs[2].plot(
            np.arange(1, EPOCHS + 1),
            comparison_results[name]["history"],
            label=name,
            color=colors[i],
            lw=2.5,
            marker="o",
            ms=5,
        )
    axs[2].set_xlabel("Epoch", fontsize=12)
    axs[2].set_ylabel("Validation BPC", fontsize=12)
    axs[2].set_title("c) Learning Curves", fontsize=14, weight="bold")
    axs[2].grid(True, linestyle=":")
    axs[2].legend(fontsize=11)
    axs[2].set_xticks(np.arange(1, EPOCHS + 1))

    plt.show()


if __name__ == "__main__":
    main()
