import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Tuple

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm

from plotting_utils import apply_plot_style

COMPARE_DIR = Path(__file__).resolve().parents[1]
if str(COMPARE_DIR) not in sys.path:
    sys.path.insert(0, str(COMPARE_DIR))

from methods.strict_fptt import StrictFPTTRegressor
from methods.standard_eprop import StandardEPropRNN
from task.common.sequence_core import extract_params as extract_params_common, load_params as load_params_common
# The following utilities are assumed to be in a shared file
# from shared_rnn_utils import build_batch_seed_grid, initialize_rnn_parameters

# =============================
# Config
# =============================
BLOCK_SIZE = 48          # sequence length (number of frames in the video)
BATCH_SIZE = 32
EPOCHS = 15
LEARNING_RATE = 1e-3
HIDDEN_SIZE = 192

# video frame config: 8x8 moving dot
FRAME_H = 8
FRAME_W = 8
FRAME_SIZE = FRAME_H * FRAME_W

INPUT_SIZE = FRAME_SIZE   # flattened frame
OUTPUT_SIZE = FRAME_SIZE  # predict next frame

# Stage 1 Scan Config
GAINS_SCAN = np.linspace(0.5, 1.6, 13, endpoint=False)
SCAN_EPOCHS = 3 # Use fewer epochs for the scan to save time

# training schedule
STEPS_PER_EPOCH = 80
VAL_STEPS = 30
FPTT_PARTS = 8
EPROP_SEED = 1234
EPROP_FEEDBACK = "symmetric"
FPTT_LAMBDA = 0.5       # 保留但当前 strict_fptt 不使用
FPTT_USE_FED_DYN = False

CLIP_NORM = 5.0  # Local Rule / FPTT 使用 5.0 裁剪

# =============================
# Optimization helpers
# =============================

def normalize_and_clip_gradients(
    grads: List[np.ndarray],
    norm_scale: float,
    max_norm: float,
) -> List[np.ndarray]:
    """
    Normalize gradients by `norm_scale` (typically batch_size)
    and clip their global norm to `max_norm` to avoid divergence.
    """
    safe_scale = max(1.0, float(norm_scale))
    normalized = [g / safe_scale for g in grads]
    total_norm_sq = 0.0
    for g in normalized:
        total_norm_sq += float(np.sum(g**2))
    total_norm = float(np.sqrt(total_norm_sq))
    if not np.isfinite(total_norm) or total_norm <= max_norm:
        return normalized
    clip = max_norm / (total_norm + 1e-8)
    return [g * clip for g in normalized]


# =============================
# Utilities: Video sequence task
# =============================

def generate_video_batch(
    batch_size: int,
    seq_len: int,
    rng: np.random.Generator,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Synthetic video task: a single bright pixel (moving dot) in an 8x8 frame.

    - Each sequence: one dot moves with constant velocity (vx, vy),
      bouncing at the borders.
    - Input at time t:   frame_t  (flattened 8x8, values in {0,1})
    - Target at time t:  frame_{t+1} (next frame, flattened)
      For the final step T-1, we just repeat the last frame as target.

    Returns:
      inputs:  (B, FRAME_SIZE, T)
      targets: (B, FRAME_SIZE, T)
    """
    inputs = np.zeros((batch_size, FRAME_SIZE, seq_len), dtype=np.float32)
    targets = np.zeros((batch_size, FRAME_SIZE, seq_len), dtype=np.float32)

    for b in range(batch_size):
        # random initial position
        x = int(rng.integers(0, FRAME_W))
        y = int(rng.integers(0, FRAME_H))

        # random non-zero velocity in {-1, 0, 1}^2 \ {(0,0)}
        while True:
            vx = int(rng.integers(-1, 2))
            vy = int(rng.integers(-1, 2))
            if vx != 0 or vy != 0:
                break

        frames = np.zeros((seq_len, FRAME_H, FRAME_W), dtype=np.float32)

        for t in range(seq_len):
            frame = np.zeros((FRAME_H, FRAME_W), dtype=np.float32)
            frame[y, x] = 1.0
            frames[t] = frame

            # update position with bounce
            x_new = x + vx
            y_new = y + vy

            if x_new < 0 or x_new >= FRAME_W:
                vx = -vx
                x_new = x + vx
            if y_new < 0 or y_new >= FRAME_H:
                vy = -vy
                y_new = y + vy

            x, y = x_new, y_new

        # flatten frames to feature vectors
        for t in range(seq_len):
            inputs[b, :, t] = frames[t].reshape(-1)
            if t < seq_len - 1:
                targets[b, :, t] = frames[t + 1].reshape(-1)
            else:
                # last step: predict itself (or you could use zeros)
                targets[b, :, t] = frames[t].reshape(-1)

    return inputs, targets


def evaluate_video_model(
    model: Any,
    seq_len: int,
    batch_size: int,
    steps: int,
    rng_seed: int = 12345,
) -> float:
    """
    Evaluate a model on the moving-dot video prediction task using MSE
    over the full output sequence (all frames, all pixels).
    """
    rng = np.random.default_rng(rng_seed)
    total_loss = 0.0
    count = 0

    for _ in range(steps):
        inputs, targets = generate_video_batch(batch_size, seq_len, rng)
        h_prev = np.zeros((model.hidden_size, batch_size), dtype=np.float32)
        outputs_seq, _ = model.forward_cycle(inputs, h_prev)  # list[T] of (output, batch)
        outputs_stack = np.stack(outputs_seq, axis=2)         # (output, batch, T)
        targets_stack = targets.transpose(1, 0, 2)            # (output, batch, T)

        error = outputs_stack - targets_stack
        loss = 0.5 * float(np.mean(error ** 2))
        total_loss += loss
        count += 1

    return total_loss / max(count, 1)


def calculate_lyapunov_exponent_numpy(model: Any, driver_input: np.ndarray) -> float:
    """
    driver_input: (input_size, T)
    """
    n_hidden = int(model.hidden_size)
    W_hh = np.asarray(model.W_hh, dtype=np.float64)
    W_xh = np.asarray(model.W_xh, dtype=np.float64)
    b_h  = np.asarray(model.b_h,  dtype=np.float64)

    h = np.zeros((n_hidden, 1), dtype=np.float64)
    Q = np.eye(n_hidden, dtype=np.float64)
    log_r_diag_sum = np.zeros(n_hidden, dtype=np.float64)
    log_floor = 1e-12
    T = int(driver_input.shape[1])

    for t in range(T):
        I_t = driver_input[:, t].reshape(-1, 1).astype(np.float64)
        x_t = W_hh @ h + W_xh @ I_t + b_h
        h_next = np.tanh(x_t)
        phi_prime = (1.0 - h_next**2).flatten()
        jacobian = (phi_prime[:, None]) * W_hh
        Z = jacobian @ Q
        try:
            Q, R = np.linalg.qr(Z)
        except np.linalg.LinAlgError:
            return float("nan")
        r_diag_abs = np.abs(np.diag(R))
        log_r_diag_sum += np.log(np.clip(r_diag_abs, log_floor, None))
        h = h_next
    return float(np.max(log_r_diag_sum / max(T, 1)))


def make_lyapunov_driver_for_video(seq_len: int, batch_size: int) -> np.ndarray:
    """
    Generate a representative input sequence for Lyapunov exponent calculation
    over one random batch from the video task.
    Returns: (input_size, seq_len)
    """
    rng = np.random.default_rng(999)
    x_batch, _ = generate_video_batch(batch_size, seq_len, rng)
    return np.mean(x_batch, axis=0).astype(np.float32)


# =============================
# RNN Implementations
# =============================

class NumpyLocalRuleRNN:
    """
    Local-rule RNN with online updates and optional FPTT surrogates,
    adapted here for a regression task (MSE loss).
    """
    def __init__(self, input_size: int, hidden_size: int, output_size: int,
                 eta: float = 0.001, lambda_window: int = 50) -> None:
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.output_size = output_size
        self.eta = eta
        self.epsilon = 1e-8
        self.max_grad_norm = CLIP_NORM

        # parameters
        self.W_xh = np.random.randn(hidden_size, input_size) * 0.1
        self.W_hh = np.random.randn(hidden_size, hidden_size)
        self.b_h = np.zeros((hidden_size, 1))
        self.W_hy = np.random.randn(output_size, hidden_size) * 0.1
        self.b_y = np.zeros((output_size, 1))

        # alpha-related (EMA for AR(1) approximation)
        self.alpha_rho = 0.995
        self.alpha_clip_min = -0.99
        self.alpha_clip_max = 0.99

        # lambda-related
        self.lambda_window = max(2, int(lambda_window))
        self.lambda_rho = (self.lambda_window - 1) / self.lambda_window
        self.lambda_cap = 0.99
        self.denom_floor = 1e-3
        self.eps_lambda = 1e-8

        # diagnostics (optional)
        self.diagnostics: Dict[str, Any] = {}
        self.reset_learning_state()
        self.reset_diagnostics()

        # FPTT-related
        self.use_fptt_surrogates = False
        self.beta_schedule = "linear"
        self.fptt_Q_prev: np.ndarray | None = None
        self.fptt_Q_sum: np.ndarray | None = None
        self.fptt_Q_count: np.ndarray | None = None

    # ----- diagnostics -----
    def reset_diagnostics(self) -> None:
        self.diagnostics = {
            "total_neurons_steps": 0,
            "lambda_clip_count": 0,
            "denom_floor_hits": 0,
            "lambda_history": [],
        }

    def reset_learning_state(self) -> None:
        shape = (self.hidden_size, 1)
        self.alpha_num = np.zeros(shape)
        self.alpha_den = np.zeros(shape)
        self.alpha_hat = np.zeros(shape)

        self.S_A2 = np.zeros(shape)
        self.S_AB = np.zeros(shape)
        self.lambda_vals = np.zeros(shape)

    def initialize_weights_with_gain(self, g: float) -> None:
        # Re-initialize all weights for a fresh start, not just W_hh
        rng = np.random.default_rng(int(g * 1000)) # Seed for reproducibility
        self.W_xh = rng.standard_normal(self.W_xh.shape) * 0.1
        self.W_hh = rng.standard_normal(self.W_hh.shape) * (g / np.sqrt(self.hidden_size))
        self.b_h = np.zeros(self.b_h.shape)
        self.W_hy = rng.standard_normal(self.W_hy.shape) * 0.1
        self.b_y = np.zeros(self.b_y.shape)
        self.reset_learning_state()


    # ----- FPTT helpers -----
    def enable_fptt_surrogates(self, time_steps: int, output_size: int,
                               Q_init: np.ndarray | None = None,
                               beta_schedule: str = "linear") -> None:
        if Q_init is None:
            Q_init = np.full((output_size, time_steps), 0.0, dtype=np.float32)
        self.use_fptt_surrogates = True
        self.beta_schedule = beta_schedule
        self.fptt_Q_prev = Q_init.astype(np.float32).copy()
        self.reset_fptt_epoch_accumulators(time_steps, output_size)

    def reset_fptt_epoch_accumulators(self, time_steps: int | None = None,
                                      output_size: int | None = None) -> None:
        if self.fptt_Q_prev is None and (time_steps is None or output_size is None):
            raise ValueError("Call enable_fptt_surrogates first.")
        if time_steps is None or output_size is None:
            time_steps = int(self.fptt_Q_prev.shape[1])
            output_size = int(self.fptt_Q_prev.shape[0])
        self.fptt_Q_sum = np.zeros((output_size, time_steps), dtype=np.float64)
        self.fptt_Q_count = np.zeros((time_steps,), dtype=np.int64)

    def finalize_fptt_epoch(self) -> None:
        if self.fptt_Q_sum is None or self.fptt_Q_count is None:
            return
        counts = np.maximum(1, self.fptt_Q_count.astype(np.float64))
        Q_new = (self.fptt_Q_sum / counts[None, :]).astype(np.float32)
        self.fptt_Q_prev = Q_new

    # ----- single batch (one cycle) forward + update -----
    def run_one_cycle_and_update_directly(
        self,
        inputs_cycle: np.ndarray,
        targets_cycle: np.ndarray,
        h_prev_cycle: np.ndarray,
    ) -> Tuple[float, np.ndarray]:
        """
        inputs_cycle:  (B, input_size, T)
        targets_cycle: (B, output_size, T)
        h_prev_cycle: (hidden, B)
        """
        h_prev = h_prev_cycle
        total_cycle_loss = 0.0
        batch_size = inputs_cycle.shape[0]
        time_steps = inputs_cycle.shape[2]

        prev_g = None
        prev_u = None
        prev_delta = None

        for t in range(time_steps):
            I_t = inputs_cycle[:, :, t].T          # (input, batch)
            y_target_t = targets_cycle[:, :, t].T  # (output, batch)

            x_t = self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h
            h_t = np.tanh(x_t)
            y_hat_t = self.W_hy @ h_t + self.b_y   # (output, batch)

            # ---- loss & dL/dyhat (MSE) ----
            if self.use_fptt_surrogates and self.fptt_Q_prev is not None:
                beta_t = float(t + 1) / float(time_steps)
                Q_t = self.fptt_Q_prev[:, t].reshape(-1, 1)  # (output, 1)
                Q_t_batch = np.repeat(Q_t, repeats=y_hat_t.shape[1], axis=1)
                Y_tilde = beta_t * y_target_t + (1.0 - beta_t) * Q_t_batch

                error = y_hat_t - Y_tilde
                loss_t = 0.5 * np.mean(error ** 2)
                dL_dyhat = error

                if self.fptt_Q_sum is not None and self.fptt_Q_count is not None:
                    self.fptt_Q_sum[:, t] += np.sum(y_hat_t, axis=1)
                    self.fptt_Q_count[t] += batch_size
            else:
                error = y_hat_t - y_target_t
                loss_t = 0.5 * np.mean(error ** 2)
                dL_dyhat = error

            total_cycle_loss += float(loss_t)

            # local learning rule ingredients
            g_t = self.W_hy.T @ dL_dyhat   # (hidden, batch)
            u_t = 1.0 - h_t**2             # (hidden, batch)

            # ---- use "previous-step" lambda for update ----
            lambda_used = self.lambda_vals.copy()
            self.diagnostics["lambda_history"].append(float(np.mean(lambda_used)))
            self.diagnostics["total_neurons_steps"] += self.hidden_size * batch_size

            denominator = 1.0 - lambda_used * u_t
            denom_mask = np.abs(denominator) < self.denom_floor
            if np.any(denom_mask):
                self.diagnostics["denom_floor_hits"] += int(np.sum(denom_mask))
            denominator = np.where(
                denom_mask,
                self.denom_floor * np.sign(denominator),
                denominator,
            )
            delta_t = (u_t * g_t) / denominator  # (hidden, batch)

            # ---- alpha update (teaching-signal AR(1) in expectation) ----
            if prev_delta is not None:
                dtp_mean = np.mean(delta_t * prev_delta, axis=1, keepdims=True)
                dpp_mean = np.mean(prev_delta**2, axis=1, keepdims=True)
                self.alpha_num = self.alpha_rho * self.alpha_num + (1.0 - self.alpha_rho) * dtp_mean
                self.alpha_den = self.alpha_rho * self.alpha_den + (1.0 - self.alpha_rho) * dpp_mean
                raw_alpha = self.alpha_num / (self.alpha_den + self.epsilon)
                self.alpha_hat = np.clip(raw_alpha, self.alpha_clip_min, self.alpha_clip_max)

            dW_hh = (delta_t @ h_prev.T) / batch_size
            dW_xh = (delta_t @ I_t.T) / batch_size
            db_h  = np.mean(delta_t, axis=1, keepdims=True)
            dW_hy = (dL_dyhat @ h_t.T) / batch_size
            db_y  = np.mean(dL_dyhat, axis=1, keepdims=True)

            # Local rule 使用 5.0 global norm 裁剪
            if self.max_grad_norm > 0:
                grads = [dW_xh, dW_hh, db_h, dW_hy, db_y]
                clipped_grads = normalize_and_clip_gradients(grads, norm_scale=1.0, max_norm=self.max_grad_norm)
                dW_xh, dW_hh, db_h, dW_hy, db_y = clipped_grads
            
            self.W_xh -= self.eta * dW_xh
            self.W_hh -= self.eta * dW_hh
            self.b_h  -= self.eta * db_h
            self.W_hy -= self.eta * dW_hy
            self.b_y  -= self.eta * db_y

            # ---- update lambda using (prev_u, prev_g) and (u_t, g_t) ----
            if prev_g is not None and prev_u is not None:
                A_s = prev_u * u_t * (self.alpha_hat * prev_g - g_t)
                B_s = self.alpha_hat * prev_u * prev_g - u_t * g_t

                A2_mean = np.mean(A_s**2, axis=1, keepdims=True)
                AB_mean = np.mean(A_s * B_s, axis=1, keepdims=True)

                self.S_A2 = self.lambda_rho * self.S_A2 + (1.0 - self.lambda_rho) * A2_mean
                self.S_AB = self.lambda_rho * self.S_AB + (1.0 - self.lambda_rho) * AB_mean

                lambda_unproj = self.S_AB / (self.S_A2 + self.eps_lambda)

                u_abs_max = np.max(np.abs(u_t), axis=1, keepdims=True) + 1e-12
                safe_cap = (1.0 - self.denom_floor) / u_abs_max
                cap = np.minimum(safe_cap, self.lambda_cap)
                lambda_clipped = np.clip(lambda_unproj, -cap, cap)

                self.diagnostics["lambda_clip_count"] += int(np.sum(lambda_unproj != lambda_clipped))
                self.lambda_vals = lambda_clipped

            prev_g = g_t
            prev_u = u_t
            prev_delta = delta_t
            h_prev = h_t

        return total_cycle_loss, h_prev

    def forward_cycle(self, inputs, h_prev) -> Tuple[List[np.ndarray], np.ndarray]:
        outputs = []
        for t in range(inputs.shape[2]):
            I_t = inputs[:, :, t].T
            h_t = np.tanh(self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h)
            outputs.append(self.W_hy @ h_t + self.b_y)
            h_prev = h_t
        return outputs, h_prev


class BPTTRNN:
    def __init__(self, input_size: int, hidden_size: int, output_size: int, eta: float = 0.001) -> None:
        self.input_size, self.hidden_size, self.output_size, self.eta = input_size, hidden_size, output_size, eta
        self.W_xh = np.random.randn(hidden_size, input_size) * 0.1
        self.W_hh = np.random.randn(hidden_size, hidden_size)
        self.b_h = np.zeros((hidden_size, 1))
        self.W_hy = np.random.randn(output_size, hidden_size) * 0.1
        self.b_y = np.zeros((output_size, 1))
        self.max_grad_norm = CLIP_NORM  # 不再用于 BPTT（只平均，不裁剪）

    def initialize_weights_with_gain(self, g: float) -> None:
        self.W_hh = np.random.randn(*self.W_hh.shape) * (g / np.sqrt(self.hidden_size))

    def train_batch(self, inputs, targets, h_prev) -> Tuple[float, np.ndarray]:
        batch_size, time_steps = inputs.shape[0], inputs.shape[2]
        h_acts: Dict[int, np.ndarray] = {-1: h_prev}
        y_hats: Dict[int, np.ndarray] = {}
        total_loss = 0.0

        # forward pass
        for t in range(time_steps):
            I_t = inputs[:, :, t].T
            y_target_t = targets[:, :, t].T

            h_t = np.tanh(self.W_hh @ h_acts[t - 1] + self.W_xh @ I_t + self.b_h)
            y_hat_t = self.W_hy @ h_t + self.b_y

            error = y_hat_t - y_target_t
            loss_t = 0.5 * np.mean(error ** 2)
            total_loss += float(loss_t)

            h_acts[t] = h_t
            y_hats[t] = y_hat_t

        # backward pass
        dW_xh, dW_hh, db_h = np.zeros_like(self.W_xh), np.zeros_like(self.W_hh), np.zeros_like(self.b_h)
        dW_hy, db_y = np.zeros_like(self.W_hy), np.zeros_like(self.b_y)
        dh_next = np.zeros((self.hidden_size, batch_size))

        for t in reversed(range(time_steps)):
            I_batch, y_target_t = inputs[:, :, t], targets[:, :, t].T
            h_t, h_prev_t, y_hat_t = h_acts[t], h_acts[t - 1], y_hats[t]
            
            dL_dyhat = y_hat_t - y_target_t
            dW_hy += dL_dyhat @ h_t.T
            db_y += np.sum(dL_dyhat, axis=1, keepdims=True)

            dh = self.W_hy.T @ dL_dyhat + dh_next
            dtanh = (1.0 - h_t**2) * dh
            dW_hh += dtanh @ h_prev_t.T
            dW_xh += dtanh @ I_batch
            db_h += np.sum(dtanh, axis=1, keepdims=True)
            dh_next = self.W_hh.T @ dtanh

        # 只做平均（按 batch_size），不做裁剪
        safe_scale = max(1.0, float(batch_size))
        dW_xh /= safe_scale
        dW_hh /= safe_scale
        db_h  /= safe_scale
        dW_hy /= safe_scale
        db_y  /= safe_scale
        
        self.W_xh -= self.eta * dW_xh
        self.W_hh -= self.eta * dW_hh
        self.b_h  -= self.eta * db_h
        self.W_hy -= self.eta * dW_hy
        self.b_y  -= self.eta * db_y

        return (total_loss / time_steps), h_acts[time_steps - 1]

    def forward_cycle(self, inputs, h_prev) -> Tuple[List[np.ndarray], np.ndarray]:
        outputs = []
        for t in range(inputs.shape[2]):
            I_t = inputs[:, :, t].T
            h_t = np.tanh(self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h)
            outputs.append(self.W_hy @ h_t + self.b_y)
            h_prev = h_t
        return outputs, h_prev


class FPTTRNN(StrictFPTTRegressor):
    """Wrapper for Strict FPTT regressor for consistent API."""
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        eta: float = 1e-3,
        parts: int = FPTT_PARTS,
        clip: float = CLIP_NORM,
        alpha: float = 0.1,
        beta: float = 0.5,
        rho: float = 0.0,
        lambda_reg: float = FPTT_LAMBDA,
        **_: Any,
    ) -> None:
        # lambda_reg is mapped to StrictFPTTRegressor's `lmbda`; other extras are accepted but ignored.
        super().__init__(
            input_size=input_size,
            hidden_size=hidden_size,
            output_size=output_size,
            eta=eta,
            parts=parts,
            clip=clip,
            alpha=alpha,
            beta=beta,
            rho=rho,
            lmbda=lambda_reg,
        )

    def zero_state(self, batch_size: int) -> np.ndarray:
        return np.zeros((self.hidden_size, batch_size), dtype=np.float32)

# =============================
# Main Experiment Logic
# =============================

def plot_comparison(results: Dict[str, Dict[str, Any]], gain: float) -> None:
    names = list(results.keys())
    mses = [results[n]['final_metric'] for n in names]
    lyapunovs = [results[n]['lambda_post'] for n in names]

    fig, axs = plt.subplots(1, 3, figsize=(22, 7), constrained_layout=True)
    fig.suptitle(f"RNN Algorithm Comparison on Video Task (Initialized at g={gain:.2f})", fontsize=20, weight="bold")
    colors = [cm.viridis(x) for x in np.linspace(0.2, 0.9, len(names))]

    # --- a) Final Performance ---
    bars = axs[0].bar(names, mses, color=colors, zorder=3)
    axs[0].set_ylabel("Final Validation MSE (lower is better)", fontsize=12)
    axs[0].set_title("a) Generalization Performance", fontsize=14, weight="bold")
    axs[0].grid(True, axis='y', linestyle=":", zorder=0)
    for bar in bars:
        axs[0].text(bar.get_x() + bar.get_width()/2.0, bar.get_height(), f'{bar.get_height():.4f}', ha='center', va='bottom')
    axs[0].tick_params(axis='x', rotation=15)

    # --- b) Post-Training Dynamics ---
    bars = axs[1].bar(names, lyapunovs, color=colors, zorder=3)
    axs[1].axhline(0, color="crimson", linestyle="--", label="Edge of Chaos (λ=0)", zorder=2)
    axs[1].set_ylabel("Max Lyapunov Exponent (λ_max)", fontsize=12)
    axs[1].set_title("b) Post-Training Dynamics", fontsize=14, weight="bold")
    axs[1].grid(True, axis='y', linestyle=":", zorder=0)
    axs[1].legend()
    for bar in bars:
        height = bar.get_height()
        va = 'bottom' if height >= 0 else 'top'
        offset = abs(height) * 0.05 + 0.01
        y_pos = height + offset if height >= 0 else height - offset
        axs[1].text(bar.get_x() + bar.get_width() / 2.0, y_pos, f"{height:.3f}", ha="center", va=va)
    axs[1].tick_params(axis='x', rotation=15)

    # --- c) Learning Curves ---
    for i, name in enumerate(names):
        axs[2].plot(np.arange(1, EPOCHS + 1), results[name]['history'],
                    label=name, color=colors[i], lw=2.5, marker='o', ms=5)
    axs[2].set_xlabel("Epoch", fontsize=12)
    axs[2].set_ylabel("Validation MSE", fontsize=12)
    axs[2].set_title("c) Learning Curves", fontsize=14, weight="bold")
    axs[2].grid(True, linestyle=":")
    axs[2].legend(fontsize=11)
    axs[2].set_xticks(np.arange(1, EPOCHS + 1))

    plt.show()


def main() -> None:
    apply_plot_style()
    np.random.seed(42)
    start_time = time.time()
    
    # --- Data and Driver setup ---
    print("Preparing Lyapunov driver for the video task...")
    lyap_driver = make_lyapunov_driver_for_video(BLOCK_SIZE, BATCH_SIZE)

    # ========================================================================
    # STAGE 1: Find the best performing model near the critical state
    # ========================================================================
    print("\n" + "="*60)
    print("STAGE 1: Finding optimal 'g' for Local Rule RNN...")
    print("="*60)
    
    scan_results = []
    for g in GAINS_SCAN:
        model = NumpyLocalRuleRNN(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, eta=LEARNING_RATE)
        model.initialize_weights_with_gain(g)
        lambda_pre = calculate_lyapunov_exponent_numpy(model, lyap_driver)
        initial_params = extract_params_common(model)

        # Short training for the scan
        rng = np.random.default_rng(123)
        for epoch in range(SCAN_EPOCHS):
            for _ in range(STEPS_PER_EPOCH):
                inputs, targets = generate_video_batch(BATCH_SIZE, BLOCK_SIZE, rng)
                h_prev = np.zeros((model.hidden_size, inputs.shape[0]), dtype=np.float32)
                model.run_one_cycle_and_update_directly(inputs, targets, h_prev)
        
        val_mse = evaluate_video_model(model, BLOCK_SIZE, BATCH_SIZE, VAL_STEPS)
        lambda_post = calculate_lyapunov_exponent_numpy(model, lyap_driver)

        print(
            f"[SCAN] g={g:.3f} | Val MSE: {val_mse:.6f} | "
            f"Lyapunov (pre, post) = ({lambda_pre:.4f}, {lambda_post:.4f}), Δ={lambda_post - lambda_pre:.4f}"
        )
        if not (np.isnan(lambda_pre) or np.isnan(lambda_post)):
            scan_results.append({
                "g": g,
                "mse": val_mse,
                "lambda_pre": lambda_pre,
                "lambda_post": lambda_post,
                "params": initial_params,
            })

    top_performers = sorted(scan_results, key=lambda x: x["mse"])[:3]
    best_critical_model = min(top_performers, key=lambda x: abs(x["lambda_pre"]))
    g_optimal, initial_params_optimal = best_critical_model['g'], best_critical_model['params']

    print("\n--- Stage 1 Summary ---")
    print(f"Optimal gain 'g' found: {g_optimal:.4f}")
    print(f"  - Achieved MSE: {best_critical_model['mse']:.6f}")
    print(f"  - Lyapunov Exponent PRE (close to 0): {best_critical_model['lambda_pre']:.4f}")
    print(f"  - Lyapunov Exponent POST: {best_critical_model['lambda_post']:.4f}")

    # ========================================================================
    # STAGE 2: Compare algorithms at the optimal point
    # ========================================================================
    print("\n" + "="*60)
    print(f"STAGE 2: Comparing algorithms with optimal g = {g_optimal:.4f}")
    print("="*60)

    models_to_compare = {
        "Local Rule": NumpyLocalRuleRNN(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, eta=LEARNING_RATE),
        "BPTT": BPTTRNN(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, eta=LEARNING_RATE),
        "E-Prop": StandardEPropRNN(
            INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE,
            eta=LEARNING_RATE, feedback=EPROP_FEEDBACK, seed=EPROP_SEED, loss_mode="mse",
            device="cpu",
        ),
        "FPTT": FPTTRNN(
            input_size=INPUT_SIZE, hidden_size=HIDDEN_SIZE, output_size=OUTPUT_SIZE,
            eta=LEARNING_RATE, parts=FPTT_PARTS, lambda_reg=FPTT_LAMBDA, use_feddyn=FPTT_USE_FED_DYN,
            clip=CLIP_NORM # 显式 5.0 裁剪
        ),
    }

    comparison_results = {}
    for name, model in models_to_compare.items():
        print(f"\nTraining {name}...")
        load_params_common(model, initial_params_optimal)
        
        # Ensure stateful buffers are reset after loading parameters
        if hasattr(model, "reset_learning_state"): model.reset_learning_state()
        if hasattr(model, "reset_state_buffers"): model.reset_state_buffers()

        lambda_pre_model = calculate_lyapunov_exponent_numpy(model, lyap_driver)
        history = []
        rng = np.random.default_rng(456)
        
        for epoch in range(EPOCHS):
            if hasattr(model, "set_epoch"): model.set_epoch(epoch)
            
            for _ in range(STEPS_PER_EPOCH):
                inputs, targets = generate_video_batch(BATCH_SIZE, BLOCK_SIZE, rng)
                h_prev = np.zeros((model.hidden_size, inputs.shape[0]), dtype=np.float32)
                
                # Adapt training call based on model's API
                if hasattr(model, 'run_one_cycle_and_update_directly'):
                    model.run_one_cycle_and_update_directly(inputs, targets, h_prev)
                elif hasattr(model, 'train_batch'):
                    model.train_batch(inputs, targets, h_prev)
                else:
                    raise NotImplementedError(f"Model {name} has no recognized training method.")

            val_mse = evaluate_video_model(model, BLOCK_SIZE, BATCH_SIZE, VAL_STEPS)
            history.append(val_mse)
            print(f"  Epoch {epoch+1}/{EPOCHS} | Val MSE: {val_mse:.6f}")

        final_lambda = calculate_lyapunov_exponent_numpy(model, lyap_driver)
        comparison_results[name] = {
            "final_metric": history[-1],
            "lambda_pre": lambda_pre_model,
            "lambda_post": final_lambda,
            "history": history,
        }
        print(
            f"-> Final Results for {name}: MSE={history[-1]:.6f}, "
            f"Lyapunov pre={lambda_pre_model:.4f}, post={final_lambda:.4f}, "
            f"Δ={final_lambda - lambda_pre_model:.4f}"
        )
    
    print(f"\nExperiment finished! Total time: {(time.time() - start_time) / 60:.2f} minutes")

    # ========================================================================
    # STAGE 3: Plotting the comparison results
    # ========================================================================
    plot_comparison(comparison_results, g_optimal)


if __name__ == '__main__':
    main()
