import sys
import time
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch

from plotting_utils import apply_plot_style

COMPARE_DIR = Path(__file__).resolve().parents[1]
if str(COMPARE_DIR) not in sys.path:
    sys.path.insert(0, str(COMPARE_DIR))

from methods.strict_fptt import StrictFPTTClassifier
from methods.standard_eprop import StandardEPropRNN

# =============================
# Optimization helpers
# =============================

def normalize_and_clip_gradients(
    grads: List[np.ndarray],
    norm_scale: float,
    max_norm: float | None,
) -> List[np.ndarray]:
    """
    1. Scale gradients by 1/norm_scale (usually batch_size).
    2. Clip gradients by global norm.
    """
    safe_scale = max(1.0, float(norm_scale))
    normalized = [g / safe_scale for g in grads]
    
    if max_norm is None or max_norm <= 0:
        return normalized
    
    total_norm_sq = 0.0
    for g in normalized:
        total_norm_sq += float(np.sum(g**2))
    total_norm = float(np.sqrt(total_norm_sq))
    
    if not np.isfinite(total_norm) or total_norm <= max_norm:
        return normalized
    
    clip = max_norm / (total_norm + 1e-8)
    return [g * clip for g in normalized]

# =============================
# Config 
# =============================
VOCAB_SIZE = 10000     # Keras IMDB default top words
MAXLEN = 200           # truncate / pad length (left-pad)
EMBED_DIM = 64         # fixed random embedding dim
TRAIN_LIMIT = 5000     # Increase this for better convergence
TEST_LIMIT  = 2000
BATCH_SIZE  = 64
EPOCHS      = 12
LEARNING_RATE = 1e-3
HIDDEN_SIZE   = 128
GAINS = np.linspace(0.5, 1.6, 13, endpoint=False) 
CLIP_NORM = 5.0        # Global grad-norm clipping (for Local Rule & FPTT)
INIT_SEED = 2025       

BPTT_USE_FPTT = False  # Set to True to test BPTT with mixed targets

EPROP_FEEDBACK = "symmetric"  
EPROP_SEED = 1234

FPTT_PARTS = 8
FPTT_ORACLE_MOMENTUM = 1.0
FPTT_LAMBDA = 0.5          # 保留参数名，当前 strict_fptt 中不会用到
FPTT_USE_FED_DYN = False   # 同上
FPTT_WARMUP_EPOCHS = 20

# =============================
# Utilities
# =============================

def calculate_lyapunov_exponent_numpy(model: "BaseRNN", driver_input: np.ndarray) -> float:
    n_hidden = model.hidden_size
    W_hh = model.W_hh.astype(np.float64, copy=False)
    W_xh = model.W_xh.astype(np.float64, copy=False)
    b_h  = model.b_h.astype(np.float64,  copy=False)

    h = np.zeros((n_hidden, 1), dtype=np.float64)
    Q = np.eye(n_hidden, dtype=np.float64)
    log_r_diag_sum = np.zeros(n_hidden, dtype=np.float64)
    log_floor = 1e-12

    T = driver_input.shape[1]
    for t in range(T):
        I_t = driver_input[:, t].reshape(-1, 1).astype(np.float64)
        x_t = W_hh @ h + W_xh @ I_t + b_h
        h_next = np.tanh(x_t)
        phi_prime = (1.0 - h_next**2).flatten()
        jacobian = np.diag(phi_prime) @ W_hh

        Z = jacobian @ Q
        try:
            Q, R = np.linalg.qr(Z)
        except np.linalg.LinAlgError:
            return float("nan")
        r_diag_abs = np.abs(np.diag(R))
        log_r_diag_sum += np.log(np.clip(r_diag_abs, log_floor, None))
        h = h_next

    lyapunov_exponents = log_r_diag_sum / max(T, 1)
    return float(np.max(lyapunov_exponents))

def softmax(logits: np.ndarray) -> np.ndarray:
    logits_shifted = logits - np.max(logits, axis=0, keepdims=True)
    exp_logits = np.exp(logits_shifted)
    return exp_logits / (np.sum(exp_logits, axis=0, keepdims=True) + 1e-12)

def pad_truncate(seqs: List[List[int]], maxlen: int, pad_value: int = 0) -> np.ndarray:
    out = np.full((len(seqs), maxlen), pad_value, dtype=np.int64)
    for i, s in enumerate(seqs):
        if len(s) >= maxlen:
            out[i] = np.array(s[-maxlen:], dtype=np.int64)
        else:
            out_len = len(s)
            out[i, -out_len:] = np.array(s, dtype=np.int64)
    return out

def load_imdb_index_sequences(
    train_limit: int | None = None,
    test_limit: int | None = None,
    vocab_size: int = VOCAB_SIZE,
    maxlen: int = MAXLEN,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    try:
        from tensorflow.keras.datasets import imdb 
    except Exception as exc:
        raise RuntimeError("TensorFlow/Keras not installed.") from exc

    (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=vocab_size)
    if train_limit is not None:
        x_train = x_train[:train_limit]
        y_train = y_train[:train_limit]
    if test_limit is not None:
        x_test = x_test[:test_limit]
        y_test = y_test[:test_limit]

    x_train_idx = pad_truncate(x_train, maxlen=maxlen, pad_value=0)
    x_test_idx  = pad_truncate(x_test,  maxlen=maxlen, pad_value=0)

    y_train = np.asarray(y_train, dtype=np.int64)
    y_test  = np.asarray(y_test,  dtype=np.int64)
    return x_train_idx, y_train, x_test_idx, y_test

def build_frozen_embeddings(vocab_size: int, embed_dim: int, seed: int = 42) -> np.ndarray:
    rng = np.random.default_rng(seed)
    E = rng.standard_normal((embed_dim, vocab_size)).astype(np.float32) * 0.1
    E[:, 0] = 0.0
    return E

def indices_to_emb_sequences(idx: np.ndarray, E: np.ndarray) -> np.ndarray:
    N, T = idx.shape
    D = E.shape[0]
    out = np.zeros((N, D, T), dtype=np.float32)
    for t in range(T):
        ids_t = idx[:, t]
        out[:, :, t] = E[:, ids_t].T 
    return out

def build_targets_from_labels(labels: np.ndarray, time_steps: int) -> np.ndarray:
    N = labels.shape[0]
    onehot = np.zeros((N, 2), dtype=np.float32)
    onehot[np.arange(N), labels] = 1.0
    targets = np.repeat(onehot[:, :, None], time_steps, axis=2)
    return targets

def iterate_minibatches(
    inputs: np.ndarray,
    targets: np.ndarray,
    batch_size: int,
    rng: np.random.Generator,
) -> Iterable[Tuple[np.ndarray, np.ndarray]]:
    indices = np.arange(inputs.shape[0])
    rng.shuffle(indices)
    for start_idx in range(0, inputs.shape[0], batch_size):
        batch_indices = indices[start_idx : start_idx + batch_size]
        yield inputs[batch_indices], targets[batch_indices]


def split_train_val(
    inputs: np.ndarray,
    targets: np.ndarray,
    labels: np.ndarray,
    val_fraction: float,
    rng: np.random.Generator,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    total = int(inputs.shape[0])
    val_size = max(1, int(total * float(val_fraction)))
    indices = rng.permutation(total)
    val_idx = indices[:val_size]
    train_idx = indices[val_size:]
    return (
        inputs[train_idx],
        targets[train_idx],
        labels[train_idx],
        inputs[val_idx],
        targets[val_idx],
        labels[val_idx],
    )


def extract_params(model: Any) -> Dict[str, np.ndarray]:
    params: Dict[str, np.ndarray] = {}
    for name in ("W_hh", "W_xh", "b_h", "W_hy", "b_y"):
        if not hasattr(model, name):
            continue
        value = getattr(model, name)
        if torch.is_tensor(value):
            params[name] = value.detach().cpu().numpy().copy()
        else:
            params[name] = np.asarray(value, dtype=np.float32).copy()
    return params


def load_params(model: Any, params: Dict[str, np.ndarray]) -> None:
    for name in ("W_hh", "W_xh", "b_h", "W_hy", "b_y"):
        if name not in params or not hasattr(model, name):
            continue
        value = params[name]
        current = getattr(model, name)
        if torch.is_tensor(current):
            tensor = torch.as_tensor(value, device=current.device, dtype=current.dtype)
            with torch.no_grad():
                current.copy_(tensor)
        else:
            setattr(model, name, np.asarray(value, dtype=np.float32).copy())

    if hasattr(model, "reset_state_buffers"):
        model.reset_state_buffers()
    if hasattr(model, "reset_optimizer"):
        model.reset_optimizer()
    elif hasattr(model, "optimizer") and hasattr(model.optimizer, "state"):
        model.optimizer.state.clear()

def evaluate_classifier_final_step(
    model: "BaseRNN",
    inputs: np.ndarray,
    targets: np.ndarray,
    labels: np.ndarray,
    batch_size: int,
) -> Tuple[float, float]:
    total_loss = 0.0
    total_samples = 0
    correct = 0
    eps = 1e-12

    for start_idx in range(0, inputs.shape[0], batch_size):
        end_idx = min(start_idx + batch_size, inputs.shape[0])
        input_batch = inputs[start_idx:end_idx]
        target_batch = targets[start_idx:end_idx]
        label_batch = labels[start_idx:end_idx]
        current_batch = input_batch.shape[0]

        h_prev_eval = np.zeros((model.hidden_size, current_batch), dtype=np.float32)
        outputs_seq, _ = model.forward_cycle(input_batch, h_prev_eval) 
        outputs_stack = np.stack(outputs_seq, axis=2)  # (classes, batch, time)
        outputs_batch = np.transpose(outputs_stack, (1, 0, 2))  # (batch, classes, time)

        final_logits = outputs_batch[:, :, -1]                 
        probs = softmax(final_logits.T).T                      
        y_final = target_batch[:, :, -1]                       
        ce = -np.sum(y_final * np.log(probs + eps), axis=1)    
        batch_loss = float(np.mean(ce))

        preds = np.argmax(final_logits, axis=1)
        correct += int(np.sum(preds == label_batch))

        total_loss += batch_loss * current_batch
        total_samples += current_batch

    avg_loss = total_loss / max(total_samples, 1)
    accuracy = correct / max(total_samples, 1)
    return avg_loss, accuracy

def orthogonal_matrix(n: int, rng: np.random.Generator) -> np.ndarray:
    A = rng.standard_normal((n, n))
    Q, R = np.linalg.qr(A)
    d = np.sign(np.diag(R))
    Q = Q * d
    return Q.astype(np.float32)

def make_initial_params(
    input_size: int,
    hidden_size: int,
    output_size: int,
    g: float,
    rng: np.random.Generator,
) -> Dict[str, np.ndarray]:
    W_hh = orthogonal_matrix(hidden_size, rng) * g
    W_xh = rng.standard_normal((hidden_size, input_size)).astype(np.float32) * 0.1
    b_h  = np.zeros((hidden_size, 1), dtype=np.float32)
    W_hy = rng.standard_normal((output_size, hidden_size)).astype(np.float32) * 0.1
    b_y  = np.zeros((output_size, 1), dtype=np.float32)
    return {"W_hh": W_hh, "W_xh": W_xh, "b_h": b_h, "W_hy": W_hy, "b_y": b_y}

# =============================
# Base RNN
# =============================
class BaseRNN:
    def __init__(self, input_size: int, hidden_size: int, output_size: int, eta: float = 1e-3) -> None:
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.eta = eta
        self.W_xh = np.random.randn(hidden_size, input_size).astype(np.float32) * 0.1
        self.W_hh = np.random.randn(hidden_size, hidden_size).astype(np.float32)
        self.b_h = np.zeros((hidden_size, 1), dtype=np.float32)
        self.W_hy = np.random.randn(output_size, hidden_size).astype(np.float32) * 0.1
        self.b_y = np.zeros((output_size, 1), dtype=np.float32)

    def set_params(self, params: Dict[str, np.ndarray]) -> None:
        self.W_hh = params["W_hh"].copy()
        self.W_xh = params["W_xh"].copy()
        self.b_h  = params["b_h"].copy()
        self.W_hy = params["W_hy"].copy()
        self.b_y  = params["b_y"].copy()

    def forward_cycle(self, inputs_cycle: np.ndarray, h_prev_cycle: np.ndarray) -> Tuple[List[np.ndarray], np.ndarray]:
        h_prev = h_prev_cycle
        outputs: List[np.ndarray] = []
        for t in range(inputs_cycle.shape[2]):
            I_t = inputs_cycle[:, :, t].T
            x_t = self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h
            h_t = np.tanh(x_t)
            y_hat_t = self.W_hy @ h_t + self.b_y
            outputs.append(y_hat_t)
            h_prev = h_t
        return outputs, h_prev

# =============================
# Local Rule
# =============================
class NumpyLocalRuleRNN(BaseRNN):
    def __init__(self, input_size: int, hidden_size: int, output_size: int, eta: float = 0.001, lambda_window: int = 50) -> None:
        super().__init__(input_size, hidden_size, output_size, eta)
        self.epsilon = 1e-8
        self.alpha_rho = 0.995
        self.alpha_clip_min = -0.99
        self.alpha_clip_max = 0.99
        self.lambda_window = max(2, int(lambda_window))
        self.lambda_rho = (self.lambda_window - 1) / self.lambda_window
        self.lambda_cap = 0.99
        self.denom_floor = 1e-3
        self.eps_lambda = 1e-8
        self.diagnostics: Dict[str, Any] = {}
        self.reset_learning_state()
        self.reset_diagnostics()
        self.use_fptt_surrogates = True
        self.beta_schedule = "linear"
        self.fptt_Q_prev: np.ndarray | None = None
        self.fptt_Q_sum: np.ndarray | None = None
        self.fptt_Q_count: np.ndarray | None = None

    def enable_fptt_surrogates(self, time_steps: int, output_size: int, Q_init: np.ndarray | None = None, beta_schedule: str = "linear") -> None:
        if Q_init is None:
            Q_init = np.full((output_size, time_steps), 1.0 / output_size, dtype=np.float32)
        self.use_fptt_surrogates = True
        self.beta_schedule = beta_schedule
        self.fptt_Q_prev = Q_init.astype(np.float32).copy()
        self.reset_fptt_epoch_accumulators(time_steps, output_size)

    def reset_fptt_epoch_accumulators(self, time_steps: int | None = None, output_size: int | None = None) -> None:
        if self.fptt_Q_prev is None and (time_steps is None or output_size is None):
            raise ValueError("Call enable_fptt_surrogates first.")
        if time_steps is None or output_size is None:
            time_steps = int(self.fptt_Q_prev.shape[1])
            output_size = int(self.fptt_Q_prev.shape[0])
        self.fptt_Q_sum = np.zeros((output_size, time_steps), dtype=np.float64)
        self.fptt_Q_count = np.zeros((time_steps,), dtype=np.int64)

    def finalize_fptt_epoch(self) -> None:
        if self.fptt_Q_sum is None or self.fptt_Q_count is None:
            return
        counts = np.maximum(1, self.fptt_Q_count.astype(np.float64))
        Q_new = (self.fptt_Q_sum / counts[None, :]).astype(np.float32)
        self.fptt_Q_prev = Q_new

    def reset_diagnostics(self) -> None:
        self.diagnostics = {
            "total_neurons_steps": 0,
            "lambda_clip_count": 0,
            "denom_floor_hits": 0,
            "lambda_history": [],
        }

    def get_lambda_clip_percentage(self) -> float:
        total = self.diagnostics["total_neurons_steps"]
        if total == 0:
            return 0.0
        return (self.diagnostics["lambda_clip_count"] / total) * 100.0

    def reset_learning_state(self) -> None:
        shape = (self.hidden_size, 1)
        self.alpha_num = np.zeros(shape, dtype=np.float32)
        self.alpha_den = np.zeros(shape, dtype=np.float32)
        self.alpha_hat = np.zeros(shape, dtype=np.float32)
        self.S_A2 = np.zeros(shape, dtype=np.float32)
        self.S_AB = np.zeros(shape, dtype=np.float32)
        self.lambda_vals = np.zeros(shape, dtype=np.float32)

    def run_one_cycle_and_update_directly(
        self,
        inputs_cycle: np.ndarray,
        targets_cycle: np.ndarray,
        h_prev_cycle: np.ndarray,
    ) -> Tuple[float, np.ndarray]:
        h_prev = h_prev_cycle
        total_cycle_loss = 0.0
        batch_size = inputs_cycle.shape[0]
        time_steps = inputs_cycle.shape[2]
        eps = 1e-12
        prev_g = None
        prev_u = None
        prev_delta = None

        for t in range(time_steps):
            I_t = inputs_cycle[:, :, t].T
            y_target_t = targets_cycle[:, :, t].T

            x_t = self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h
            h_t = np.tanh(x_t)
            y_hat_t = self.W_hy @ h_t + self.b_y

            if self.use_fptt_surrogates and self.fptt_Q_prev is not None:
                beta_t = float(t + 1) / float(time_steps) if self.beta_schedule == "linear" else 1.0
                P_t = softmax(y_hat_t)
                Q_t = self.fptt_Q_prev[:, t].reshape(-1, 1)
                Q_t_batch = np.repeat(Q_t, repeats=y_hat_t.shape[1], axis=1)
                Y_tilde = beta_t * y_target_t + (1.0 - beta_t) * Q_t_batch
                CE_true = -np.sum(y_target_t * np.log(P_t + eps), axis=0)
                CE_div  = -np.sum(Q_t_batch * np.log(P_t + eps), axis=0)
                loss_t = np.mean(beta_t * CE_true + (1.0 - beta_t) * CE_div)
                dL_dyhat = P_t - Y_tilde
                if self.fptt_Q_sum is not None and self.fptt_Q_count is not None:
                    self.fptt_Q_sum[:, t] += np.sum(P_t, axis=1)
                    self.fptt_Q_count[t] += batch_size
            else:
                P_t = softmax(y_hat_t)
                dL_dyhat = P_t - y_target_t
                loss_t = -np.mean(np.sum(y_target_t * np.log(P_t + eps), axis=0))

            total_cycle_loss += float(loss_t)

            g_t = self.W_hy.T @ dL_dyhat
            u_t = 1.0 - h_t**2

            lambda_used = self.lambda_vals.copy()
            self.diagnostics["lambda_history"].append(float(np.mean(lambda_used)))
            self.diagnostics["total_neurons_steps"] += self.hidden_size * batch_size

            denominator = 1.0 - lambda_used * u_t
            denom_mask = np.abs(denominator) < self.denom_floor
            if np.any(denom_mask):
                self.diagnostics["denom_floor_hits"] += int(np.sum(denom_mask))
            denominator = np.where(denom_mask, self.denom_floor * np.sign(denominator), denominator)
            delta_t = (u_t * g_t) / denominator

            # ---- alpha update (teaching-signal AR(1) in expectation) ----
            if prev_delta is not None:
                dtp_mean = np.mean(delta_t * prev_delta, axis=1, keepdims=True)
                dpp_mean = np.mean(prev_delta**2, axis=1, keepdims=True)
                self.alpha_num = self.alpha_rho * self.alpha_num + (1.0 - self.alpha_rho) * dtp_mean
                self.alpha_den = self.alpha_rho * self.alpha_den + (1.0 - self.alpha_rho) * dpp_mean
                raw_alpha = self.alpha_num / (self.alpha_den + self.epsilon)
                self.alpha_hat = np.clip(raw_alpha, self.alpha_clip_min, self.alpha_clip_max)

            dW_hh = (delta_t @ h_prev.T) / batch_size
            dW_xh = (delta_t @ I_t.T) / batch_size
            db_h  = np.mean(delta_t, axis=1, keepdims=True)
            dW_hy = (dL_dyhat @ h_t.T) / batch_size
            db_y  = np.mean(dL_dyhat, axis=1, keepdims=True)

            # Local rule 使用 5.0 的 global-norm 裁剪
            grads = [dW_xh, dW_hh, db_h, dW_hy, db_y]
            clipped_grads = normalize_and_clip_gradients(grads, norm_scale=1.0, max_norm=CLIP_NORM)
            dW_xh, dW_hh, db_h, dW_hy, db_y = clipped_grads

            self.W_hh -= self.eta * dW_hh
            self.W_xh -= self.eta * dW_xh
            self.b_h  -= self.eta * db_h
            self.W_hy -= self.eta * dW_hy
            self.b_y  -= self.eta * db_y

            if prev_g is not None and prev_u is not None:
                A_s = prev_u * u_t * (self.alpha_hat * prev_g - g_t)
                B_s = self.alpha_hat * prev_u * prev_g - u_t * g_t
                A2_mean = np.mean(A_s**2, axis=1, keepdims=True)
                AB_mean = np.mean(A_s * B_s, axis=1, keepdims=True)
                self.S_A2 = self.lambda_rho * self.S_A2 + (1.0 - self.lambda_rho) * A2_mean
                self.S_AB = self.lambda_rho * self.S_AB + (1.0 - self.lambda_rho) * AB_mean
                lambda_unproj = self.S_AB / (self.S_A2 + self.eps_lambda)
                u_abs_max = np.max(np.abs(u_t), axis=1, keepdims=True) + 1e-12
                safe_cap = (1.0 - self.denom_floor) / u_abs_max
                cap = np.minimum(safe_cap, self.lambda_cap)
                lambda_clipped = np.clip(lambda_unproj, -cap, cap)
                self.diagnostics["lambda_clip_count"] += int(np.sum(lambda_unproj != lambda_clipped))
                self.lambda_vals = lambda_clipped

            prev_g = g_t
            prev_u = u_t
            prev_delta = delta_t
            h_prev = h_t
        return total_cycle_loss / time_steps, h_prev

# =============================
# BPTT RNN (Fixed Implementation)
# =============================
class BPTTRNN(BaseRNN):
    def __init__(self, input_size: int, hidden_size: int, output_size: int, eta: float = 0.001) -> None:
        super().__init__(input_size, hidden_size, output_size, eta)
        self.use_fptt_surrogates = False
        self.beta_schedule = "linear"
        self.fptt_Q_prev: np.ndarray | None = None
        self.fptt_Q_sum: np.ndarray | None = None
        self.fptt_Q_count: np.ndarray | None = None

    def enable_fptt_surrogates(self, time_steps: int, output_size: int, Q_init: np.ndarray | None = None, beta_schedule: str = "linear") -> None:
        if Q_init is None:
            Q_init = np.full((output_size, time_steps), 1.0 / output_size, dtype=np.float32)
        self.use_fptt_surrogates = True
        self.beta_schedule = beta_schedule
        self.fptt_Q_prev = Q_init.astype(np.float32).copy()
        self.reset_fptt_epoch_accumulators(time_steps, output_size)

    def reset_fptt_epoch_accumulators(self, time_steps: int | None = None, output_size: int | None = None) -> None:
        if self.fptt_Q_prev is None and (time_steps is None or output_size is None):
            raise ValueError("Call enable_fptt_surrogates first.")
        if time_steps is None or output_size is None:
            time_steps = int(self.fptt_Q_prev.shape[1])
            output_size = int(self.fptt_Q_prev.shape[0])
        self.fptt_Q_sum = np.zeros((output_size, time_steps), dtype=np.float64)
        self.fptt_Q_count = np.zeros((time_steps,), dtype=np.int64)

    def finalize_fptt_epoch(self) -> None:
        if self.fptt_Q_sum is None or self.fptt_Q_count is None:
            return
        counts = np.maximum(1, self.fptt_Q_count.astype(np.float64))
        Q_new = (self.fptt_Q_sum / counts[None, :]).astype(np.float32)
        self.fptt_Q_prev = Q_new

    def train_batch(
        self,
        inputs_batch: np.ndarray,
        targets_batch: np.ndarray,
        h_prev_batch: np.ndarray,
    ) -> Tuple[float, np.ndarray]:
        batch_size = inputs_batch.shape[0]
        time_steps = inputs_batch.shape[2]
        eps = 1e-12

        # 1. Forward Pass: Cache states
        h_acts = {-1: h_prev_batch}
        y_probs_cache = {}
        target_cache = {}
        
        total_loss = 0.0

        for t in range(time_steps):
            I_t = inputs_batch[:, :, t].T 
            h_prev_t = h_acts[t - 1]
            
            x_t = self.W_hh @ h_prev_t + self.W_xh @ I_t + self.b_h
            h_t = np.tanh(x_t)
            y_hat_t = self.W_hy @ h_t + self.b_y
            h_acts[t] = h_t
            
            P_t = softmax(y_hat_t)
            y_target_t = targets_batch[:, :, t].T
            
            # Determine effective target for loss and gradient
            if self.use_fptt_surrogates and self.fptt_Q_prev is not None:
                beta_t = float(t + 1) / float(time_steps) if self.beta_schedule == "linear" else 1.0
                Q_t = self.fptt_Q_prev[:, t].reshape(-1, 1)
                Q_t_batch = np.repeat(Q_t, repeats=P_t.shape[1], axis=1)
                Y_tilde = beta_t * y_target_t + (1.0 - beta_t) * Q_t_batch
                
                CE_true = -np.sum(y_target_t * np.log(P_t + eps), axis=0)
                CE_div  = -np.sum(Q_t_batch * np.log(P_t + eps), axis=0)
                loss_t = np.mean(beta_t * CE_true + (1.0 - beta_t) * CE_div)
                
                if self.fptt_Q_sum is not None and self.fptt_Q_count is not None:
                    self.fptt_Q_sum[:, t] += np.sum(P_t, axis=1)
                    self.fptt_Q_count[t] += batch_size
            else:
                Y_tilde = y_target_t
                loss_t = -np.mean(np.sum(y_target_t * np.log(P_t + eps), axis=0))
            
            total_loss += float(loss_t)
            y_probs_cache[t] = P_t
            target_cache[t] = Y_tilde

        # 2. Backward Pass: Accumulate Gradients through Time
        dW_xh = np.zeros_like(self.W_xh)
        dW_hh = np.zeros_like(self.W_hh)
        db_h  = np.zeros_like(self.b_h)
        dW_hy = np.zeros_like(self.W_hy)
        db_y  = np.zeros_like(self.b_y)
        
        dh_next = np.zeros((self.hidden_size, batch_size), dtype=np.float32)

        for t in reversed(range(time_steps)):
            I_t = inputs_batch[:, :, t].T
            h_t = h_acts[t]
            h_prev_t = h_acts[t - 1]
            
            # Gradient of Loss w.r.t. output logits for step t
            dy = y_probs_cache[t] - target_cache[t]  # (Classes, Batch)

            # Accumulate Output Layer Gradients
            dW_hy += dy @ h_t.T
            db_y  += np.sum(dy, axis=1, keepdims=True)

            # Backprop into Hidden
            dh = self.W_hy.T @ dy + dh_next
            dtanh = (1.0 - h_t**2) * dh
            
            # Accumulate Hidden/Input Gradients
            dW_hh += dtanh @ h_prev_t.T
            dW_xh += dtanh @ I_t.T
            db_h  += np.sum(dtanh, axis=1, keepdims=True)

            # Pass to previous step
            dh_next = self.W_hh.T @ dtanh

        # 3. Normalize only (no clipping for BPTT baseline)
        grads = [dW_xh, dW_hh, db_h, dW_hy, db_y]
        grads = normalize_and_clip_gradients(
            grads,
            norm_scale=batch_size,
            max_norm=None,  # 只做平均，不裁剪
        )
        
        # 4. Update
        self.W_xh -= self.eta * grads[0]
        self.W_hh -= self.eta * grads[1]
        self.b_h  -= self.eta * grads[2]
        self.W_hy -= self.eta * grads[3]
        self.b_y  -= self.eta * grads[4]

        return total_loss / time_steps, h_acts[time_steps - 1]


class FPTTRNN(StrictFPTTClassifier):
    """Strict FPTT classifier with dataset-level oracle sharing for IMDB sentiment."""

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        eta: float = 1e-3,
        parts: int = FPTT_PARTS,
        oracle_momentum: float = FPTT_ORACLE_MOMENTUM,
        clip: float = CLIP_NORM,
        alpha: float = 0.1,
        beta: float = 0.5,
        rho: float = 0.0,
        lambda_reg: float = FPTT_LAMBDA,
        use_feddyn: bool = FPTT_USE_FED_DYN,
        warmup_epochs: int = FPTT_WARMUP_EPOCHS,
        oracle_id: str = "imdb",
        **_: Any,
    ) -> None:
        # lambda_reg is mapped to StrictFPTTClassifier's `lmbda`; use_feddyn is accepted but ignored.
        super().__init__(
            input_size=input_size,
            hidden_size=hidden_size,
            output_size=output_size,
            eta=eta,
            parts=parts,
            clip=clip,
            alpha=alpha,
            beta=beta,
            rho=rho,
            lmbda=lambda_reg,
            oracle_momentum=oracle_momentum,
            warmup_epochs=warmup_epochs,
            oracle_id=oracle_id,
            label_mode="last",
            use_oracle=True,
        )

# =============================
# Main
# =============================
def main():
    apply_plot_style()
    np.random.seed(42)

    print("Loading IMDB (index sequences)...")
    x_train_idx, y_train, x_test_idx, y_test = load_imdb_index_sequences(
        train_limit=TRAIN_LIMIT, test_limit=TEST_LIMIT, vocab_size=VOCAB_SIZE, maxlen=MAXLEN
    )
    E = build_frozen_embeddings(VOCAB_SIZE, EMBED_DIM)

    print("Building dense embedded sequences...")
    train_inputs = indices_to_emb_sequences(x_train_idx, E)  # (N, D, T)
    test_inputs = indices_to_emb_sequences(x_test_idx, E)

    train_targets = build_targets_from_labels(y_train, time_steps=train_inputs.shape[2])
    test_targets = build_targets_from_labels(y_test, time_steps=test_inputs.shape[2])

    # Standard protocol: use a validation split for hyperparameter selection/curves, and use test only once at the end.
    rng_split = np.random.default_rng(INIT_SEED)
    (
        train_inputs_fit,
        train_targets_fit,
        y_train_fit,
        val_inputs,
        val_targets,
        y_val,
    ) = split_train_val(train_inputs, train_targets, y_train, val_fraction=0.1, rng=rng_split)

    input_size = train_inputs.shape[1]
    time_steps = train_inputs.shape[2]
    output_size = 2

    lyapunov_driver = np.mean(train_inputs_fit[: min(32, train_inputs_fit.shape[0])], axis=0)
    total_start_time = time.time()

    print("" + "=" * 70)
    print("Stage 1: Scanning gains to find the most critical Local Rule model")
    print("=" * 70)

    local_rule_scan_results: List[Dict[str, Any]] = []

    for gi, g in enumerate(GAINS):
        print(f"[SCAN] Evaluating Local Rule at gain g={g:.3f}")
        seed_init = INIT_SEED * 1000 + gi
        seed_train = INIT_SEED * 1000 + gi + 500
        init_rng = np.random.default_rng(seed_init)
        init_params = make_initial_params(input_size, HIDDEN_SIZE, output_size, g, init_rng)

        local_model = NumpyLocalRuleRNN(input_size, HIDDEN_SIZE, output_size, eta=LEARNING_RATE, lambda_window=50)
        local_model.set_params(init_params)
        if getattr(local_model, "use_fptt_surrogates", False):
            Q0 = np.full((output_size, time_steps), 1.0 / output_size, dtype=np.float32)
            local_model.enable_fptt_surrogates(time_steps=time_steps, output_size=output_size, Q_init=Q0, beta_schedule="linear")

        lambda_pre = calculate_lyapunov_exponent_numpy(local_model, lyapunov_driver)

        train_rng = np.random.default_rng(seed_train)
        for epoch in range(EPOCHS):
            if getattr(local_model, "use_fptt_surrogates", False):
                local_model.reset_fptt_epoch_accumulators()
            for inputs_batch, targets_batch in iterate_minibatches(train_inputs_fit, train_targets_fit, BATCH_SIZE, train_rng):
                h_prev = np.zeros((local_model.hidden_size, inputs_batch.shape[0]), dtype=np.float32)
                local_model.run_one_cycle_and_update_directly(inputs_batch, targets_batch, h_prev)
            if getattr(local_model, "use_fptt_surrogates", False):
                local_model.finalize_fptt_epoch()

        val_loss, val_acc = evaluate_classifier_final_step(local_model, val_inputs, val_targets, y_val, BATCH_SIZE)
        lambda_post = calculate_lyapunov_exponent_numpy(local_model, lyapunov_driver)
        delta_lambda = (lambda_post - lambda_pre) if not (np.isnan(lambda_pre) or np.isnan(lambda_post)) else float("nan")
        
        print(
            f"  -> val_acc: {val_acc:.4f} | val_loss: {val_loss:.4f} | "
            f"Lyapunov (pre={lambda_pre:.4f}, post={lambda_post:.4f}, Δ={delta_lambda:.4f}) | "
            f"clip%={local_model.get_lambda_clip_percentage():.2f}%"
        )

        if not (np.isnan(lambda_pre) or np.isnan(lambda_post)):
            local_rule_scan_results.append({
                "g": g,
                "val_accuracy": float(val_acc),
                "val_loss": float(val_loss),
                "lambda_pre": lambda_pre,
                "lambda_post": lambda_post,
                "initial_params": init_params,
            })

    top_k = min(3, len(local_rule_scan_results))
    top_performers = sorted(local_rule_scan_results, key=lambda x: x["val_accuracy"], reverse=True)[:top_k]
    best_critical_model_info = min(top_performers, key=lambda x: abs(x["lambda_pre"]))

    g_optimal = best_critical_model_info["g"]
    initial_params_optimal = best_critical_model_info["initial_params"]

    print("--- Stage 1 Summary ---")
    print(f"Optimal gain g* = {g_optimal:.4f}")
    print(f"  Val Accuracy: {best_critical_model_info['val_accuracy']:.4f}")

    print("" + "=" * 70)
    print(f"Stage 2: Comparing learning rules initialized at g* = {g_optimal:.4f}")
    print("=" * 70)

    models_to_compare: Dict[str, BaseRNN] = {
        "Local Rule": NumpyLocalRuleRNN(input_size, HIDDEN_SIZE, output_size, eta=LEARNING_RATE, lambda_window=50),
        "BPTT": BPTTRNN(input_size, HIDDEN_SIZE, output_size, eta=LEARNING_RATE),
        "E-PROP": StandardEPropRNN(
            input_size,
            HIDDEN_SIZE,
            output_size,
            eta=LEARNING_RATE,
            feedback=EPROP_FEEDBACK,
            seed=EPROP_SEED,
            loss_mode="ce",
        ),
        "FPTT": FPTTRNN(
            input_size,
            HIDDEN_SIZE,
            output_size,
            eta=LEARNING_RATE,
            parts=FPTT_PARTS,                 
            oracle_momentum=FPTT_ORACLE_MOMENTUM,
            clip=CLIP_NORM,
            alpha=0.1,
            beta=0.5,
            rho=0.0,
            lambda_reg=FPTT_LAMBDA,
            use_feddyn=FPTT_USE_FED_DYN,
            warmup_epochs=FPTT_WARMUP_EPOCHS,
        ),
    }

    comparison_results: Dict[str, Dict[str, Any]] = {}
    eval_log_interval = max(1, EPOCHS // 5)

    for model_index, (name, model) in enumerate(models_to_compare.items()):
        print(f"Training {name}...")
        load_params(model, initial_params_optimal)
        
        if isinstance(model, NumpyLocalRuleRNN) and getattr(model, "use_fptt_surrogates", False):
            Q0 = np.full((output_size, time_steps), 1.0 / output_size, dtype=np.float32)
            model.enable_fptt_surrogates(time_steps=time_steps, output_size=output_size, Q_init=Q0, beta_schedule="linear")
        
        if isinstance(model, BPTTRNN) and BPTT_USE_FPTT:
            Q0 = np.full((output_size, time_steps), 1.0 / output_size, dtype=np.float32)
            model.enable_fptt_surrogates(time_steps=time_steps, output_size=output_size, Q_init=Q0, beta_schedule="linear")

        lambda_pre_model = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
        epoch_acc_history: List[float] = []
        best_val_acc = -float("inf")
        best_val_loss = float("inf")
        best_epoch = 0
        best_params: Dict[str, np.ndarray] | None = None

        seed_model_train = INIT_SEED * 10000 + model_index
        rng = np.random.default_rng(seed_model_train)

        for epoch in range(EPOCHS):
            if hasattr(model, "set_epoch"):
                model.set_epoch(epoch)
                
            if isinstance(model, NumpyLocalRuleRNN):
                if getattr(model, "use_fptt_surrogates", False):
                    model.reset_fptt_epoch_accumulators()
                for inputs_batch, targets_batch in iterate_minibatches(train_inputs_fit, train_targets_fit, BATCH_SIZE, rng):
                    h_prev = np.zeros((model.hidden_size, inputs_batch.shape[0]), dtype=np.float32)
                    model.run_one_cycle_and_update_directly(inputs_batch, targets_batch, h_prev)
                if getattr(model, "use_fptt_surrogates", False):
                    model.finalize_fptt_epoch()
            elif isinstance(model, BPTTRNN):
                if BPTT_USE_FPTT:
                    model.reset_fptt_epoch_accumulators()
                for inputs_batch, targets_batch in iterate_minibatches(train_inputs_fit, train_targets_fit, BATCH_SIZE, rng):
                    h_prev = np.zeros((model.hidden_size, inputs_batch.shape[0]), dtype=np.float32)
                    model.train_batch(inputs_batch, targets_batch, h_prev)
                if BPTT_USE_FPTT:
                    model.finalize_fptt_epoch()
            else:
                for inputs_batch, targets_batch in iterate_minibatches(train_inputs_fit, train_targets_fit, BATCH_SIZE, rng):
                    h_prev = np.zeros((model.hidden_size, inputs_batch.shape[0]), dtype=np.float32)
                    model.train_batch(inputs_batch, targets_batch, h_prev)

            val_loss_epoch, val_acc_epoch = evaluate_classifier_final_step(model, val_inputs, val_targets, y_val, BATCH_SIZE)
            epoch_acc_history.append(val_acc_epoch)
            if val_acc_epoch > best_val_acc:
                best_val_acc = float(val_acc_epoch)
                best_val_loss = float(val_loss_epoch)
                best_epoch = int(epoch + 1)
                best_params = extract_params(model)
            if (epoch + 1) % eval_log_interval == 0 or epoch == EPOCHS - 1:
                print(f"  Epoch {epoch + 1}/{EPOCHS} | val_acc={val_acc_epoch:.4f} | val_loss={val_loss_epoch:.4f}")

        if best_params is not None:
            load_params(model, best_params)

        final_test_loss, final_test_acc = evaluate_classifier_final_step(model, test_inputs, test_targets, y_test, BATCH_SIZE)
        lambda_post_model = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)

        comparison_results[name] = {
            "test_accuracy": final_test_acc,
            "best_val_accuracy": best_val_acc,
            "best_val_loss": best_val_loss,
            "best_epoch": best_epoch,
            "lyapunov_pre": lambda_pre_model,
            "lyapunov_post": lambda_post_model,
            "accuracy_history": epoch_acc_history,
        }

        print(
            f"-> {name}: test_acc={final_test_acc:.4f} | best_val_acc={best_val_acc:.4f} (epoch={best_epoch:02d}) | "
            f"Lyapunov pre={lambda_pre_model:.4f}, post={lambda_post_model:.4f}, "
            f"Δ={lambda_post_model - lambda_pre_model:.4f}"
        )

    elapsed_minutes = (time.time() - total_start_time) / 60.0
    print(f"Experiment finished! Total time: {elapsed_minutes:.2f} minutes")

    names = list(models_to_compare.keys())
    accuracies = [comparison_results[n]["test_accuracy"] for n in names]
    lyapunovs = [comparison_results[n]["lyapunov_post"] for n in names]

    fig, axs = plt.subplots(1, 3, figsize=(22, 7), constrained_layout=True)
    fig.suptitle(f"IMDB RNN Comparison Initialized at g*={g_optimal:.2f}", fontsize=20, weight="bold")

    colors = [cm.viridis(0.25), cm.plasma(0.55), cm.inferno(0.65), cm.cividis(0.55)]

    ax1 = axs[0]
    bars = ax1.bar(names, accuracies, color=colors, zorder=3)
    ax1.set_ylabel("Final Test Accuracy", fontsize=12)
    ax1.set_title("a) Generalization Performance", fontsize=14, weight="bold")
    ax1.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
    ax1.set_ylim(0, 1.05)
    ax1.spines["top"].set_visible(False)
    ax1.spines["right"].set_visible(False)
    for bar in bars:
        yval = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width() / 2.0, yval + 0.02, f"{yval:.3f}", ha="center", va="bottom")

    ax2 = axs[1]
    bars = ax2.bar(names, lyapunovs, color=colors, zorder=3)
    ax2.axhline(0, color="crimson", linestyle="--", linewidth=1.5, label="Edge of Chaos (λ=0)", zorder=2)
    ax2.set_ylabel("Max Lyapunov Exponent (λ_max)", fontsize=12)
    ax2.set_title("b) Post-Training Dynamics", fontsize=14, weight="bold")
    ax2.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
    ax2.spines["top"].set_visible(False)
    ax2.spines["right"].set_visible(False)
    ax2.legend()
    for bar in bars:
        yval = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width() / 2.0, yval + np.sign(yval) * 0.03, f"{yval:.3f}", ha="center", va="bottom")

    ax3 = axs[2]
    epochs_axis = np.arange(1, EPOCHS + 1)
    for i, name in enumerate(names):
        ax3.plot(
            epochs_axis,
            comparison_results[name]["accuracy_history"],
            label=name,
            color=colors[i],
            linewidth=2.5,
            marker="o",
            markersize=4,
            markevery=max(1, EPOCHS // 5),
        )
    ax3.set_xlabel("Epoch", fontsize=12)
    ax3.set_ylabel("Val Accuracy", fontsize=12)
    ax3.set_title("c) Learning Curves", fontsize=14, weight="bold")
    ax3.grid(True, linestyle=":", linewidth=0.7)
    ax3.legend(fontsize=11)
    ax3.set_ylim(0, 1.05)
    ax3.spines["top"].set_visible(False)
    ax3.spines["right"].set_visible(False)

    plt.show()

if __name__ == "__main__":
    main()
