"""核心"""
from __future__ import annotations

import argparse
import csv
import importlib
import inspect
import json
import math
import os
import platform
import random
import re
import subprocess
import sys
import time
import urllib.request
import zipfile
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple

import numpy as np
import torch

# 文件结构：
# 1) 通用工具/绘图/时间权重：resolve_device / build_time_weights / plot_*。
# 2) 模型定义：TorchLocalRuleRNN / TorchBPTTRNN / StandardEPropRNN / StrictFPTT*。
# 3) 数据加载与序列构造：load_mnist_images / load_permuted_mnist_sequences / load_sequential_cifar10_sequences。
# 4) 训练/评估/扫描与任务入口：train_batches / evaluate_* / scan_gains_* / run_*_task。

COMPARE_DIR = Path(__file__).resolve().parents[2]
if str(COMPARE_DIR) not in sys.path:
    sys.path.insert(0, str(COMPARE_DIR))

# 训练算法实现（严格 FPTT / E-Prop）位于 methods 目录。
from methods.strict_fptt import (
    StrictFPTTClassifier,
    StrictFPTTRegressor,
    build_chunk_schedule,
)
from methods.standard_eprop import StandardEPropRNN

FPTT_PARTS = 10
FPTT_ORACLE_MOMENTUM = 1.0
FPTT_LAMBDA = 1.0
EPROP_FEEDBACK = "symmetric"
EPROP_SEED = 1234
LYAPUNOV_MIN_STEPS = 50

# 共享默认超参：严格 FPTT / E-Prop / 稳定性扫描复用。

DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# resolve_device：解析 device 参数并选择 CPU/GPU。
# 说明：支持 None/字符串/torch.device 等输入。
def resolve_device(device: torch.device | str | None = None) -> torch.device:
    if device is None:
        return DEFAULT_DEVICE
    return torch.device(device)


# to_tensor：将 numpy/列表转为 tensor 并搬到设备。
# 说明：处理 dtype/device，将 numpy 或列表搬到指定设备。
def to_tensor(
    array: np.ndarray | torch.Tensor,
    device: torch.device,
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    if torch.is_tensor(array):
        return array.to(device=device, dtype=dtype)
    return torch.as_tensor(array, dtype=dtype, device=device)


# =============================
# 通用工具（时间权重/批次切分/统计）
# =============================

# softmax：对 logits 做数值稳定 softmax 并输出概率。
# 说明：先减去最大值稳定化，再指数归一化为概率。
def softmax(logits: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
    if torch.is_tensor(logits):
        logits_shifted = logits - torch.max(logits, dim=0, keepdim=True).values
        exp_logits = torch.exp(logits_shifted)
        return exp_logits / (torch.sum(exp_logits, dim=0, keepdim=True) + 1e-12)
    logits_shifted = logits - np.max(logits, axis=0, keepdims=True)
    exp_logits = np.exp(logits_shifted)
    return exp_logits / (np.sum(exp_logits, axis=0, keepdims=True) + 1e-12)


# iterate_minibatches：按 batch_size 产出小批次样本。
# 说明：支持 shuffle 与批次切分，输出批数据。
def iterate_minibatches(
    inputs: np.ndarray | torch.Tensor,
    targets: np.ndarray | torch.Tensor,
    batch_size: int,
    rng: np.random.Generator,
) -> Iterable[Tuple[np.ndarray | torch.Tensor, np.ndarray | torch.Tensor]]:
    if torch.is_tensor(inputs):
        indices = rng.permutation(inputs.shape[0])
        indices_t = torch.as_tensor(indices, device=inputs.device)
        for start_idx in range(0, inputs.shape[0], batch_size):
            batch_indices = indices_t[start_idx : start_idx + batch_size]
            yield inputs[batch_indices], targets[batch_indices]
        return
    # 生成一个从 0 到 N-1 的顺序数组 [0, 1, 2, ..., N-1]
    indices = np.arange(inputs.shape[0])
    
    # 使用 rng 原地打乱这个数组，变成乱序索引
    rng.shuffle(indices)
    
    # 同样的循环逻辑，按 batch_size 步长遍历
    for start_idx in range(0, inputs.shape[0], batch_size):
        # 切片取出当前批次的乱序索引
        batch_indices = indices[start_idx : start_idx + batch_size]
        
        # 使用 NumPy 的索引方式取出数据并 yield 出去
        yield inputs[batch_indices], targets[batch_indices]


# split_train_val：按比例拆分训练/验证集并可打乱。
# 说明：按比例切分 train/val 并可 shuffle。
def split_train_val(
    inputs: np.ndarray | torch.Tensor,
    targets: np.ndarray | torch.Tensor,
    labels: np.ndarray | torch.Tensor,
    val_fraction: float,
    rng: np.random.Generator,
) -> Tuple[
    np.ndarray | torch.Tensor,
    np.ndarray | torch.Tensor,
    np.ndarray | torch.Tensor,
    np.ndarray | torch.Tensor,
    np.ndarray | torch.Tensor,
    np.ndarray | torch.Tensor,
]:
    total = inputs.shape[0]
    val_size = max(1, int(total * val_fraction))
    indices = rng.permutation(total)
    if torch.is_tensor(inputs):
        device = inputs.device
        val_idx = torch.as_tensor(indices[:val_size], device=device)
        train_idx = torch.as_tensor(indices[val_size:], device=device)
        return (
            inputs[train_idx],
            targets[train_idx],
            labels[train_idx],
            inputs[val_idx],
            targets[val_idx],
            labels[val_idx],
        )
    val_idx = indices[:val_size]
    train_idx = indices[val_size:]
    return (
        inputs[train_idx],
        targets[train_idx],
        labels[train_idx],
        inputs[val_idx],
        targets[val_idx],
        labels[val_idx],
    )


def split_train_val_grouped(
    inputs: np.ndarray | torch.Tensor,
    targets: np.ndarray | torch.Tensor,
    labels: np.ndarray | torch.Tensor,
    groups: np.ndarray | torch.Tensor,
    val_fraction: float,
    rng: np.random.Generator,
) -> Tuple[
    np.ndarray | torch.Tensor,
    np.ndarray | torch.Tensor,
    np.ndarray | torch.Tensor,
    np.ndarray | torch.Tensor,
    np.ndarray | torch.Tensor,
    np.ndarray | torch.Tensor,
]:
    if inputs.shape[0] != groups.shape[0]:
        raise ValueError(f"inputs and groups must align: {inputs.shape[0]} vs {groups.shape[0]}")

    groups_np = groups.detach().cpu().numpy() if torch.is_tensor(groups) else np.asarray(groups)
    unique_groups = np.unique(groups_np)
    if unique_groups.size < 2:
        return split_train_val(inputs, targets, labels, val_fraction, rng)

    val_group_count = max(1, int(unique_groups.size * val_fraction))
    val_group_count = min(val_group_count, unique_groups.size - 1)
    val_groups = rng.choice(unique_groups, size=val_group_count, replace=False)
    val_mask_np = np.isin(groups_np, val_groups)
    if not np.any(val_mask_np) or np.all(val_mask_np):
        return split_train_val(inputs, targets, labels, val_fraction, rng)

    if torch.is_tensor(inputs):
        device = inputs.device
        val_mask = torch.as_tensor(val_mask_np, device=device)
        train_idx = torch.where(~val_mask)[0]
        val_idx = torch.where(val_mask)[0]
        return (
            inputs[train_idx],
            targets[train_idx],
            labels[train_idx],
            inputs[val_idx],
            targets[val_idx],
            labels[val_idx],
        )

    train_idx = np.nonzero(~val_mask_np)[0]
    val_idx = np.nonzero(val_mask_np)[0]
    return (
        inputs[train_idx],
        targets[train_idx],
        labels[train_idx],
        inputs[val_idx],
        targets[val_idx],
        labels[val_idx],
    )


# build_repeated_targets：将目标扩展为每个时间步重复的序列标签。
# 说明：将标签复制到每个时间步以对齐序列。
def build_repeated_targets(
    labels: np.ndarray | torch.Tensor,
    num_classes: int,
    time_steps: int,
) -> np.ndarray | torch.Tensor:
    if torch.is_tensor(labels):
        labels = labels.to(dtype=torch.long)
        onehot = torch.zeros(
            (labels.shape[0], num_classes),
            dtype=torch.float32,
            device=labels.device,
        )
        onehot.scatter_(1, labels.view(-1, 1), 1.0)
        return onehot.unsqueeze(2).repeat(1, 1, time_steps)
    onehot = np.zeros((labels.shape[0], num_classes), dtype=np.float32)
    onehot[np.arange(labels.shape[0]), labels] = 1.0
    return np.repeat(onehot[:, :, None], time_steps, axis=2)


# normalize_train_test：使用训练集统计量归一化训练/测试集。
# 说明：只用训练集统计量归一化，避免数据泄漏。
def normalize_train_test(
    train_data: np.ndarray | torch.Tensor,
    test_data: np.ndarray | torch.Tensor,
) -> Tuple[np.ndarray | torch.Tensor, np.ndarray | torch.Tensor]:
    if torch.is_tensor(train_data):
        mean = train_data.mean()
        std = train_data.std() + 1e-7
        return (train_data - mean) / std, (test_data - mean) / std
    mean = float(np.mean(train_data))
    std = float(np.std(train_data)) + 1e-7
    return (train_data - mean) / std, (test_data - mean) / std


# build_time_weights：生成时间步权重向量用于损失加权。
# 说明：按策略生成时间步权重（none/final/late/fptt）。
def build_time_weights(time_steps: int, mode: str | None, min_weight: float = 0.05) -> np.ndarray | None:
    if mode is None:
        return None
    mode = mode.lower()
    if mode == "none":
        return None
    time_steps = max(1, int(time_steps))
    min_weight = float(max(min_weight, 1e-4))
    if mode in {"final", "last"}:
        weights = np.full(time_steps, min_weight, dtype=np.float32)
        weights[-1] = 1.0
        return weights
    if mode in {"late", "linear"}:
        return np.linspace(min_weight, 1.0, time_steps, dtype=np.float32)
    raise ValueError(f"Unknown time-weighting mode: {mode}")


# clip_gradients：按阈值裁剪梯度，避免梯度爆炸。
# 说明：为上层训练/评估流程提供辅助支持。
def clip_gradients(grads: List[np.ndarray | torch.Tensor], max_norm: float) -> List[np.ndarray | torch.Tensor]:
    if max_norm <= 0:
        return grads
    if grads and torch.is_tensor(grads[0]):
        total_norm_sq = torch.zeros((), device=grads[0].device)
        for g in grads:
            total_norm_sq = total_norm_sq + torch.sum(g ** 2)
        total_norm = torch.sqrt(total_norm_sq)
        if not torch.isfinite(total_norm) or total_norm <= max_norm:
            return grads
        scale = max_norm / (total_norm + 1e-8)
        return [g * scale for g in grads]
    total_norm_sq = 0.0
    for g in grads:
        total_norm_sq += float(np.sum(g**2))
    total_norm = float(np.sqrt(total_norm_sq))
    if not np.isfinite(total_norm) or total_norm <= max_norm:
        return grads
    scale = max_norm / (total_norm + 1e-8)
    return [g * scale for g in grads]


# _count_elements：递归统计张量/列表/字典的元素数量。
# 说明：用于复杂度报告或内存规模估计。
def _count_elements(value: Any) -> int:
    if value is None:
        return 0
    if torch.is_tensor(value):
        return int(value.numel())
    return int(np.asarray(value).size)


# count_model_parameters：统计模型参数量与可学习参数量。
# 说明：用于复杂度报告或内存规模估计。
def count_model_parameters(model: Any) -> int:
    if isinstance(model, torch.nn.Module):
        return int(sum(p.numel() for p in model.parameters()))
    total = 0
    for name in ("W_hh", "W_xh", "b_h", "W_hy", "b_y"):
        if hasattr(model, name):
            total += _count_elements(getattr(model, name))
    return total


# count_persistent_state：统计模型持久状态/缓冲占用规模。
# 说明：用于复杂度报告或内存规模估计。
def count_persistent_state(model: Any) -> int:
    total = 0
    for name in ("alpha_num", "alpha_den", "alpha_hat", "S_A2", "S_AB", "lambda_vals"):
        if hasattr(model, name):
            total += _count_elements(getattr(model, name))
    for name in ("fptt_Q_prev", "fptt_Q_sum", "fptt_Q_count"):
        if hasattr(model, name):
            total += _count_elements(getattr(model, name))
    if hasattr(model, "B_fb"):
        total += _count_elements(getattr(model, "B_fb"))
    regularizer = getattr(model, "_regularizer", None)
    if regularizer is not None and hasattr(regularizer, "_state"):
        for state in regularizer._state.values():
            total += _count_elements(state.get("sm"))
            total += _count_elements(state.get("lm"))
    return total


# estimate_model_complexity：估算模型参数规模与计算复杂度。
# 说明：基于参数量/步数给出粗略成本估计。
def estimate_model_complexity(model: Any) -> Dict[str, int]:
    params = count_model_parameters(model)
    state = count_persistent_state(model)
    return {"params": params, "state": state, "total": params + state}


# estimate_update_factor：估算一次更新的尺度因子。
# 说明：基于参数量/步数给出粗略成本估计。
def estimate_update_factor(model: Any, time_steps: int) -> int:
    time_steps = max(1, int(time_steps))
    if isinstance(model, (TorchLocalRuleRNN, StandardEPropRNN)):
        return time_steps
    if isinstance(model, (StrictFPTTClassifier, StrictFPTTRegressor)):
        parts = getattr(model, "parts", FPTT_PARTS)
        return len(build_chunk_schedule(time_steps, parts))
    if isinstance(model, TorchBPTTRNN):
        tbptt_steps = getattr(model, "tbptt_steps", None)
        if tbptt_steps is not None:
            tbptt_steps = int(tbptt_steps)
            if 0 < tbptt_steps < time_steps:
                return int(math.ceil(time_steps / tbptt_steps))
    return 1


# estimate_training_counts：估算训练步数/批次数量。
# 说明：基于参数量/步数给出粗略成本估计。
def estimate_training_counts(
    model: Any,
    time_steps: int,
    batches_per_epoch: int,
    epochs: int,
) -> Dict[str, int]:
    batches = max(1, int(batches_per_epoch))
    epochs = max(1, int(epochs))
    time_steps = max(1, int(time_steps))
    update_factor = estimate_update_factor(model, time_steps)
    updates_per_epoch = update_factor * batches
    updates_total = updates_per_epoch * epochs
    steps_total = batches * epochs * time_steps
    return {
        "batches_per_epoch": batches,
        "time_steps": time_steps,
        "update_factor": update_factor,
        "updates_per_epoch": updates_per_epoch,
        "updates_total": updates_total,
        "steps_total": steps_total,
    }


# _slugify：将字符串规范化为安全文件名。
# 说明：移除非法字符，避免输出路径出错。
def _slugify(value: str) -> str:
    cleaned = re.sub(r"[^a-zA-Z0-9]+", "_", value.strip().lower())
    cleaned = cleaned.strip("_")
    return cleaned or "task"


# resolve_plot_context：解析绘图上下文配置并返回路径/前缀。
# 说明：解析输出目录/前缀并创建目录。
def resolve_plot_context(
    args: argparse.Namespace,
    task_label: str,
) -> Tuple[Path, str]:
    plot_path = getattr(args, "plot_path", None)
    plot_dir = getattr(args, "plot_dir", None)
    plot_tag = getattr(args, "plot_tag", None)
    task_slug = _slugify(task_label)

    if plot_path:
        base = Path(plot_path)
        plot_dir = base.parent
        plot_tag = base.stem or task_slug
    else:
        if plot_dir is None:
            plot_dir = Path("plots") / task_slug
        else:
            plot_dir = Path(plot_dir)
        if not plot_tag:
            plot_tag = time.strftime("%Y%m%d_%H%M%S")

    plot_dir.mkdir(parents=True, exist_ok=True)
    return plot_dir, plot_tag


# build_plot_path：构建图表输出路径并确保目录存在。
# 说明：按任务/方法拼接文件名并确保目录存在。
def build_plot_path(plot_dir: Path, plot_tag: str, suffix: str) -> str:
    safe_suffix = _slugify(suffix)
    return str(plot_dir / f"{plot_tag}_{safe_suffix}.png")


# set_global_seed：统一设置 random/numpy/torch 的随机种子。
# 说明：确保数据划分、batch 打乱与（可能的）torch 随机操作可复现。
def set_global_seed(seed: int, deterministic_cudnn: bool = True) -> None:
    seed = int(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    if deterministic_cudnn and hasattr(torch.backends, "cudnn"):
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def _find_git_root(start: Path) -> Path | None:
    for cand in (start, *start.parents):
        if (cand / ".git").exists():
            return cand
    return None


def collect_environment_info() -> Dict[str, Any]:
    git_commit = None
    git_dirty = None
    git_root = _find_git_root(Path(__file__).resolve())
    if git_root is not None:
        try:
            git_commit = (
                subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=str(git_root), text=True)
                .strip()
            )
            git_dirty = bool(
                subprocess.check_output(["git", "status", "--porcelain"], cwd=str(git_root), text=True).strip()
            )
        except Exception:
            git_commit = None
            git_dirty = None

    cudnn_version = None
    try:
        cudnn_version = torch.backends.cudnn.version() if hasattr(torch.backends, "cudnn") else None
    except Exception:
        cudnn_version = None

    deterministic_algos = None
    if hasattr(torch, "are_deterministic_algorithms_enabled"):
        deterministic_algos = bool(torch.are_deterministic_algorithms_enabled())

    return {
        "python_version": platform.python_version(),
        "python_executable": sys.executable,
        "platform": platform.platform(),
        "numpy_version": np.__version__,
        "torch_version": torch.__version__,
        "cuda_available": bool(torch.cuda.is_available()),
        "cuda_version": torch.version.cuda,
        "cudnn_version": cudnn_version,
        "cudnn_deterministic": getattr(torch.backends.cudnn, "deterministic", None)
        if hasattr(torch.backends, "cudnn")
        else None,
        "cudnn_benchmark": getattr(torch.backends.cudnn, "benchmark", None)
        if hasattr(torch.backends, "cudnn")
        else None,
        "torch_deterministic_algorithms": deterministic_algos,
        "git_commit": git_commit,
        "git_dirty": git_dirty,
    }


# save_results_summary：保存实验结果与配置汇总到文件。
# 说明：写入 CSV/JSON 便于实验对比。
def save_results_summary(
    task_label: str,
    metric_label: str,
    results: Dict[str, Dict[str, Any]],
    plot_dir: Path,
    plot_tag: str,
    *,
    run_config: Dict[str, Any] | None = None,
    env_info: Dict[str, Any] | None = None,
) -> None:
    if not plot_dir:
        return

    run_config = run_config or {}
    env_info = env_info or collect_environment_info()

    def _csv_value(value: Any) -> Any:
        if value is None:
            return ""
        if isinstance(value, (str, int, float, bool)):
            return value
        try:
            return json.dumps(value, ensure_ascii=False, sort_keys=True)
        except Exception:
            return str(value)

    rows: List[Dict[str, Any]] = []
    for name, data in results.items():
        row: Dict[str, Any] = {
            "task": task_label,
            "metric_label": metric_label,
            "method": name,
        }
        for k, v in run_config.items():
            row[f"cfg_{k}"] = _csv_value(v)
        for k, v in env_info.items():
            row[f"env_{k}"] = _csv_value(v)

        row.update(
            {
                "method": name,
                "metric": data.get("metric"),
                "val_metric": data.get("val_metric"),
                "val_loss": data.get("val_loss"),
                "best_epoch": data.get("best_epoch"),
                "lyap_pre": data.get("lyap_pre"),
                "lyap_post": data.get("lyap_post"),
                "runtime_sec": data.get("runtime_sec"),
                "train_runtime_sec": data.get("train_runtime_sec"),
                "eval_runtime_sec": data.get("eval_runtime_sec"),
                "runtime_per_update_sec": data.get("runtime_per_update_sec"),
                "runtime_per_step_sec": data.get("runtime_per_step_sec"),
                "batches_per_epoch": data.get("batches_per_epoch"),
                "time_steps": data.get("time_steps"),
                "update_factor": data.get("update_factor"),
                "updates_per_epoch": data.get("updates_per_epoch"),
                "updates_total": data.get("updates_total"),
                "steps_total": data.get("steps_total"),
                "complexity_params": data.get("complexity_params"),
                "complexity_state": data.get("complexity_state"),
                "complexity_total": data.get("complexity_total"),
                "history_len": len(data.get("history", [])),
            }
        )
        hparams = data.get("hparams")
        if isinstance(hparams, dict):
            for k, v in hparams.items():
                row[f"hp_{k}"] = _csv_value(v)

        rows.append(row)

    csv_path = plot_dir / f"{plot_tag}_summary.csv"
    with csv_path.open("w", newline="", encoding="utf-8") as handle:
        if rows:
            preferred = [
                "task",
                "metric_label",
                "method",
                "metric",
                "val_metric",
                "val_loss",
                "best_epoch",
                "lyap_pre",
                "lyap_post",
                "runtime_sec",
                "train_runtime_sec",
                "eval_runtime_sec",
                "runtime_per_update_sec",
                "runtime_per_step_sec",
                "batches_per_epoch",
                "time_steps",
                "update_factor",
                "updates_per_epoch",
                "updates_total",
                "steps_total",
                "complexity_params",
                "complexity_state",
                "complexity_total",
                "history_len",
            ]
            all_keys = set().union(*(row.keys() for row in rows))
            fieldnames: List[str] = []
            for key in preferred:
                if key in all_keys:
                    fieldnames.append(key)
                    all_keys.remove(key)
            fieldnames.extend(sorted(all_keys))
        else:
            fieldnames = []
        writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore")
        if rows:
            writer.writeheader()
            writer.writerows(rows)

    json_path = plot_dir / f"{plot_tag}_summary.json"
    payload = {
        "schema_version": 2,
        "task": task_label,
        "metric_label": metric_label,
        "generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
        "run_config": run_config,
        "environment": env_info,
        "results": results,
    }
    with json_path.open("w", encoding="utf-8") as handle:
        json.dump(payload, handle, indent=2, ensure_ascii=False)


# ensure_python_package：确保依赖包可用，必要时提示/安装。
# 说明：尝试 import，缺失时给出安装指引。
def ensure_python_package(import_name: str, package_name: str | None = None) -> bool:
    try:
        importlib.import_module(import_name)
        return True
    except ImportError:
        pass

    pkg = package_name or import_name
    try:
        print(f"[SETUP] Installing missing package: {pkg}")
        subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])
        importlib.import_module(import_name)
        return True
    except Exception as exc:
        print(f"[SETUP] Failed to install {pkg}: {exc}")
        return False


# calculate_lyapunov_exponent_numpy：使用 NumPy QR 分解估计 Lyapunov 指数。
# 说明：通过 JVP/QR 累计增长率估计 Lyapunov 指数。
def calculate_lyapunov_exponent_numpy(
    model: Any,
    driver_input: np.ndarray | torch.Tensor,
) -> float:
    n_hidden = int(model.hidden_size)
    device = DEFAULT_DEVICE
    if hasattr(model, "device"):
        device = resolve_device(getattr(model, "device"))
    for attr in ("W_hh", "_W_hh"):
        if hasattr(model, attr):
            value = getattr(model, attr)
            if torch.is_tensor(value):
                device = value.device
                break

    W_hh = to_tensor(getattr(model, "W_hh"), device, dtype=torch.float64)
    W_xh = to_tensor(getattr(model, "W_xh"), device, dtype=torch.float64)
    b_h = to_tensor(getattr(model, "b_h"), device, dtype=torch.float64).reshape(-1, 1)

    h = torch.zeros((n_hidden, 1), dtype=torch.float64, device=device)
    Q = torch.eye(n_hidden, dtype=torch.float64, device=device)
    log_r_diag_sum = torch.zeros(n_hidden, dtype=torch.float64, device=device)
    log_floor = 1e-12

    driver = to_tensor(driver_input, device, dtype=torch.float64)
    if driver.ndim == 1:
        driver = driver[:, None]
    time_steps = int(driver.shape[1])
    if time_steps < LYAPUNOV_MIN_STEPS:
        reps = int(math.ceil(LYAPUNOV_MIN_STEPS / max(time_steps, 1)))
        driver = driver.repeat(1, reps)
        time_steps = int(driver.shape[1])
    for t in range(time_steps):
        I_t = driver[:, t].reshape(-1, 1)
        x_t = W_hh @ h + W_xh @ I_t + b_h
        h_next = torch.tanh(x_t)

        phi_prime = (1.0 - h_next**2).flatten()
        jacobian = phi_prime[:, None] * W_hh

        Z = jacobian @ Q
        try:
            Q, R = torch.linalg.qr(Z)
        except RuntimeError:
            return float("nan")

        r_diag_abs = torch.abs(torch.diag(R))
        log_r_diag_sum = log_r_diag_sum + torch.log(torch.clamp(r_diag_abs, min=log_floor))
        h = h_next

    lyaps = log_r_diag_sum / max(time_steps, 1)
    return float(torch.max(lyaps).item())


# build_lyapunov_driver：构建通用 Lyapunov 指数计算驱动。
# 说明：通过 JVP/QR 累计增长率估计 Lyapunov 指数。
def build_lyapunov_driver(
    inputs: np.ndarray | torch.Tensor, max_samples: int = 32
) -> np.ndarray | torch.Tensor:
    sample_count = min(max_samples, inputs.shape[0])
    if torch.is_tensor(inputs):
        samples = inputs[:sample_count]
        if sample_count <= 1:
            return samples[0]
        flat = samples.reshape(sample_count, -1)
        norms = torch.norm(flat, dim=1)
        median_idx = torch.argsort(norms)[sample_count // 2]
        return samples[median_idx]
    samples_np = inputs[:sample_count]
    if sample_count <= 1:
        return samples_np[0]
    flat = samples_np.reshape(sample_count, -1)
    norms = np.linalg.norm(flat, axis=1)
    median_idx = np.argsort(norms)[sample_count // 2]
    return samples_np[median_idx]


# build_lm_lyapunov_driver：构建语言模型的 Lyapunov 指数驱动。
# 说明：为语言模型构建 Lyapunov 评估函数。
def build_lm_lyapunov_driver(
    data: np.ndarray | torch.Tensor,
    vocab_size: int,
    block_size: int,
    seed: int,
) -> np.ndarray | torch.Tensor:
    rng = np.random.default_rng(seed)
    driver_inputs = None
    for inputs, _ in batch_iterator(data, vocab_size, block_size, batch_size=8, steps=1, rng=rng):
        if torch.is_tensor(inputs):
            flat = inputs.reshape(inputs.shape[0], -1)
            norms = torch.norm(flat, dim=1)
            median_idx = torch.argsort(norms)[inputs.shape[0] // 2]
            driver_inputs = inputs[median_idx]
        else:
            flat = inputs.reshape(inputs.shape[0], -1)
            norms = np.linalg.norm(flat, axis=1)
            median_idx = np.argsort(norms)[inputs.shape[0] // 2]
            driver_inputs = inputs[median_idx]
        break
    if driver_inputs is None:
        if torch.is_tensor(data):
            driver_inputs = torch.zeros((vocab_size, block_size), dtype=torch.float32, device=data.device)
        else:
            driver_inputs = np.zeros((vocab_size, block_size), dtype=np.float32)
    return driver_inputs


# =============================
# 绘图与结果输出
# =============================

# _apply_icml_plot_style：设置 ICML 风格的字体、线宽与布局参数。
# 说明：通过 rcParams 统一字体、线宽、图例等样式。
def _apply_icml_plot_style() -> None:
    try:
        import matplotlib as mpl
    except Exception:
        return

    mpl.rcParams.update(
        {
            "font.family": "serif",
            "font.serif": ["Times New Roman", "Times", "DejaVu Serif"],
            "font.weight": "normal",
            "axes.labelsize": 11,
            "axes.labelweight": "normal",
            "axes.titlesize": 12,
            "axes.titleweight": "normal",
            "xtick.labelsize": 10,
            "ytick.labelsize": 10,
            "figure.titlesize": 14,
            "figure.titleweight": "normal",
            "legend.fontsize": 9,
            "legend.frameon": False,
            "axes.spines.top": False,
            "axes.spines.right": False,
            "axes.axisbelow": True,
            "grid.alpha": 0.35,
            "grid.linestyle": ":",
            "grid.linewidth": 0.8,
            "lines.linewidth": 2.0,
            "savefig.dpi": 300,
            "figure.dpi": 150,
        }
    )


# _lyapunov_plot_enabled：判断当前配置是否需要输出 Lyapunov 图。
# 说明：为上层训练/评估流程提供辅助支持。
def _lyapunov_plot_enabled() -> bool:
    flag = str(os.getenv("OLL_PHYSICAL_PROBE_NO_LYAPUNOV", "")).strip().lower()
    return flag not in {"1", "true", "yes"}


# _method_palette：按方法数量生成可区分的颜色列表。
# 说明：根据方法数量生成稳定的配色列表。
def _method_palette(count: int) -> List[str]:
    base = [
        "#0072B2",
        "#D55E00",
        "#009E73",
        "#CC79A7",
        "#E69F00",
        "#56B4E9",
        "#F0E442",
        "#999999",
    ]
    return [base[idx % len(base)] for idx in range(max(0, count))]


# _save_figure：统一保存图像文件并处理布局与 dpi。
# 说明：支持 tight_layout/dpi 并安全写入路径。
def _save_figure(fig: Any, path: str | Path) -> None:
    fig.savefig(path, dpi=300, bbox_inches="tight", pad_inches=0.2)


# plot_comparison_results：绘制不同方法/模型的对比曲线。
# 说明：从结果字典提取指标，设置坐标轴并保存图像。
def plot_comparison_results(
    task_label: str,
    results: Dict[str, Dict[str, Any]],
    metric_label: str,
    history_label: str | None,
    higher_is_better: bool,
    plot_dir: Path,
    plot_tag: str,
    show: bool,
) -> None:
    try:
        import matplotlib.pyplot as plt
    except Exception as exc:
        print(f"[PLOT] Skipping plot: matplotlib unavailable ({exc}).")
        return
    _apply_icml_plot_style()

    names = list(results.keys())
    if not names:
        return

    colors = _method_palette(len(names))
    x = np.arange(len(names))

    metric_values = [results[name]["metric"] for name in names]
    fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
    bars = ax.bar(x, metric_values, color=colors, edgecolor="black", linewidth=0.3, zorder=3)
    better_note = "higher is better" if higher_is_better else "lower is better"
    ax.set_title(f"{task_label}: Final Performance ({better_note})", fontsize=12, fontweight="normal")
    ax.set_ylabel(metric_label, fontsize=12, fontweight="normal")
    ax.set_xticks(x)
    ax.set_xticklabels(names, rotation=15)
    ax.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
    for bar, val in zip(bars, metric_values):
        ax.text(
            bar.get_x() + bar.get_width() / 2.0,
            bar.get_height(),
            f"{val:.4f}",
            ha="center",
            va="bottom",
            fontsize=9,
        )
    metric_path = build_plot_path(plot_dir, plot_tag, "metric")
    _save_figure(fig, metric_path)
    print(f"[PLOT] Saved metric plot to {metric_path}")
    if show:
        plt.show()
    else:
        plt.close(fig)

    if _lyapunov_plot_enabled():
        lyap_pre = [results[name].get("lyap_pre", float("nan")) for name in names]
        lyap_post = [results[name].get("lyap_post", float("nan")) for name in names]
        fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
        width = 0.32
        ax.bar(x - width / 2.0, lyap_pre, width, label="Pre", color=colors, alpha=0.4, zorder=3)
        ax.bar(x + width / 2.0, lyap_post, width, label="Post", color=colors, alpha=0.9, zorder=3)
        ax.axhline(0.0, color="crimson", linestyle="--", linewidth=1.0, zorder=2)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=15)
        ax.set_title(f"{task_label}: Lyapunov (pre/post)", fontsize=12, fontweight="normal")
        ax.set_ylabel("Lyap max", fontsize=12, fontweight="normal")
        ax.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
        ax.legend(fontsize=9)
        lyap_path = build_plot_path(plot_dir, plot_tag, "lyapunov")
        _save_figure(fig, lyap_path)
        print(f"[PLOT] Saved lyapunov plot to {lyap_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)
    else:
        print("[PLOT] Lyapunov plot disabled by OLL_PHYSICAL_PROBE_NO_LYAPUNOV.")

    histories = [results[name].get("history", []) for name in names]
    has_history = any(len(h) > 0 for h in histories)
    if has_history:
        markers = ["o", "s", "^", "D", "v", "P", "X", "*", "<", ">"]
        fig, ax = plt.subplots(figsize=(7.4, 4.4), constrained_layout=True)
        for idx, (name, history, color) in enumerate(zip(names, histories, colors)):
            if not history:
                continue
            epochs = np.arange(1, len(history) + 1)
            mark_every = max(1, len(history) // 10)
            ax.plot(
                epochs,
                history,
                label=name,
                color=color,
                linewidth=2.0,
                marker=markers[idx % len(markers)],
                markersize=4,
                markevery=mark_every,
            )
        ax.set_xlabel("Epoch", fontsize=12, fontweight="normal")
        ax.set_ylabel(history_label or metric_label, fontsize=12, fontweight="normal")
        ax.set_title(f"{task_label}: Learning Curve", fontsize=12, fontweight="normal")
        ax.grid(True, linestyle=":", linewidth=0.7)
        ax.tick_params(axis="x", labelsize=10)
        ax.tick_params(axis="y", labelsize=10)
        ncol = min(3, max(1, len(names)))
        ax.legend(fontsize=9, ncol=ncol, loc="upper center", bbox_to_anchor=(0.5, 1.22))
        curve_path = build_plot_path(plot_dir, plot_tag, "curve")
        _save_figure(fig, curve_path)
        print(f"[PLOT] Saved learning-curve plot to {curve_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)


# plot_scan_results：绘制增益扫描结果并标注最优点。
# 说明：从结果字典提取指标，设置坐标轴并保存图像。
def plot_scan_results(
    task_label: str,
    gains: Iterable[float],
    metric_values: Iterable[float],
    metric_label: str,
    lyap_pre: Iterable[float] | None,
    lyap_post: Iterable[float] | None,
    plot_dir: Path,
    plot_tag: str,
    show: bool,
    best_g: float | None = None,
    higher_is_better: bool = True,
    aux_values: Iterable[float] | None = None,
    aux_label: str | None = None,
    suffix: str = "scan_lyapunov",
) -> None:
    try:
        import matplotlib.pyplot as plt
    except Exception as exc:
        print(f"[PLOT] Skipping scan plot: matplotlib unavailable ({exc}).")
        return
    if plot_dir is None:
        return
    _apply_icml_plot_style()

    gains_arr = np.asarray(list(gains), dtype=float)
    metric_arr = np.asarray(list(metric_values), dtype=float)
    if gains_arr.size == 0 or metric_arr.size == 0:
        return
    count = min(gains_arr.size, metric_arr.size)
    gains_arr = gains_arr[:count]
    metric_arr = metric_arr[:count]

    order = np.argsort(gains_arr)
    gains_arr = gains_arr[order]
    metric_arr = metric_arr[order]

    aux_arr = None
    if aux_values is not None:
        aux_arr = np.asarray(list(aux_values), dtype=float)[:count][order]
    lyap_pre_arr = None
    if lyap_pre is not None:
        lyap_pre_arr = np.asarray(list(lyap_pre), dtype=float)[:count][order]
    lyap_post_arr = None
    if lyap_post is not None:
        lyap_post_arr = np.asarray(list(lyap_post), dtype=float)[:count][order]

    lyap_enabled = _lyapunov_plot_enabled()
    if lyap_enabled:
        fig, axes = plt.subplots(2, 1, figsize=(7.2, 6.4), constrained_layout=True)
        ax_metric = axes[0]
    else:
        fig, ax_metric = plt.subplots(figsize=(7.2, 3.6), constrained_layout=True)
    ax_metric.plot(gains_arr, metric_arr, color="#0072B2", marker="o", label=metric_label)
    if best_g is not None and np.isfinite(best_g):
        ax_metric.axvline(
            float(best_g),
            color="crimson",
            linestyle="--",
            linewidth=1.0,
            label=f"best g={best_g:.3f}",
        )
    ax_metric.set_ylabel(metric_label)
    note = "higher is better" if higher_is_better else "lower is better"
    ax_metric.set_title(f"{task_label}: Scan ({note})", fontsize=12, fontweight="normal")
    ax_metric.grid(True, axis="y", linestyle=":", linewidth=0.7)
    handles, labels = ax_metric.get_legend_handles_labels()
    if aux_arr is not None and aux_label:
        ax_aux = ax_metric.twinx()
        ax_aux.plot(gains_arr, aux_arr, color="#D55E00", marker="s", label=aux_label)
        ax_aux.set_ylabel(aux_label)
        handles2, labels2 = ax_aux.get_legend_handles_labels()
        handles += handles2
        labels += labels2
    if handles:
        ax_metric.legend(handles, labels, fontsize=9, loc="best")

    if lyap_enabled:
        ax_lyap = axes[1]
        has_lyap = False
        if lyap_pre_arr is not None and np.any(np.isfinite(lyap_pre_arr)):
            ax_lyap.plot(gains_arr, lyap_pre_arr, color="#56B4E9", marker="o", label="Lyap pre")
            has_lyap = True
        if lyap_post_arr is not None and np.any(np.isfinite(lyap_post_arr)):
            ax_lyap.plot(gains_arr, lyap_post_arr, color="#009E73", marker="s", label="Lyap post")
            has_lyap = True
        ax_lyap.axhline(0.0, color="crimson", linestyle="--", linewidth=1.0)
        if best_g is not None and np.isfinite(best_g):
            ax_lyap.axvline(float(best_g), color="crimson", linestyle="--", linewidth=1.0)
        ax_lyap.set_xlabel("Gain g")
        ax_lyap.set_ylabel("Lyapunov (max)")
        ax_lyap.grid(True, axis="y", linestyle=":", linewidth=0.7)
        if has_lyap:
            ax_lyap.legend(fontsize=9, loc="best")

    scan_suffix = suffix if lyap_enabled else suffix.replace("lyapunov", "metric")
    scan_path = build_plot_path(plot_dir, plot_tag, scan_suffix)
    _save_figure(fig, scan_path)
    if lyap_enabled:
        print(f"[PLOT] Saved scan plot to {scan_path}")
    else:
        print(f"[PLOT] Saved metric-only scan plot to {scan_path}")
    if show:
        plt.show()
    else:
        plt.close(fig)


# plot_cost_profile：绘制损失或代价随时间的变化。
# 说明：从结果字典提取指标，设置坐标轴并保存图像。
def plot_cost_profile(
    task_label: str,
    results: Dict[str, Dict[str, Any]],
    plot_dir: Path,
    plot_tag: str,
    show: bool,
) -> None:
    try:
        import matplotlib.pyplot as plt
    except Exception as exc:
        print(f"[PLOT] Skipping cost plot: matplotlib unavailable ({exc}).")
        return
    _apply_icml_plot_style()

    names = list(results.keys())
    if not names:
        return

    params = [results[name].get("complexity_params") for name in names]
    state = [results[name].get("complexity_state") for name in names]
    runtime = [results[name].get("runtime_sec") for name in names]
    train_runtime = [results[name].get("train_runtime_sec") for name in names]
    eval_runtime = [results[name].get("eval_runtime_sec") for name in names]
    runtime_per_update = [results[name].get("runtime_per_update_sec") for name in names]
    runtime_per_step = [results[name].get("runtime_per_step_sec") for name in names]

    has_complexity = any(v is not None for v in params) or any(v is not None for v in state)
    has_runtime = any(v is not None for v in runtime) or any(v is not None for v in train_runtime)
    has_runtime_per_update = any(v is not None for v in runtime_per_update)
    has_runtime_per_step = any(v is not None for v in runtime_per_step)
    if not has_complexity and not has_runtime and not has_runtime_per_update and not has_runtime_per_step:
        return

    colors = _method_palette(len(names))
    x = np.arange(len(names))

    # _scale_time：按量级缩放时间并返回合适单位。
    # 说明：根据最大时间尺度选择 s/ms/us 单位。
    def _scale_time(values: np.ndarray) -> Tuple[np.ndarray, str]:
        finite = values[np.isfinite(values)]
        if finite.size == 0:
            return values, "s"
        max_val = float(np.max(finite))
        if max_val < 1e-3:
            return values * 1e6, "us"
        if max_val < 1.0:
            return values * 1e3, "ms"
        return values, "s"

    if has_complexity:
        fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
        params_arr = np.array([float(v) if v is not None else 0.0 for v in params])
        state_arr = np.array([float(v) if v is not None else 0.0 for v in state])
        params_m = params_arr / 1e6
        state_m = state_arr / 1e6
        ax.bar(x, params_m, color=colors, edgecolor="black", linewidth=0.3, label="Params", zorder=3)
        if np.any(state_m > 0):
            ax.bar(
                x,
                state_m,
                bottom=params_m,
                color="0.7",
                edgecolor="black",
                linewidth=0.3,
                label="State",
                zorder=3,
            )
        totals = params_m + state_m
        for xi, total in zip(x, totals):
            ax.text(xi, total, f"{total:.2f}", ha="center", va="bottom", fontsize=9)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=15)
        ax.set_ylabel("Elements (M)")
        ax.set_title(f"{task_label}: Model Size (params + persistent state)", fontsize=12, fontweight="normal")
        ax.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
        if np.any(state_m > 0):
            ax.legend(fontsize=9)
        complexity_path = build_plot_path(plot_dir, plot_tag, "complexity")
        _save_figure(fig, complexity_path)
        print(f"[PLOT] Saved complexity plot to {complexity_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)

    if has_runtime:
        fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
        use_train_runtime = any(v is not None for v in train_runtime)
        runtime_source = train_runtime if use_train_runtime else runtime
        runtime_arr = np.array([float(v) if v is not None else float("nan") for v in runtime_source])
        runtime_clean = np.nan_to_num(runtime_arr, nan=0.0)
        if use_train_runtime:
            eval_arr = np.array([float(v) if v is not None else 0.0 for v in eval_runtime])
            has_eval = np.any(eval_arr > 0)
        else:
            eval_arr = np.zeros_like(runtime_clean)
            has_eval = False

        if use_train_runtime:
            ax.bar(x, runtime_clean, color=colors, edgecolor="black", linewidth=0.3, zorder=3, label="Train")
        else:
            ax.bar(x, runtime_clean, color=colors, edgecolor="black", linewidth=0.3, zorder=3)
        if has_eval:
            ax.bar(
                x,
                eval_arr,
                bottom=runtime_clean,
                color="0.85",
                edgecolor="black",
                linewidth=0.3,
                zorder=3,
                label="Eval",
            )
        totals = runtime_clean + (eval_arr if has_eval else 0.0)
        for xi, base, total in zip(x, runtime_arr, totals):
            label = "n/a" if not np.isfinite(base) else f"{total:.2f}s"
            y = 0.0 if not np.isfinite(base) else total
            ax.text(xi, y, label, ha="center", va="bottom", fontsize=9)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=15)
        ax.set_ylabel("Seconds")
        if use_train_runtime:
            title = (
                f"{task_label}: Runtime (train + eval, fixed epochs)"
                if has_eval
                else f"{task_label}: Runtime (train, fixed epochs)"
            )
        else:
            title = f"{task_label}: Runtime"
        ax.set_title(title, fontsize=12, fontweight="normal")
        ax.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
        if has_eval:
            ax.legend(fontsize=9)
        runtime_path = build_plot_path(plot_dir, plot_tag, "runtime")
        _save_figure(fig, runtime_path)
        print(f"[PLOT] Saved runtime plot to {runtime_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)

    if has_runtime_per_update:
        fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
        per_update_arr = np.array([float(v) if v is not None else float("nan") for v in runtime_per_update])
        scaled_vals, unit = _scale_time(per_update_arr)
        scaled_vals = np.nan_to_num(scaled_vals, nan=0.0)
        ax.bar(x, scaled_vals, color=colors, edgecolor="black", linewidth=0.3, zorder=3)
        for xi, raw, scaled in zip(x, per_update_arr, scaled_vals):
            label = "n/a" if not np.isfinite(raw) else f"{scaled:.2f}{unit}"
            y = 0.0 if not np.isfinite(raw) else scaled
            ax.text(xi, y, label, ha="center", va="bottom", fontsize=9)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=15)
        ax.set_ylabel(f"Time per update ({unit})")
        ax.set_title(f"{task_label}: Runtime per Update (method-specific)", fontsize=12, fontweight="normal")
        ax.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
        per_update_path = build_plot_path(plot_dir, plot_tag, "runtime_per_update")
        _save_figure(fig, per_update_path)
        print(f"[PLOT] Saved per-update runtime plot to {per_update_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)

    if has_runtime_per_step:
        fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
        per_step_arr = np.array([float(v) if v is not None else float("nan") for v in runtime_per_step])
        scaled_vals, unit = _scale_time(per_step_arr)
        scaled_vals = np.nan_to_num(scaled_vals, nan=0.0)
        ax.bar(x, scaled_vals, color=colors, edgecolor="black", linewidth=0.3, zorder=3)
        for xi, raw, scaled in zip(x, per_step_arr, scaled_vals):
            label = "n/a" if not np.isfinite(raw) else f"{scaled:.2f}{unit}"
            y = 0.0 if not np.isfinite(raw) else scaled
            ax.text(xi, y, label, ha="center", va="bottom", fontsize=9)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=15)
        ax.set_ylabel(f"Time per step ({unit})")
        ax.set_title(f"{task_label}: Runtime per Step (normalized)", fontsize=12, fontweight="normal")
        ax.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
        per_step_path = build_plot_path(plot_dir, plot_tag, "runtime_per_step")
        _save_figure(fig, per_step_path)
        print(f"[PLOT] Saved per-step runtime plot to {per_step_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)


# _default_time_indices：生成默认时间步索引，便于可视化对齐。
# 说明：为上层训练/评估流程提供辅助支持。
def _default_time_indices(seq_len: int) -> List[int]:
    if seq_len <= 0:
        return [0]
    candidates = [0, seq_len // 2, seq_len - 1]
    unique: List[int] = []
    for idx in candidates:
        if 0 <= idx < seq_len and idx not in unique:
            unique.append(idx)
    return unique


# predict_sequence_outputs_batch：批量执行 teacher forcing 序列预测。
# 说明：批量 teacher forcing 预测序列。
def predict_sequence_outputs_batch(
    model: Any,
    input_seq: np.ndarray | torch.Tensor,
) -> torch.Tensor:
    device = DEFAULT_DEVICE
    if hasattr(model, "device"):
        device = resolve_device(getattr(model, "device"))
    if hasattr(model, "W_hh") and torch.is_tensor(getattr(model, "W_hh")):
        device = getattr(model, "W_hh").device
    inputs_t = to_tensor(input_seq, device=device, dtype=torch.float32)
    if inputs_t.dim() == 2:
        inputs_t = inputs_t.unsqueeze(0)
    batch_size = int(inputs_t.shape[0])
    h_prev = torch.zeros((model.hidden_size, batch_size), device=device, dtype=torch.float32)
    outputs_seq, _ = model.forward_cycle(inputs_t, h_prev)
    outputs_stack = torch.stack(outputs_seq, dim=2)
    return outputs_stack


# predict_sequence_outputs：teacher forcing 下输出序列预测。
# 说明：teacher forcing 下返回逐步预测。
def predict_sequence_outputs(model: Any, input_seq: np.ndarray | torch.Tensor) -> np.ndarray:
    outputs_stack = predict_sequence_outputs_batch(model, input_seq)
    outputs = outputs_stack[:, 0, :]
    return outputs.detach().cpu().numpy()


# predict_sequence_rollout_batch：批量执行 rollout 序列预测。
# 说明：批量 rollout 生成预测序列。
def predict_sequence_rollout_batch(
    model: Any,
    input_seq: np.ndarray | torch.Tensor,
    warmup_steps: int = 1,
) -> torch.Tensor:
    device = DEFAULT_DEVICE
    if hasattr(model, "device"):
        device = resolve_device(getattr(model, "device"))
    if hasattr(model, "W_hh") and torch.is_tensor(getattr(model, "W_hh")):
        device = getattr(model, "W_hh").device

    inputs_t = to_tensor(input_seq, device=device, dtype=torch.float32)
    if inputs_t.dim() == 2:
        inputs_t = inputs_t.unsqueeze(0)
    batch_size = int(inputs_t.shape[0])
    time_steps = int(inputs_t.shape[2])
    input_size = int(inputs_t.shape[1])
    output_size = int(getattr(model, "output_size", input_size))
    if output_size != input_size:
        return predict_sequence_outputs_batch(model, inputs_t)

    W_hh = to_tensor(getattr(model, "W_hh"), device=device, dtype=torch.float32)
    W_xh = to_tensor(getattr(model, "W_xh"), device=device, dtype=torch.float32)
    b_h = to_tensor(getattr(model, "b_h"), device=device, dtype=torch.float32).reshape(-1, 1)
    W_hy = to_tensor(getattr(model, "W_hy"), device=device, dtype=torch.float32)
    b_y = to_tensor(getattr(model, "b_y"), device=device, dtype=torch.float32).reshape(-1, 1)

    hidden_size = int(getattr(model, "hidden_size", W_hh.shape[0]))
    h_prev = torch.zeros((hidden_size, batch_size), device=device, dtype=torch.float32)

    warmup_steps = max(1, min(int(warmup_steps), time_steps))
    outputs: List[torch.Tensor] = []
    prev_output = None

    with torch.no_grad():
        for t in range(time_steps):
            if t < warmup_steps:
                I_t = inputs_t[:, :, t].T
            else:
                I_t = prev_output
            x_t = W_hh @ h_prev + W_xh @ I_t + b_h
            h_prev = torch.tanh(x_t)
            y_hat_t = W_hy @ h_prev + b_y
            outputs.append(y_hat_t)
            prev_output = y_hat_t

    outputs_stack = torch.stack(outputs, dim=2)
    return outputs_stack


# predict_sequence_rollout：rollout 方式递推生成序列预测。
# 说明：rollout 递推生成预测序列。
def predict_sequence_rollout(
    model: Any,
    input_seq: np.ndarray | torch.Tensor,
    warmup_steps: int = 1,
) -> np.ndarray:
    outputs_stack = predict_sequence_rollout_batch(model, input_seq, warmup_steps=warmup_steps)
    outputs = outputs_stack[:, 0, :]
    return outputs.detach().cpu().numpy()


# plot_image_sequence_predictions：可视化图像序列预测与真值对比。
# 说明：从结果字典提取指标，设置坐标轴并保存图像。
def plot_image_sequence_predictions(
    task_label: str,
    predictions: Dict[str, np.ndarray],
    targets: np.ndarray,
    frame_h: int,
    frame_w: int,
    frame_channels: int | None,
    time_indices: List[int] | None,
    frame_index: int | None,
    plot_dir: Path,
    plot_tag: str,
    show: bool,
    inputs: np.ndarray | None = None,
    mode_label: str | None = None,
) -> None:
    try:
        import matplotlib.pyplot as plt
    except Exception as exc:
        print(f"[PLOT] Skipping prediction plot: matplotlib unavailable ({exc}).")
        return
    _apply_icml_plot_style()

    if not predictions:
        return
    if frame_h <= 0 or frame_w <= 0:
        print("[PLOT] Skipping prediction plot: invalid frame size.")
        return

    frame_size = frame_h * frame_w
    if frame_size <= 0:
        print("[PLOT] Skipping prediction plot: invalid frame size.")
        return
    frame_channels = max(1, int(frame_channels or 1))
    frame_depth = frame_size * frame_channels

    if targets.shape[0] == frame_depth:
        frame_slice = slice(0, frame_depth)
    elif targets.shape[0] % frame_depth == 0:
        total_frames = targets.shape[0] // frame_depth
        idx = int(frame_index or 0)
        if idx < 0 or idx >= total_frames:
            print("[PLOT] Skipping prediction plot: frame index out of range.")
            return
        start = idx * frame_depth
        frame_slice = slice(start, start + frame_depth)
    else:
        print("[PLOT] Skipping prediction plot: frame size mismatch.")
        return

    valid_predictions: Dict[str, np.ndarray] = {}
    for name, pred in predictions.items():
        if pred.shape[0] < frame_slice.stop:
            print(f"[PLOT] Skipping prediction plot: {name} frame size mismatch.")
            continue
        valid_predictions[name] = pred
    if not valid_predictions:
        print("[PLOT] Skipping prediction plot: no predictions with matching frame size.")
        return

    seq_len = targets.shape[1]
    time_indices = time_indices or _default_time_indices(seq_len)
    time_indices = [t for t in time_indices if 0 <= t < seq_len]
    if not time_indices:
        time_indices = _default_time_indices(seq_len)

    # render_strip：将图像序列渲染为横向条带图。
    # 说明：按时间步拼接图像为横向条带。
    def render_strip(frames: np.ndarray, title: str, suffix: str) -> None:
        n_cols = len(time_indices)
        fig, axes = plt.subplots(
            1,
            n_cols,
            figsize=(2.8 * n_cols, 2.6),
            gridspec_kw={"wspace": 0.02},
        )
        fig.subplots_adjust(left=0.02, right=0.99, top=0.86, bottom=0.12)
        if n_cols == 1:
            axes = [axes]
        fig.suptitle(title, fontsize=13, fontweight="normal", y=0.98)

        for col, t in enumerate(time_indices):
            frame = frames[:, t]
            ax = axes[col]
            if frame_channels == 1:
                frame_img = frame.reshape(frame_h, frame_w)
                ax.imshow(frame_img, cmap="viridis", vmin=0.0, vmax=1.0)
            else:
                frame_img = frame.reshape(frame_h, frame_w, frame_channels)
                frame_img = np.clip(frame_img, 0.0, 1.0)
                if frame_channels == 3:
                    ax.imshow(frame_img)
                else:
                    frame_img = np.mean(frame_img, axis=2)
                    ax.imshow(frame_img, cmap="viridis", vmin=0.0, vmax=1.0)
            ax.set_xlabel(f"t={t}", fontsize=10, fontweight="normal", labelpad=6)
            ax.set_xticks([])
            ax.set_yticks([])
            ax.tick_params(left=False, bottom=False)
            for spine in ax.spines.values():
                spine.set_visible(False)

        plot_path = build_plot_path(plot_dir, plot_tag, suffix)
        _save_figure(fig, plot_path)
        print(f"[PLOT] Saved prediction frames to {plot_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)

    if (
        inputs is not None
        and inputs.ndim == 2
        and inputs.shape[0] >= frame_slice.stop
        and inputs.shape[1] == targets.shape[1]
    ):
        input_frames = np.clip(inputs[frame_slice], 0.0, 1.0)
        render_strip(input_frames, f"{task_label}: Input (t)", "pred_input")

    target_frames = np.clip(targets[frame_slice], 0.0, 1.0)
    render_strip(target_frames, f"{task_label}: Target", "pred_target")

    for name, pred in valid_predictions.items():
        pred_frames = np.clip(pred[frame_slice], 0.0, 1.0)
        suffix = f" ({mode_label})" if mode_label else ""
        render_strip(pred_frames, f"{task_label}: {name}{suffix}", f"pred_{_slugify(name)}")


# plot_timeseries_predictions：绘制时间序列预测与真实值对比。
# 说明：从结果字典提取指标，设置坐标轴并保存图像。
def plot_timeseries_predictions(
    task_label: str,
    predictions: Dict[str, np.ndarray],
    targets: np.ndarray,
    dims: List[int] | None,
    plot_path: str | None,
    show: bool,
) -> None:
    try:
        import matplotlib.pyplot as plt
    except Exception as exc:
        print(f"[PLOT] Skipping prediction plot: matplotlib unavailable ({exc}).")
        return
    _apply_icml_plot_style()

    if not predictions:
        return
    if targets.ndim != 2 or targets.shape[0] == 0:
        print("[PLOT] Skipping prediction plot: invalid target shape.")
        return

    output_size, seq_len = targets.shape
    if dims is None:
        dims = list(range(min(output_size, 3)))
    dims = [int(d) for d in dims if 0 <= int(d) < output_size]
    if not dims:
        dims = [0]

    t = np.arange(seq_len)
    n_rows = len(dims)
    fig, axes = plt.subplots(
        n_rows,
        1,
        figsize=(8.5, 2.4 * n_rows),
        sharex=True,
        gridspec_kw={"hspace": 0.28},
    )
    fig.subplots_adjust(top=0.88, bottom=0.1)
    if n_rows == 1:
        axes = [axes]

    fig.suptitle(f"{task_label}: Prediction vs Target", fontsize=14, fontweight="normal")

    pred_colors = _method_palette(len(predictions))
    pred_items = list(predictions.items())

    for idx, dim in enumerate(dims):
        ax = axes[idx]
        ax.plot(t, targets[dim], color="black", linewidth=2.2, label="Target")
        for (name, pred), color in zip(pred_items, pred_colors):
            if pred.shape[0] <= dim:
                continue
            ax.plot(t, pred[dim], linewidth=1.7, alpha=0.85, label=name, color=color)
        ax.set_ylabel(f"Dim {dim}", fontsize=12, fontweight="normal")
        ax.grid(True, linestyle=":", linewidth=0.7)
        ax.tick_params(axis="x", labelsize=10)
        ax.tick_params(axis="y", labelsize=10)
        ax.margins(x=0)
        if idx == n_rows - 1:
            ax.set_xlabel("Time", fontsize=12, fontweight="normal")
        if idx == 0:
            ncol = min(3, max(1, len(predictions) + 1))
            ax.legend(
                fontsize=9,
                ncol=ncol,
                loc="upper left",
                frameon=True,
                framealpha=0.9,
            )

    if plot_path:
        _save_figure(fig, plot_path)
        print(f"[PLOT] Saved prediction plot to {plot_path}")
    if show:
        plt.show()
    else:
        plt.close(fig)


# plot_trajectory_predictions：绘制轨迹预测与真实轨迹对比。
# 说明：从结果字典提取指标，设置坐标轴并保存图像。
def plot_trajectory_predictions(
    task_label: str,
    predictions: Dict[str, np.ndarray],
    targets: np.ndarray,
    dims: List[int] | None,
    plot_path: str | None,
    show: bool,
    skip_steps: int = 0,
) -> None:
    try:
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
    except Exception as exc:
        print(f"[PLOT] Skipping prediction plot: matplotlib unavailable ({exc}).")
        return
    _apply_icml_plot_style()

    if not predictions:
        return
    if targets.ndim != 2 or targets.shape[0] < 3:
        print("[PLOT] Skipping prediction plot: invalid target shape.")
        return

    output_size = targets.shape[0]
    if dims is None:
        dims = [0, 1, 2]
    dims = [int(d) for d in dims]
    if len(dims) < 3 or any(d < 0 or d >= output_size for d in dims[:3]):
        print("[PLOT] Skipping prediction plot: invalid trajectory dims.")
        return

    skip_steps = max(0, int(skip_steps))
    if targets.shape[1] <= skip_steps:
        print("[PLOT] Skipping prediction plot: skip_steps exceeds trajectory length.")
        return
    time_slice = slice(skip_steps, None)

    x_t, y_t, z_t = targets[dims[0], time_slice], targets[dims[1], time_slice], targets[dims[2], time_slice]
    fig = plt.figure(figsize=(7.6, 6.2), constrained_layout=True)
    ax = fig.add_subplot(111, projection="3d")
    ax.plot(x_t, y_t, z_t, color="black", linewidth=2.2, label="Target")

    x_min, x_max = float(np.min(x_t)), float(np.max(x_t))
    y_min, y_max = float(np.min(y_t)), float(np.max(y_t))
    z_min, z_max = float(np.min(z_t)), float(np.max(z_t))

    colors = _method_palette(len(predictions))
    for idx, (name, pred) in enumerate(predictions.items()):
        if pred.shape[0] <= max(dims[:3]):
            continue
        if pred.ndim != 2 or pred.shape[1] <= skip_steps:
            continue
        x_p, y_p, z_p = pred[dims[0], time_slice], pred[dims[1], time_slice], pred[dims[2], time_slice]
        ax.plot(x_p, y_p, z_p, linewidth=1.5, alpha=0.9, label=name, color=colors[idx % len(colors)])

        x_min = min(x_min, float(np.min(x_p)))
        x_max = max(x_max, float(np.max(x_p)))
        y_min = min(y_min, float(np.min(y_p)))
        y_max = max(y_max, float(np.max(y_p)))
        z_min = min(z_min, float(np.min(z_p)))
        z_max = max(z_max, float(np.max(z_p)))

    try:
        ax.view_init(elev=25, azim=-60)
    except Exception:
        pass

    ranges = np.array([x_max - x_min, y_max - y_min, z_max - z_min], dtype=np.float64)
    max_range = float(np.max(ranges))
    if max_range <= 0.0 or not np.isfinite(max_range):
        max_range = 1.0
    margin = 0.05 * max_range
    half = 0.5 * max_range + margin
    x_mid = 0.5 * (x_max + x_min)
    y_mid = 0.5 * (y_max + y_min)
    z_mid = 0.5 * (z_max + z_min)
    ax.set_xlim(x_mid - half, x_mid + half)
    ax.set_ylim(y_mid - half, y_mid + half)
    ax.set_zlim(z_mid - half, z_mid + half)

    ax.set_xlabel("x", fontsize=12, fontweight="normal")
    ax.set_ylabel("y", fontsize=12, fontweight="normal")
    ax.set_zlabel("z", fontsize=12, fontweight="normal")
    ax.set_title(f"{task_label}: 3D Trajectory", fontsize=12, fontweight="normal")
    ax.legend(fontsize=9, loc="upper left", bbox_to_anchor=(1.02, 1.0))

    if plot_path:
        _save_figure(fig, plot_path)
        print(f"[PLOT] Saved prediction plot to {plot_path}")
    if show:
        plt.show()
    else:
        plt.close(fig)


# =============================
# Local-rule RNN (online updates + optional FPTT surrogates)
# =============================

# 局部学习规则的RNN（可选 FPTT 软标签）。
class TorchLocalRuleRNN:
    # __init__：初始化模块超参数、子模块与状态缓存。
    # 说明：为上层训练/评估流程提供辅助支持。
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        eta: float = 1e-3,
        lambda_window: int = 50,
        loss_mode: str = "ce",
        max_grad_norm: float = 5.0,
        seed: int | None = None,
        device: torch.device | str | None = None,
    ) -> None:
        if loss_mode not in {"ce", "mse"}:
            raise ValueError("loss_mode must be 'ce' or 'mse'.")
        rng = np.random.default_rng(seed)
        self.device = resolve_device(device)
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.output_size = output_size
        self.eta = float(eta)
        self.epsilon = 1e-8
        self.loss_mode = loss_mode
        self.max_grad_norm = float(max_grad_norm)
        self.step_weights: torch.Tensor | None = None

        input_scale = min(0.1, 1.0 / math.sqrt(max(1, input_size)))
        self.W_xh = to_tensor(
            rng.standard_normal((hidden_size, input_size)).astype(np.float32) * input_scale,
            self.device,
        )
        self.W_hh = to_tensor(
            rng.standard_normal((hidden_size, hidden_size)).astype(np.float32),
            self.device,
        )
        self.b_h = torch.zeros((hidden_size, 1), dtype=torch.float32, device=self.device)
        self.W_hy = to_tensor(
            rng.standard_normal((output_size, hidden_size)).astype(np.float32) * 0.1,
            self.device,
        )
        self.b_y = torch.zeros((output_size, 1), dtype=torch.float32, device=self.device)

        self.alpha_rho = 0.995
        self.alpha_clip_min = -0.99
        self.alpha_clip_max = 0.99

        self.lambda_window = max(2, int(lambda_window))
        self.lambda_rho = (self.lambda_window - 1) / self.lambda_window
        self.lambda_cap = 0.99
        self.denom_floor = 1e-3
        self.eps_lambda = 1e-8

        self.use_fptt_surrogates = False
        self.beta_schedule = "linear"
        self.fptt_Q_prev: torch.Tensor | None = None
        self.fptt_Q_sum: torch.Tensor | None = None
        self.fptt_Q_count: torch.Tensor | None = None

        self.reset_learning_state()

    # reset_learning_state：清理训练过程中的缓存与统计量。
    # 说明：为上层训练/评估流程提供辅助支持。
    def reset_learning_state(self) -> None:
        shape = (self.hidden_size, 1)
        self.alpha_num = torch.zeros(shape, dtype=torch.float32, device=self.device)
        self.alpha_den = torch.zeros(shape, dtype=torch.float32, device=self.device)
        self.alpha_hat = torch.zeros(shape, dtype=torch.float32, device=self.device)
        self.S_A2 = torch.zeros(shape, dtype=torch.float32, device=self.device)
        self.S_AB = torch.zeros(shape, dtype=torch.float32, device=self.device)
        self.lambda_vals = torch.zeros(shape, dtype=torch.float32, device=self.device)

    # initialize_weights_with_gain：按指定增益初始化权重矩阵。
    # 说明：用于软标签平滑、分块训练与正则约束。
    def initialize_weights_with_gain(self, g: float, seed: int | None = None) -> None:
        rng = np.random.default_rng(seed)
        std_dev = g / math.sqrt(self.hidden_size)
        self.W_hh = to_tensor(
            rng.standard_normal(self.W_hh.shape).astype(np.float32) * std_dev,
            self.device,
        )

    # enable_fptt_surrogates：启用 FPTT 软标签/替代目标并设置权重。
    # 说明：初始化 Q/权重并开启软标签替代。
    def enable_fptt_surrogates(
        self,
        time_steps: int,
        output_size: int,
        Q_init: np.ndarray | torch.Tensor | None = None,
        beta_schedule: str = "linear",
    ) -> None:
        # FPTT 软标签：若未提供 Q_init，用均匀分布初始化。
        if Q_init is None:
            Q_init = np.full((output_size, time_steps), 1.0 / output_size, dtype=np.float32)
        self.use_fptt_surrogates = True
        self.beta_schedule = beta_schedule
        self.fptt_Q_prev = to_tensor(Q_init, self.device, dtype=torch.float32).clone()
        self.reset_fptt_epoch_accumulators(time_steps, output_size)

    # reset_fptt_epoch_accumulators：清空 FPTT 的 epoch 内累计统计。
    # 说明：清空 epoch 内软标签/统计累计值。
    def reset_fptt_epoch_accumulators(
        self,
        time_steps: int | None = None,
        output_size: int | None = None,
    ) -> None:
        if self.fptt_Q_prev is None and (time_steps is None or output_size is None):
            raise ValueError("Call enable_fptt_surrogates first.")
        if time_steps is None or output_size is None:
            time_steps = int(self.fptt_Q_prev.shape[1])
            output_size = int(self.fptt_Q_prev.shape[0])
        self.fptt_Q_sum = torch.zeros(
            (output_size, time_steps),
            dtype=torch.float64,
            device=self.device,
        )
        self.fptt_Q_count = torch.zeros((time_steps,), dtype=torch.int64, device=self.device)

    # finalize_fptt_epoch：在 epoch 结束时更新 FPTT 累计量。
    # 说明：汇总累计概率并更新 Q 供下个 epoch。
    def finalize_fptt_epoch(self) -> None:
        if self.fptt_Q_sum is None or self.fptt_Q_count is None:
            return
        counts = torch.clamp(self.fptt_Q_count.to(dtype=torch.float64), min=1.0)
        Q_new = (self.fptt_Q_sum / counts[None, :]).to(dtype=torch.float32)
        self.fptt_Q_prev = Q_new

    # run_one_cycle_and_update_directly：执行一次循环并直接更新权重。
    # 说明：更新隐藏状态并返回输出/损失等结果。
    def run_one_cycle_and_update_directly(
        self,
        inputs_cycle: np.ndarray | torch.Tensor,
        targets_cycle: np.ndarray | torch.Tensor,
        h_prev_cycle: np.ndarray | torch.Tensor,
    ) -> Tuple[float, torch.Tensor]:
        inputs = to_tensor(inputs_cycle, self.device, dtype=torch.float32)
        targets = to_tensor(targets_cycle, self.device, dtype=torch.float32)
        h_prev = to_tensor(h_prev_cycle, self.device, dtype=torch.float32)
        total_cycle_loss = torch.zeros((), device=self.device)
        batch_size = int(inputs.shape[0])
        time_steps = int(inputs.shape[2])
        eps = 1e-12
        step_weights = self.step_weights
        if step_weights is not None and not torch.is_tensor(step_weights):
            step_weights = to_tensor(step_weights, self.device, dtype=torch.float32)
        weight_norm = float(step_weights.sum().item()) if step_weights is not None else float(time_steps)
        weight_norm = max(weight_norm, 1.0)

        prev_g: torch.Tensor | None = None
        prev_u: torch.Tensor | None = None
        prev_delta: torch.Tensor | None = None

        with torch.no_grad():
            for t in range(time_steps):
                step_weight = step_weights[t] if step_weights is not None else 1.0
                I_t = inputs[:, :, t].T
                y_target_t = targets[:, :, t].T

                x_t = self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h
                h_t = torch.tanh(x_t)
                y_hat_t = self.W_hy @ h_t + self.b_y

                if self.loss_mode == "ce" and self.use_fptt_surrogates and self.fptt_Q_prev is not None:
                    # FPTT：用历史软标签 Q_t 平滑当前标签，并累计概率更新 Q。
                    beta_t = float(t + 1) / float(time_steps) if self.beta_schedule == "linear" else 1.0
                    P_t = softmax(y_hat_t)
                    Q_t = self.fptt_Q_prev[:, t].reshape(-1, 1)
                    Q_t_batch = Q_t.repeat(1, y_hat_t.shape[1])
                    Y_tilde = beta_t * y_target_t + (1.0 - beta_t) * Q_t_batch

                    CE_true = -torch.sum(y_target_t * torch.log(P_t + eps), dim=0)
                    CE_div = -torch.sum(Q_t_batch * torch.log(P_t + eps), dim=0)
                    loss_t = torch.mean(beta_t * CE_true + (1.0 - beta_t) * CE_div)
                    dL_dyhat = P_t - Y_tilde

                    if self.fptt_Q_sum is not None and self.fptt_Q_count is not None:
                        self.fptt_Q_sum[:, t] += torch.sum(P_t, dim=1).to(self.fptt_Q_sum.dtype)
                        self.fptt_Q_count[t] += batch_size
                elif self.loss_mode == "ce":
                    P_t = softmax(y_hat_t)
                    loss_t = -torch.mean(torch.sum(y_target_t * torch.log(P_t + eps), dim=0))
                    dL_dyhat = P_t - y_target_t
                else:
                    error = y_hat_t - y_target_t
                    loss_t = 0.5 * torch.mean(torch.sum(error**2, dim=0))
                    dL_dyhat = error

                if step_weights is not None:
                    loss_t = loss_t * step_weight
                    dL_dyhat = dL_dyhat * step_weight

                total_cycle_loss = total_cycle_loss + loss_t

                g_t = self.W_hy.T @ dL_dyhat
                u_t = 1.0 - h_t**2

                lambda_used = self.lambda_vals.clone()
                denominator = 1.0 - lambda_used * u_t
                denom_mask = torch.abs(denominator) < self.denom_floor
                denominator = torch.where(
                    denom_mask,
                    self.denom_floor * torch.sign(denominator + 1e-12),
                    denominator,
                )
                delta_t = (u_t * g_t) / denominator

                # Estimate alpha_hat from the teaching signal itself:
                # \widetilde{\delta}_t \approx \alpha \widetilde{\delta}_{t-1}.
                if prev_delta is not None:
                    dtp_mean = torch.mean(delta_t * prev_delta, dim=1, keepdim=True)
                    dpp_mean = torch.mean(prev_delta**2, dim=1, keepdim=True)
                    self.alpha_num = self.alpha_rho * self.alpha_num + (1.0 - self.alpha_rho) * dtp_mean
                    self.alpha_den = self.alpha_rho * self.alpha_den + (1.0 - self.alpha_rho) * dpp_mean
                    raw_alpha = self.alpha_num / (self.alpha_den + self.epsilon)
                    self.alpha_hat = torch.clamp(raw_alpha, self.alpha_clip_min, self.alpha_clip_max)

                dW_hh = (delta_t @ h_prev.T) / batch_size
                dW_xh = (delta_t @ I_t.T) / batch_size
                db_h = torch.mean(delta_t, dim=1, keepdim=True)
                dW_hy = (dL_dyhat @ h_t.T) / batch_size
                db_y = torch.mean(dL_dyhat, dim=1, keepdim=True)

                dW_hh, dW_xh, db_h, dW_hy, db_y = clip_gradients(
                    [dW_hh, dW_xh, db_h, dW_hy, db_y],
                    self.max_grad_norm,
                )

                self.W_hh.add_(-self.eta * dW_hh)
                self.W_xh.add_(-self.eta * dW_xh)
                self.b_h.add_(-self.eta * db_h)
                self.W_hy.add_(-self.eta * dW_hy)
                self.b_y.add_(-self.eta * db_y)

                if prev_g is not None and prev_u is not None:
                    A_s = prev_u * u_t * (self.alpha_hat * prev_g - g_t)
                    B_s = self.alpha_hat * prev_u * prev_g - u_t * g_t
                    A2_mean = torch.mean(A_s**2, dim=1, keepdim=True)
                    AB_mean = torch.mean(A_s * B_s, dim=1, keepdim=True)
                    self.S_A2 = self.lambda_rho * self.S_A2 + (1.0 - self.lambda_rho) * A2_mean
                    self.S_AB = self.lambda_rho * self.S_AB + (1.0 - self.lambda_rho) * AB_mean

                    lambda_unproj = self.S_AB / (self.S_A2 + self.eps_lambda)
                    u_abs_max = torch.max(torch.abs(u_t), dim=1, keepdim=True).values + 1e-12
                    safe_cap = (1.0 - self.denom_floor) / u_abs_max
                    cap = torch.minimum(safe_cap, torch.full_like(safe_cap, self.lambda_cap))
                    self.lambda_vals = torch.clamp(lambda_unproj, min=-cap, max=cap)

                prev_g = g_t
                prev_u = u_t
                h_prev = h_t
                prev_delta = delta_t

        return float((total_cycle_loss / weight_norm).item()), h_prev

    # forward_cycle：执行完整时间序列循环的前向与状态更新。
    # 说明：更新隐藏状态并返回输出/损失等结果。
    def forward_cycle(
        self, inputs_cycle: np.ndarray | torch.Tensor, h_prev_cycle: np.ndarray | torch.Tensor
    ) -> Tuple[List[torch.Tensor], torch.Tensor]:
        inputs = to_tensor(inputs_cycle, self.device, dtype=torch.float32)
        h_prev = to_tensor(h_prev_cycle, self.device, dtype=torch.float32)
        outputs: List[torch.Tensor] = []
        with torch.no_grad():
            for t in range(inputs.shape[2]):
                I_t = inputs[:, :, t].T
                x_t = self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h
                h_prev = torch.tanh(x_t)
                outputs.append(self.W_hy @ h_prev + self.b_y)
        return outputs, h_prev


NumpyLocalRuleRNN = TorchLocalRuleRNN

# =============================
# BPTT RNN（基线，autograd）
# =============================

class TorchBPTTRNN(torch.nn.Module):
    # __init__：初始化模块超参数、子模块与状态缓存。
    # 说明：为上层训练/评估流程提供辅助支持。
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        eta: float = 1e-3,
        loss_mode: str = "ce",
        max_grad_norm: float = 5.0,
        tbptt_steps: int | None = None,
        time_normalization: bool = True,
        seed: int | None = None,
        device: torch.device | str | None = None,
    ) -> None:
        super().__init__()
        if loss_mode not in {"ce", "mse"}:
            raise ValueError("loss_mode must be 'ce' or 'mse'.")
        rng = np.random.default_rng(seed)
        self.device = resolve_device(device)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.eta = float(eta)
        self.loss_mode = loss_mode
        self.max_grad_norm = float(max_grad_norm)
        self.step_weights: torch.Tensor | None = None
        self.tbptt_steps = None if tbptt_steps is None else int(tbptt_steps)
        self.time_normalization = bool(time_normalization)

        input_scale = min(0.1, 1.0 / math.sqrt(max(1, input_size)))
        self.W_xh = torch.nn.Parameter(
            to_tensor(
                rng.standard_normal((hidden_size, input_size)).astype(np.float32) * input_scale,
                self.device,
            )
        )
        self.W_hh = torch.nn.Parameter(
            to_tensor(rng.standard_normal((hidden_size, hidden_size)).astype(np.float32), self.device)
        )
        self.b_h = torch.nn.Parameter(torch.zeros((hidden_size, 1), dtype=torch.float32, device=self.device))
        self.W_hy = torch.nn.Parameter(
            to_tensor(rng.standard_normal((output_size, hidden_size)).astype(np.float32) * 0.1, self.device)
        )
        self.b_y = torch.nn.Parameter(torch.zeros((output_size, 1), dtype=torch.float32, device=self.device))
        self.optimizer = torch.optim.SGD(self.parameters(), lr=self.eta)

    # reset_optimizer：重置优化器状态（动量/梯度）。
    # 说明：为上层训练/评估流程提供辅助支持。
    def reset_optimizer(self) -> None:
        self.optimizer = torch.optim.SGD(self.parameters(), lr=self.eta)

    # initialize_weights_with_gain：按指定增益初始化权重矩阵。
    # 说明：用于软标签平滑、分块训练与正则约束。
    def initialize_weights_with_gain(self, g: float, seed: int | None = None) -> None:
        rng = np.random.default_rng(seed)
        std_dev = g / math.sqrt(self.hidden_size)
        new_w_hh = to_tensor(
            rng.standard_normal(self.W_hh.shape).astype(np.float32) * std_dev,
            self.device,
        )
        with torch.no_grad():
            self.W_hh.copy_(new_w_hh)

    # train_batch：训练单个 batch 并返回损失/指标。
    # 说明：包含前向、损失、更新与指标统计。
    def train_batch(
        self,
        inputs_batch: np.ndarray | torch.Tensor,
        targets_batch: np.ndarray | torch.Tensor,
        h_prev_batch: np.ndarray | torch.Tensor,
    ) -> Tuple[float, torch.Tensor]:
        inputs = to_tensor(inputs_batch, self.device, dtype=torch.float32)
        targets = to_tensor(targets_batch, self.device, dtype=torch.float32)
        h_prev = to_tensor(h_prev_batch, self.device, dtype=torch.float32)
        time_steps = int(inputs.shape[2])
        step_weights = self.step_weights
        if step_weights is not None and not torch.is_tensor(step_weights):
            step_weights = to_tensor(step_weights, self.device, dtype=torch.float32)
        weight_norm = float(step_weights.sum().item()) if step_weights is not None else float(time_steps)
        weight_norm = max(weight_norm, 1.0)

        tbptt_steps = 0 if self.tbptt_steps is None else int(self.tbptt_steps)
        use_tbptt = 0 < tbptt_steps < time_steps
        if not use_tbptt:
            self.optimizer.zero_grad(set_to_none=True)
            total_loss = torch.zeros((), device=self.device)
            h_state = h_prev

            for t in range(time_steps):
                step_weight = step_weights[t] if step_weights is not None else 1.0
                I_t = inputs[:, :, t].T
                x_t = self.W_hh @ h_state + self.W_xh @ I_t + self.b_h
                h_state = torch.tanh(x_t)
                y_hat_t = self.W_hy @ h_state + self.b_y
                y_true_t = targets[:, :, t].T

                if self.loss_mode == "ce":
                    log_probs = torch.log_softmax(y_hat_t, dim=0)
                    loss_t = -(y_true_t * log_probs).sum(dim=0).mean()
                else:
                    error = y_hat_t - y_true_t
                    loss_t = 0.5 * torch.sum(error**2, dim=0).mean()

                if step_weights is not None:
                    loss_t = loss_t * step_weight
                total_loss = total_loss + loss_t

            loss_scale = weight_norm if self.time_normalization else 1.0
            scaled_loss = total_loss / max(loss_scale, 1.0)
            scaled_loss.backward()
            if self.max_grad_norm > 0:
                clip_norm = self.max_grad_norm
                if not self.time_normalization:
                    clip_norm = self.max_grad_norm * math.sqrt(weight_norm)
                torch.nn.utils.clip_grad_norm_(self.parameters(), clip_norm)
            self.optimizer.step()

            avg_loss = float(scaled_loss.item())
            return avg_loss, h_state.detach()

        total_loss_value = 0.0
        h_state = h_prev
        for start in range(0, time_steps, tbptt_steps):
            self.optimizer.zero_grad(set_to_none=True)
            chunk_loss = torch.zeros((), device=self.device)
            end = min(time_steps, start + tbptt_steps)
            for t in range(start, end):
                step_weight = step_weights[t] if step_weights is not None else 1.0
                I_t = inputs[:, :, t].T
                x_t = self.W_hh @ h_state + self.W_xh @ I_t + self.b_h
                h_state = torch.tanh(x_t)
                y_hat_t = self.W_hy @ h_state + self.b_y
                y_true_t = targets[:, :, t].T

                if self.loss_mode == "ce":
                    log_probs = torch.log_softmax(y_hat_t, dim=0)
                    loss_t = -(y_true_t * log_probs).sum(dim=0).mean()
                else:
                    error = y_hat_t - y_true_t
                    loss_t = 0.5 * torch.sum(error**2, dim=0).mean()

                if step_weights is not None:
                    loss_t = loss_t * step_weight
                chunk_loss = chunk_loss + loss_t

            if self.time_normalization:
                if step_weights is not None:
                    chunk_scale = float(step_weights[start:end].sum().item())
                else:
                    chunk_scale = float(end - start)
                chunk_scale = max(chunk_scale, 1.0)
                chunk_loss_scale = chunk_scale
                chunk_clip_scale = chunk_scale
            else:
                chunk_loss_scale = 1.0
                if step_weights is not None:
                    chunk_clip_scale = float(step_weights[start:end].sum().item())
                else:
                    chunk_clip_scale = float(end - start)
                chunk_clip_scale = max(chunk_clip_scale, 1.0)
            scaled_chunk_loss = chunk_loss / max(chunk_loss_scale, 1.0)
            scaled_chunk_loss.backward()
            if self.max_grad_norm > 0:
                clip_norm = self.max_grad_norm
                if not self.time_normalization:
                    clip_norm = self.max_grad_norm * math.sqrt(chunk_clip_scale)
                torch.nn.utils.clip_grad_norm_(self.parameters(), clip_norm)
            self.optimizer.step()

            total_loss_value += float(chunk_loss.detach().item())
            h_state = h_state.detach()

        avg_loss_scale = weight_norm if self.time_normalization else 1.0
        avg_loss = total_loss_value / max(avg_loss_scale, 1.0)
        return avg_loss, h_state.detach()

    # forward_cycle：执行完整时间序列循环的前向与状态更新。
    # 说明：更新隐藏状态并返回输出/损失等结果。
    def forward_cycle(
        self, inputs_cycle: np.ndarray | torch.Tensor, h_prev_cycle: np.ndarray | torch.Tensor
    ) -> Tuple[List[torch.Tensor], torch.Tensor]:
        inputs = to_tensor(inputs_cycle, self.device, dtype=torch.float32)
        h_prev = to_tensor(h_prev_cycle, self.device, dtype=torch.float32)
        outputs: List[torch.Tensor] = []
        with torch.no_grad():
            for t in range(inputs.shape[2]):
                I_t = inputs[:, :, t].T
                x_t = self.W_hh @ h_prev + self.W_xh @ I_t + self.b_h
                h_prev = torch.tanh(x_t)
                outputs.append(self.W_hy @ h_prev + self.b_y)
        return outputs, h_prev


BPTTRNN = TorchBPTTRNN


# =============================
# Evaluation helpers
# =============================

# evaluate_classifier_final_step：评估最后时间步的分类准确率。
# 说明：只评估最后时间步的预测准确率。
def evaluate_classifier_final_step(
    model: Any,
    inputs: np.ndarray | torch.Tensor,
    targets: np.ndarray | torch.Tensor,
    labels: np.ndarray | torch.Tensor,
    batch_size: int,
) -> Tuple[float, float]:
    total_loss = 0.0
    total_samples = 0
    correct = 0
    eps = 1e-12

    device = DEFAULT_DEVICE
    if hasattr(model, "device"):
        device = resolve_device(getattr(model, "device"))
    if hasattr(model, "W_hh") and torch.is_tensor(getattr(model, "W_hh")):
        device = getattr(model, "W_hh").device
    inputs_t = to_tensor(inputs, device=device, dtype=torch.float32)
    targets_t = to_tensor(targets, device=device, dtype=torch.float32)
    labels_t = to_tensor(labels, device=device, dtype=torch.long)

    for start_idx in range(0, inputs_t.shape[0], batch_size):
        end_idx = min(start_idx + batch_size, inputs_t.shape[0])
        input_batch = inputs_t[start_idx:end_idx]
        target_batch = targets_t[start_idx:end_idx]
        label_batch = labels_t[start_idx:end_idx]
        current_batch = int(input_batch.shape[0])

        h_prev_eval = torch.zeros((model.hidden_size, current_batch), dtype=torch.float32, device=device)
        outputs_seq, _ = model.forward_cycle(input_batch, h_prev_eval)
        outputs_stack = torch.stack(outputs_seq, dim=2)
        outputs_batch = outputs_stack.permute(1, 0, 2)

        final_logits = outputs_batch[:, :, -1]
        probs = torch.softmax(final_logits, dim=1)
        y_final = target_batch[:, :, -1]
        ce = -torch.sum(y_final * torch.log(probs + eps), dim=1)
        total_loss += float(ce.mean().item()) * current_batch

        preds = torch.argmax(final_logits, dim=1)
        correct += int((preds == label_batch).sum().item())
        total_samples += current_batch

    avg_loss = total_loss / max(total_samples, 1)
    acc = correct / max(total_samples, 1)
    return avg_loss, acc


# evaluate_regression_mse：teacher forcing 下计算回归 MSE。
# 说明：teacher forcing 输出与真值的均方误差。
def evaluate_regression_mse(
    model: Any,
    inputs: np.ndarray | torch.Tensor,
    targets: np.ndarray | torch.Tensor,
    batch_size: int,
    step_weights: np.ndarray | torch.Tensor | None = None,
) -> float:
    total_loss = 0.0
    total_samples = 0
    weights_t = None
    weight_norm = 1.0
    device = DEFAULT_DEVICE
    if hasattr(model, "device"):
        device = resolve_device(getattr(model, "device"))
    if hasattr(model, "W_hh") and torch.is_tensor(getattr(model, "W_hh")):
        device = getattr(model, "W_hh").device
    if step_weights is not None:
        weights_t = to_tensor(step_weights, device=device, dtype=torch.float32)
        weight_norm = float(weights_t.sum().item())
        weight_norm = max(weight_norm, 1.0)

    inputs_t = to_tensor(inputs, device=device, dtype=torch.float32)
    targets_t = to_tensor(targets, device=device, dtype=torch.float32)
    for start_idx in range(0, inputs_t.shape[0], batch_size):
        end_idx = min(start_idx + batch_size, inputs_t.shape[0])
        input_batch = inputs_t[start_idx:end_idx]
        target_batch = targets_t[start_idx:end_idx]
        current_batch = int(input_batch.shape[0])
        h_prev_eval = torch.zeros((model.hidden_size, current_batch), dtype=torch.float32, device=device)
        outputs_seq, _ = model.forward_cycle(input_batch, h_prev_eval)
        outputs_stack = torch.stack(outputs_seq, dim=2)
        targets_stack = target_batch.permute(1, 0, 2)
        error_sq = (outputs_stack - targets_stack) ** 2
        if weights_t is None:
            mse = float(error_sq.mean().item())
        else:
            if weights_t.shape[0] != error_sq.shape[2]:
                raise ValueError("step_weights length must match the sequence length.")
            mse = float(
                torch.sum(error_sq * weights_t.view(1, 1, -1))
                / (error_sq.shape[0] * error_sq.shape[1] * weight_norm)
            )
        total_loss += mse * current_batch
        total_samples += current_batch
    return total_loss / max(total_samples, 1)


# evaluate_regression_mse_rollout：rollout 预测下计算回归 MSE。
# 说明：rollout 预测与真值的均方误差。
def evaluate_regression_mse_rollout(
    model: Any,
    inputs: np.ndarray | torch.Tensor,
    targets: np.ndarray | torch.Tensor,
    batch_size: int,
    warmup_steps: int = 1,
    step_weights: np.ndarray | torch.Tensor | None = None,
) -> float:
    total_loss = 0.0
    total_samples = 0
    weights_t = None
    weight_norm = 1.0
    device = DEFAULT_DEVICE
    if hasattr(model, "device"):
        device = resolve_device(getattr(model, "device"))
    if hasattr(model, "W_hh") and torch.is_tensor(getattr(model, "W_hh")):
        device = getattr(model, "W_hh").device
    if step_weights is not None:
        weights_t = to_tensor(step_weights, device=device, dtype=torch.float32)
        weight_norm = float(weights_t.sum().item())
        weight_norm = max(weight_norm, 1.0)

    inputs_t = to_tensor(inputs, device=device, dtype=torch.float32)
    targets_t = to_tensor(targets, device=device, dtype=torch.float32)
    time_steps = int(inputs_t.shape[2])
    warmup_steps = max(1, min(int(warmup_steps), time_steps))
    for start_idx in range(0, inputs_t.shape[0], batch_size):
        end_idx = min(start_idx + batch_size, inputs_t.shape[0])
        input_batch = inputs_t[start_idx:end_idx]
        target_batch = targets_t[start_idx:end_idx]
        current_batch = int(input_batch.shape[0])
        outputs_stack = predict_sequence_rollout_batch(
            model,
            input_batch,
            warmup_steps=warmup_steps,
        )
        targets_stack = target_batch.permute(1, 0, 2)
        error_sq = (outputs_stack - targets_stack) ** 2
        if weights_t is None:
            mse = float(error_sq.mean().item())
        else:
            if weights_t.shape[0] != error_sq.shape[2]:
                raise ValueError("step_weights length must match the sequence length.")
            mse = float(
                torch.sum(error_sq * weights_t.view(1, 1, -1))
                / (error_sq.shape[0] * error_sq.shape[1] * weight_norm)
            )
        total_loss += mse * current_batch
        total_samples += current_batch
    return total_loss / max(total_samples, 1)


# batch_iterator：按 batch_size 组织数据迭代器并支持打乱。
# 说明：支持 shuffle 与批次切分，输出批数据。
def batch_iterator(
    data: np.ndarray | torch.Tensor,
    vocab_size: int,
    block_size: int,
    batch_size: int,
    steps: int,
    rng: np.random.Generator,
) -> Iterable[Tuple[np.ndarray | torch.Tensor, np.ndarray | torch.Tensor]]:
    N = data.shape[0]
    min_len = block_size + 1
    if N < min_len:
        raise ValueError(
            f"LM data too short for block_size={block_size}: got {N} tokens, need >= {min_len}. "
            "Check the dataset or reduce --ptb-block-size."
        )
    for _ in range(steps):
        starts = rng.integers(0, max(1, N - block_size - 1), size=(batch_size,), dtype=np.int64)
        if torch.is_tensor(data):
            device = data.device
            starts_t = torch.as_tensor(starts, device=device)
            offsets = torch.arange(block_size + 1, device=device)
            idx = starts_t[:, None] + offsets[None, :]
            seq = data[idx]
            x_tokens = seq[:, :-1]
            y_tokens = seq[:, 1:]
            x = torch.nn.functional.one_hot(x_tokens, num_classes=vocab_size).to(dtype=torch.float32)
            y = torch.nn.functional.one_hot(y_tokens, num_classes=vocab_size).to(dtype=torch.float32)
            yield x.permute(0, 2, 1), y.permute(0, 2, 1)
        else:
            idx = starts[:, None] + np.arange(block_size + 1)[None, :]
            seq = data[idx]
            x_tokens = seq[:, :-1]
            y_tokens = seq[:, 1:]
            eye = np.eye(vocab_size, dtype=np.float32)
            x = eye[x_tokens]
            y = eye[y_tokens]
            yield np.transpose(x, (0, 2, 1)), np.transpose(y, (0, 2, 1))


# evaluate_language_model：评估语言模型的损失与困惑度。
# 说明：计算 NLL/困惑度等指标。
def evaluate_language_model(
    model: Any,
    data: np.ndarray | torch.Tensor,
    vocab_size: int,
    block_size: int,
    batch_size: int,
    steps: int,
) -> float:
    total_loss = 0.0
    total_tokens = 0
    eps = 1e-12
    rng = np.random.default_rng(12345)

    device = DEFAULT_DEVICE
    if hasattr(model, "device"):
        device = resolve_device(getattr(model, "device"))
    if hasattr(model, "W_hh") and torch.is_tensor(getattr(model, "W_hh")):
        device = getattr(model, "W_hh").device

    for inputs, targets in batch_iterator(data, vocab_size, block_size, batch_size, steps, rng):
        inputs_t = to_tensor(inputs, device=device, dtype=torch.float32)
        targets_t = to_tensor(targets, device=device, dtype=torch.float32)
        current_batch = int(inputs_t.shape[0])
        h_prev = torch.zeros((model.hidden_size, current_batch), dtype=torch.float32, device=device)
        outputs_seq, _ = model.forward_cycle(inputs_t, h_prev)
        outputs_stack = torch.stack(outputs_seq, dim=2)
        outputs_batch = outputs_stack.permute(1, 0, 2)

        log_probs = torch.log_softmax(outputs_batch, dim=1)
        loss = -(targets_t * log_probs).sum(dim=1)
        batch_loss = float(loss.mean().item())

        total_loss += batch_loss * current_batch * outputs_batch.shape[2]
        total_tokens += current_batch * outputs_batch.shape[2]

    return total_loss / max(total_tokens, 1)

# =============================
# 数据加载与序列构造
# =============================

MNIST_NPZ = "mnist.npz"
PTB_URL_BASES = (
    "https://raw.githubusercontent.com/wojzaremba/lstm/master/data",
    "https://raw.githubusercontent.com/tomsercu/lstm/master/data",
)
PTB_FILES = ("ptb.train.txt", "ptb.valid.txt", "ptb.test.txt")


# load_mnist_images：加载 MNIST 图像及标签。
# 说明：包含读取/缓存、归一化、reshape/序列化等预处理。
def load_mnist_images(
    train_limit: int | None = None,
    test_limit: int | None = None,
    npz_path: str | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    candidate_paths: List[Path] = []
    if npz_path is not None:
        candidate_paths.append(Path(npz_path))
    candidate_paths.append(Path(MNIST_NPZ))
    candidate_paths.append(Path.home() / ".keras" / "datasets" / MNIST_NPZ)

    dataset = None
    for path in candidate_paths:
        if path.exists():
            with np.load(path) as npz_data:
                x_train = npz_data["x_train"]
                y_train = npz_data["y_train"]
                x_test = npz_data["x_test"]
                y_test = npz_data["y_test"]
            dataset = ((x_train, y_train), (x_test, y_test))
            print(f"Loaded MNIST data from '{path}'.")
            break

    if dataset is None:
        try:
            from torchvision.datasets import MNIST  # type: ignore
            train_set = MNIST(root="./data", train=True, download=True)
            test_set = MNIST(root="./data", train=False, download=True)
            dataset = (
                (np.array(train_set.data), np.array(train_set.targets)),
                (np.array(test_set.data), np.array(test_set.targets)),
            )
            print("Loaded MNIST data via torchvision.datasets.MNIST.")
        except Exception:
            try:
                from tensorflow.keras.datasets import mnist  # type: ignore
            except Exception as exc:
                raise RuntimeError(
                    "MNIST dataset not found locally and TensorFlow is unavailable. "
                    "Place 'mnist.npz' next to this script or install tensorflow."
                ) from exc
            dataset = mnist.load_data()
            print("Loaded MNIST data via tensorflow.keras.datasets.mnist.")

    (train_images, train_labels), (test_images, test_labels) = dataset

    if train_limit is not None:
        train_images = train_images[:train_limit]
        train_labels = train_labels[:train_limit]
    if test_limit is not None:
        test_images = test_images[:test_limit]
        test_labels = test_labels[:test_limit]

    train_images = train_images.astype(np.float32) / 255.0
    test_images = test_images.astype(np.float32) / 255.0
    train_images, test_images = normalize_train_test(train_images, test_images)

    return train_images, train_labels.astype(np.int64), test_images, test_labels.astype(np.int64)


# load_permuted_mnist_sequences：加载并构造 permuted MNIST 序列数据。
# 说明：包含读取/缓存、归一化、reshape/序列化等预处理。
def load_permuted_mnist_sequences(
    train_limit: int | None = None,
    test_limit: int | None = None,
    permute_seed: int = 1234,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    train_images, train_labels, test_images, test_labels = load_mnist_images(
        train_limit=train_limit, test_limit=test_limit
    )

    rng = np.random.default_rng(permute_seed)
    perm = rng.permutation(28 * 28)
    train_flat = train_images.reshape(train_images.shape[0], -1)[:, perm]
    test_flat = test_images.reshape(test_images.shape[0], -1)[:, perm]

    train_inputs = train_flat[:, None, :].astype(np.float32)
    test_inputs = test_flat[:, None, :].astype(np.float32)
    time_steps = train_inputs.shape[2]
    train_targets = build_repeated_targets(train_labels, 10, time_steps)
    test_targets = build_repeated_targets(test_labels, 10, time_steps)

    return train_inputs, train_targets, train_labels, test_inputs, test_targets, test_labels


# load_row_cifar10_sequences：将 CIFAR-10 按“行”展开成序列输入（time_steps=H，input_size=W*C）。
def load_row_cifar10_sequences(
    train_limit: int | None = None,
    test_limit: int | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    train_images, train_labels, test_images, test_labels = load_cifar10_images(
        train_limit=train_limit,
        test_limit=test_limit,
    )

    if train_images.ndim != 4:
        raise RuntimeError(f"Unexpected CIFAR-10 image shape: {train_images.shape}")

    # torchvision: (N, H, W, C). Fallback: (N, C, H, W).
    if train_images.shape[-1] == 3:
        n_train, height, width, channels = train_images.shape
        n_test = test_images.shape[0]
        train_rows = train_images.reshape(n_train, height, width * channels)
        test_rows = test_images.reshape(n_test, height, width * channels)
    else:
        n_train, channels, height, width = train_images.shape
        n_test = test_images.shape[0]
        train_rows = np.transpose(train_images, (0, 2, 1, 3)).reshape(n_train, height, width * channels)
        test_rows = np.transpose(test_images, (0, 2, 1, 3)).reshape(n_test, height, width * channels)

    train_inputs = np.transpose(train_rows, (0, 2, 1)).astype(np.float32)
    test_inputs = np.transpose(test_rows, (0, 2, 1)).astype(np.float32)
    time_steps = int(train_inputs.shape[2])
    train_targets = build_repeated_targets(train_labels, 10, time_steps)
    test_targets = build_repeated_targets(test_labels, 10, time_steps)
    return train_inputs, train_targets, train_labels, test_inputs, test_targets, test_labels


# load_cifar10_images：加载 CIFAR-10 图像及标签。
# 说明：包含读取/缓存、归一化、reshape/序列化等预处理。
def load_cifar10_images(
    train_limit: int | None = None,
    test_limit: int | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    dataset = None
    try:
        from torchvision.datasets import CIFAR10  # type: ignore
        train_set = CIFAR10(root="./data", train=True, download=True)
        test_set = CIFAR10(root="./data", train=False, download=True)
        dataset = (train_set.data, np.array(train_set.targets), test_set.data, np.array(test_set.targets))
    except Exception:
        try:
            from tensorflow.keras.datasets import cifar10  # type: ignore
        except Exception as exc:
            raise RuntimeError(
                "CIFAR-10 dataset not available. Install torchvision or tensorflow."
            ) from exc
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        dataset = (x_train, y_train.squeeze(), x_test, y_test.squeeze())

    x_train, y_train, x_test, y_test = dataset

    if train_limit is not None:
        x_train = x_train[:train_limit]
        y_train = y_train[:train_limit]
    if test_limit is not None:
        x_test = x_test[:test_limit]
        y_test = y_test[:test_limit]

    x_train = x_train.astype(np.float32) / 255.0
    x_test = x_test.astype(np.float32) / 255.0
    x_train, x_test = normalize_train_test(x_train, x_test)

    return x_train, y_train.astype(np.int64), x_test, y_test.astype(np.int64)


# load_sequential_cifar10_sequences：将 CIFAR-10 展开为时间序列输入。
# 说明：包含读取/缓存、归一化、reshape/序列化等预处理。
def load_sequential_cifar10_sequences(
    train_limit: int | None = None,
    test_limit: int | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    train_images, train_labels, test_images, test_labels = load_cifar10_images(
        train_limit=train_limit, test_limit=test_limit
    )

    train_seq = train_images.reshape(train_images.shape[0], -1, 3)
    test_seq = test_images.reshape(test_images.shape[0], -1, 3)

    train_inputs = np.transpose(train_seq, (0, 2, 1)).astype(np.float32)
    test_inputs = np.transpose(test_seq, (0, 2, 1)).astype(np.float32)
    time_steps = train_inputs.shape[2]
    train_targets = build_repeated_targets(train_labels, 10, time_steps)
    test_targets = build_repeated_targets(test_labels, 10, time_steps)

    return train_inputs, train_targets, train_labels, test_inputs, test_targets, test_labels


def prepare_uci_har_npz(
    npz_path: str,
    root: str | None = None,
    compress: bool = True,
) -> str:
    root_dir = root or os.path.join("data", "uci_har")
    os.makedirs(root_dir, exist_ok=True)
    out_dir = os.path.dirname(npz_path)
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)

    url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00240/UCI%20HAR%20Dataset.zip"
    zip_path = os.path.join(root_dir, "UCI_HAR_Dataset.zip")
    dataset_dir = os.path.join(root_dir, "UCI HAR Dataset")

    def safe_extract(zf: zipfile.ZipFile, dest: str) -> None:
        dest_path = Path(dest).resolve()
        for member in zf.infolist():
            member_path = (dest_path / member.filename).resolve()
            try:
                member_path.relative_to(dest_path)
            except ValueError as exc:
                raise RuntimeError(f"Unsafe path in zip: {member.filename}") from exc
        zf.extractall(dest_path)

    def download_with_retry(url: str, out_path: str, retries: int = 3, timeout: int = 60) -> None:
        last_exc: Exception | None = None
        for attempt in range(retries):
            try:
                req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
                with urllib.request.urlopen(req, timeout=timeout) as response, open(out_path, "wb") as f:
                    f.write(response.read())
                return
            except Exception as exc:
                last_exc = exc
                if attempt < retries - 1:
                    wait = 2**attempt
                    print(f"[HAR] Download failed ({exc}). Retrying in {wait}s...")
                    time.sleep(wait)
        if last_exc is not None:
            raise last_exc

    if not os.path.isdir(dataset_dir):
        if not os.path.exists(zip_path) or os.path.getsize(zip_path) <= 0:
            print(f"[HAR] Downloading {url}")
            download_with_retry(url, zip_path)
        try:
            with zipfile.ZipFile(zip_path) as zf:
                safe_extract(zf, root_dir)
        except zipfile.BadZipFile as exc:
            raise RuntimeError(
                f"Downloaded UCI HAR zip is corrupted: {zip_path}. Delete it and retry."
            ) from exc

    signal_names = [
        "body_acc_x",
        "body_acc_y",
        "body_acc_z",
        "body_gyro_x",
        "body_gyro_y",
        "body_gyro_z",
        "total_acc_x",
        "total_acc_y",
        "total_acc_z",
    ]

    def load_signals(split: str) -> np.ndarray:
        base = os.path.join(dataset_dir, split, "Inertial Signals")
        signals = []
        for name in signal_names:
            path = os.path.join(base, f"{name}_{split}.txt")
            if not os.path.exists(path):
                raise RuntimeError(f"UCI HAR missing file: {path}")
            signals.append(np.loadtxt(path, dtype=np.float32))
        return np.stack(signals, axis=1).astype(np.float32)

    def load_labels(split: str) -> np.ndarray:
        path = os.path.join(dataset_dir, split, f"y_{split}.txt")
        if not os.path.exists(path):
            raise RuntimeError(f"UCI HAR missing labels: {path}")
        labels = np.loadtxt(path, dtype=np.int64).reshape(-1)
        return (labels - 1).astype(np.int64)

    def load_subjects(split: str) -> np.ndarray:
        path = os.path.join(dataset_dir, split, f"subject_{split}.txt")
        if not os.path.exists(path):
            raise RuntimeError(f"UCI HAR missing subject ids: {path}")
        return np.loadtxt(path, dtype=np.int64).reshape(-1).astype(np.int64)

    x_train = load_signals("train")
    y_train = load_labels("train")
    subject_train = load_subjects("train")
    x_test = load_signals("test")
    y_test = load_labels("test")
    subject_test = load_subjects("test")

    if compress:
        np.savez_compressed(
            npz_path,
            x_train=x_train,
            y_train=y_train,
            subject_train=subject_train,
            x_test=x_test,
            y_test=y_test,
            subject_test=subject_test,
        )
    else:
        np.savez(
            npz_path,
            x_train=x_train,
            y_train=y_train,
            subject_train=subject_train,
            x_test=x_test,
            y_test=y_test,
            subject_test=subject_test,
        )
    print(f"[HAR] Saved prepared dataset to {npz_path}")
    return npz_path


def load_uci_har_subject_ids(
    npz_path: str | None,
    train_limit: int | None,
    test_limit: int | None,
    auto_download: bool = True,
    root: str | None = None,
) -> Tuple[np.ndarray, np.ndarray]:
    default_root = root or os.path.join("data", "uci_har")
    candidate = npz_path or os.path.join(default_root, "uci_har.npz")
    if not os.path.exists(candidate):
        if auto_download:
            prepare_uci_har_npz(candidate, root=root)
        else:
            raise RuntimeError("UCI HAR npz not found. Provide a file via --har-npz, or enable auto-download.")

    data = np.load(candidate)
    subject_train = data.get("subject_train")
    subject_test = data.get("subject_test")
    if subject_train is None or subject_test is None:
        dataset_dir = os.path.join(default_root, "UCI HAR Dataset")
        train_path = os.path.join(dataset_dir, "train", "subject_train.txt")
        test_path = os.path.join(dataset_dir, "test", "subject_test.txt")
        if not (os.path.exists(train_path) and os.path.exists(test_path)):
            raise RuntimeError(
                "UCI HAR subject ids are missing. Delete the existing npz and rerun to regenerate, "
                "or keep the extracted dataset under data/uci_har/."
            )
        subject_train = np.loadtxt(train_path, dtype=np.int64).reshape(-1)
        subject_test = np.loadtxt(test_path, dtype=np.int64).reshape(-1)

    if train_limit is not None:
        subject_train = subject_train[:train_limit]
    if test_limit is not None:
        subject_test = subject_test[:test_limit]

    return subject_train.astype(np.int64), subject_test.astype(np.int64)


def load_uci_har_sequences(
    npz_path: str | None,
    train_limit: int | None,
    test_limit: int | None,
    auto_download: bool = True,
    root: str | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    default_root = root or os.path.join("data", "uci_har")
    candidate = npz_path or os.path.join(default_root, "uci_har.npz")
    if not os.path.exists(candidate):
        if auto_download:
            prepare_uci_har_npz(candidate, root=root)
        else:
            raise RuntimeError(
                "UCI HAR npz not found. Provide a preprocessed file via --har-npz, or enable auto-download."
            )

    data = np.load(candidate)
    x_train = data.get("x_train")
    y_train = data.get("y_train")
    x_test = data.get("x_test")
    y_test = data.get("y_test")
    if x_train is None or y_train is None or x_test is None or y_test is None:
        raise RuntimeError("UCI HAR npz missing required arrays: x_train/y_train/x_test/y_test.")

    if train_limit is not None:
        x_train = x_train[:train_limit]
        y_train = y_train[:train_limit]
    if test_limit is not None:
        x_test = x_test[:test_limit]
        y_test = y_test[:test_limit]

    x_train, x_test = normalize_train_test(x_train.astype(np.float32), x_test.astype(np.float32))
    time_steps = int(x_train.shape[2])
    num_classes = int(max(np.max(y_train), np.max(y_test))) + 1
    train_targets = build_repeated_targets(y_train.astype(np.int64), num_classes, time_steps)
    test_targets = build_repeated_targets(y_test.astype(np.int64), num_classes, time_steps)
    return x_train, train_targets, y_train.astype(np.int64), x_test, test_targets, y_test.astype(np.int64)


# try_load_ptb_text：尝试从本地或下载读取 PTB 文本。
# 说明：优先读本地，失败再下载并缓存。
def try_load_ptb_text(data_path: str | None) -> Tuple[str, str, str]:
    # read_file：读取文本文件内容并返回字符串。
    # 说明：按文本方式读取并返回字符串。
    def read_file(p: str) -> str:
        with open(p, "r", encoding="utf-8") as f:
            return f.read()

    # is_nonempty_file：判断文件是否存在且非空。
    # 说明：检查文件大小是否大于 0。
    def is_nonempty_file(p: str) -> bool:
        try:
            return os.path.getsize(p) > 0
        except OSError:
            return False

    # try_local：尝试在本地路径读取文件。
    # 说明：检查路径存在后读取文件。
    def try_local(dirpath: str | None) -> Tuple[str, str, str] | None:
        if dirpath is None:
            return None
        train_p = os.path.join(dirpath, "ptb.train.txt")
        valid_p = os.path.join(dirpath, "ptb.valid.txt")
        test_p = os.path.join(dirpath, "ptb.test.txt")
        if os.path.exists(train_p) and os.path.exists(valid_p) and os.path.exists(test_p):
            if not (is_nonempty_file(train_p) and is_nonempty_file(valid_p) and is_nonempty_file(test_p)):
                print(f"[PTB] Found empty PTB file(s) under '{dirpath}/'. Ignoring.")
                return None
            print(f"Loaded PTB from '{dirpath}/'.")
            return read_file(train_p), read_file(valid_p), read_file(test_p)
        return None

    for cand in [data_path, "data/penn", "data/ptb", "ptb", "penn", "./data", "./"]:
        got = try_local(cand)
        if got is not None:
            return got

    # download_with_retry：带重试机制下载文件并返回本地路径。
    # 说明：多次重试下载并回退到本地缓存。
    def download_with_retry(url: str, out_path: str, retries: int = 3, timeout: int = 30) -> None:
        last_exc: Exception | None = None
        for attempt in range(retries):
            try:
                req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
                with urllib.request.urlopen(req, timeout=timeout) as response, open(out_path, "wb") as f:
                    f.write(response.read())
                return
            except Exception as exc:
                last_exc = exc
                if attempt < retries - 1:
                    wait = 2**attempt
                    print(f"[PTB] Download failed ({exc}). Retrying in {wait}s...")
                    time.sleep(wait)
        if last_exc is not None:
            raise last_exc

    # download_ptb_files：下载 Penn Treebank 数据集文件。
    # 说明：下载 PTB 的 train/valid/test 文本。
    def download_ptb_files(target_dir: str) -> Tuple[str, str, str]:
        os.makedirs(target_dir, exist_ok=True)
        for fname in PTB_FILES:
            out_path = os.path.join(target_dir, fname)
            if is_nonempty_file(out_path):
                continue
            last_exc: Exception | None = None
            for base in PTB_URL_BASES:
                url = f"{base}/{fname}"
                print(f"[PTB] Downloading {url}")
                try:
                    download_with_retry(url, out_path)
                    last_exc = None
                    break
                except Exception as exc:
                    last_exc = exc
                    continue
            if last_exc is not None:
                raise last_exc
        return (
            read_file(os.path.join(target_dir, "ptb.train.txt")),
            read_file(os.path.join(target_dir, "ptb.valid.txt")),
            read_file(os.path.join(target_dir, "ptb.test.txt")),
        )

    if ensure_python_package("datasets"):
        try:
            from datasets import load_dataset  # type: ignore
            ds = load_dataset("ptb_text_only")

            field = None
            for candidate in ("sentence", "text"):
                if candidate in ds["train"].column_names:
                    field = candidate
                    break
            if field is None:
                raise RuntimeError(f"Unexpected PTB dataset schema: {ds['train'].column_names}")

            train_text = "\n".join(ds["train"][field])
            valid_text = "\n".join(ds["validation"][field])
            test_text = "\n".join(ds["test"][field])
            print("Loaded PTB from HuggingFace datasets (ptb_text_only).")
            return train_text, valid_text, test_text
        except Exception as exc:
            print(f"[PTB] HuggingFace datasets failed ({exc}). Falling back to direct download.")

    target_dir = data_path or os.path.join("data", "ptb")
    try:
        train_text, valid_text, test_text = download_ptb_files(target_dir)
        print(f"Loaded PTB via direct download into '{target_dir}'.")
        return train_text, valid_text, test_text
    except Exception as exc:
        raise RuntimeError(
            "PTB files not found locally and auto-download failed.\n"
            "Provide ptb.train.txt/ptb.valid.txt/ptb.test.txt under a folder (use --ptb-path),\n"
            "or install datasets: pip install datasets"
        ) from exc


# build_char_vocab：构建字符词表与 char<->id 映射。
# 说明：为上层训练/评估流程提供辅助支持。
def build_char_vocab(text: str) -> Tuple[Dict[str, int], Dict[int, str]]:
    chars = sorted(list(set(text)))
    stoi = {ch: i for i, ch in enumerate(chars)}
    itos = {i: ch for ch, i in stoi.items()}
    return stoi, itos


# encode_text：将文本按字符映射为 ID 序列。
# 说明：输出字符映射与 ID 序列，支持后续建模。
def encode_text(text: str, stoi: Dict[str, int]) -> np.ndarray:
    return np.fromiter((stoi[ch] for ch in text if ch in stoi), dtype=np.int64)


# to_sequence_features：将序列整理为模型输入特征。
# 说明：执行数值/结构变换并保持形状与数值稳定。
def to_sequence_features(data: np.ndarray) -> np.ndarray:
    if data.ndim == 3:
        if data.shape[1] < data.shape[2]:
            return data.astype(np.float32)
        return np.transpose(data, (0, 2, 1)).astype(np.float32)
    if data.ndim == 4:
        n, t, h, w = data.shape
        return data.reshape(n, t, h * w).transpose(0, 2, 1).astype(np.float32)
    if data.ndim == 5:
        # Support both (N, T, H, W, C) and (N, T, C, H, W).
        n, t, a, b, c = data.shape
        if c <= 4 and a > 4 and b > 4:
            # (N, T, H, W, C)
            return data.reshape(n, t, a * b * c).transpose(0, 2, 1).astype(np.float32)
        if a <= 4 and b > 4 and c > 4:
            # (N, T, C, H, W)
            return data.reshape(n, t, a * b * c).transpose(0, 2, 1).astype(np.float32)
        if a in (1, 2, 3, 4):
            return data.reshape(n, t, a * b * c).transpose(0, 2, 1).astype(np.float32)
        if c in (1, 2, 3, 4):
            return data.reshape(n, t, a * b * c).transpose(0, 2, 1).astype(np.float32)
        return data.reshape(n, t, a * b * c).transpose(0, 2, 1).astype(np.float32)
    raise ValueError(f"Unsupported DVS data shape: {data.shape}")


# prepare_dvs_cifar10_npz：预处理 DVS-CIFAR10 并保存为 npz。
# 说明：为上层训练/评估流程提供辅助支持。
def prepare_dvs_cifar10_npz(
    npz_path: str,
    root: str | None = None,
    time_bins: int = 10,
    spatial_downsample: int = 1,
    use_polarity: bool = True,
    train_limit: int | None = None,
    test_limit: int | None = None,
    compress: bool = True,
) -> str:
    time_bins = max(1, int(time_bins))
    spatial_downsample = max(1, int(spatial_downsample))
    if not ensure_python_package("tonic"):
        raise RuntimeError(
            "Auto-download for DVS-CIFAR10 requires the tonic package.\n"
            "Install it with: pip install tonic"
        )
    if not ensure_python_package("aedat"):
        raise RuntimeError(
            "DVS-CIFAR10 .aedat4 parsing requires the aedat package.\n"
            "Install it with: pip install aedat"
        )

    import tonic.datasets as tonic_datasets  # type: ignore
    from tonic.download_utils import extract_archive  # type: ignore
    from tonic.transforms import ToFrame  # type: ignore

    root_dir = root or os.path.join("data", "dvs_cifar10")
    os.makedirs(root_dir, exist_ok=True)
    out_dir = os.path.dirname(npz_path)
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)

    dataset_cls = None
    for name in ("DVSCIFAR10", "CIFAR10DVS", "DVSCifar10", "cifar10dvs"):
        candidate = getattr(tonic_datasets, name, None)
        if candidate is None or not callable(candidate):
            continue
        dataset_cls = candidate
        break
    if dataset_cls is None:
        raise RuntimeError(
            "Could not find DVS-CIFAR10 dataset in tonic.datasets. "
            "Try upgrading tonic: pip install -U tonic"
        )
    dataset_url = getattr(dataset_cls, "url", None)
    if isinstance(dataset_url, str) and dataset_url.startswith("https://figshare.com/ndownloader/files/"):
        file_id = dataset_url.rstrip("/").split("/")[-1]
        alt_url = f"https://ndownloader.figshare.com/files/{file_id}"
        if alt_url != dataset_url:
            print(f"[DVS] Figshare URL blocked by WAF; using direct downloader: {alt_url}")
            setattr(dataset_cls, "url", alt_url)

    dataset_sig = inspect.signature(dataset_cls)
    params = dataset_sig.parameters
    root_param = None
    for key in ("root", "data_path", "path", "root_dir", "data_dir", "data_root", "save_to", "base_folder"):
        if key in params:
            root_param = key
            break
    split_param = None
    for key in ("train", "train_split", "split", "mode", "set"):
        if key in params:
            split_param = key
            break
    download_param = None
    for key in ("download", "download_and_prepare"):
        if key in params:
            download_param = key
            break
    positional_params = [
        p
        for p in params.values()
        if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
    ]
    required_params = [p for p in positional_params if p.default is inspect._empty]

    # build_dataset：将特征/标签打包为统一数据集对象。
    # 说明：为上层训练/评估流程提供辅助支持。
    def build_dataset(train: bool | None) -> Any:
        kwargs: Dict[str, Any] = {}
        args: List[Any] = []
        if root_param is not None:
            kwargs[root_param] = root_dir
        elif required_params:
            args.append(root_dir)
        if train is not None and split_param is not None:
            if split_param == "train":
                kwargs[split_param] = train
            else:
                kwargs[split_param] = "train" if train else "test"
        if download_param is not None:
            kwargs[download_param] = True
        return dataset_cls(*args, **kwargs)

    # split_full_dataset：按比例切分训练/验证/测试集。
    # 说明：返回 train/val/test 的索引或子集。
    def split_full_dataset(
        dataset: Any,
        train_fraction: float = 0.9,
        seed: int = 0,
    ) -> Tuple[np.ndarray, np.ndarray]:
        total = len(dataset)
        if total <= 0:
            raise RuntimeError("DVS-CIFAR10 dataset is empty.")
        targets = getattr(dataset, "targets", None)
        rng = np.random.default_rng(seed)
        if targets is None or len(targets) != total:
            indices = np.arange(total)
            rng.shuffle(indices)
            split_idx = int(total * train_fraction)
            split_idx = min(max(1, split_idx), total - 1)
            return indices[:split_idx], indices[split_idx:]

        targets = np.asarray(targets)
        train_indices: List[int] = []
        test_indices: List[int] = []
        for label in np.unique(targets):
            label_indices = np.where(targets == label)[0]
            rng.shuffle(label_indices)
            split_idx = int(len(label_indices) * train_fraction)
            split_idx = min(max(1, split_idx), len(label_indices) - 1)
            train_indices.extend(label_indices[:split_idx])
            test_indices.extend(label_indices[split_idx:])
        train_array = np.array(train_indices, dtype=np.int64)
        test_array = np.array(test_indices, dtype=np.int64)
        rng.shuffle(train_array)
        rng.shuffle(test_array)
        return train_array, test_array

    # materialize：将数据集惰性对象物化为 numpy 数组。
    # 说明：一次性加载到内存，返回 numpy 数组。
    def materialize(
        dataset: Any,
        split: str,
        limit: int | None,
        indices: np.ndarray | None = None,
    ) -> Tuple[np.ndarray, np.ndarray]:
        if indices is None:
            indices = np.arange(len(dataset))
        indices = np.asarray(indices, dtype=np.int64)
        total = len(indices)
        count = total if limit is None else min(total, limit)
        if count <= 0:
            raise RuntimeError("Requested zero samples for DVS-CIFAR10 preparation.")

        def maybe_merge_polarity(frames: Any) -> Any:
            if use_polarity or not hasattr(frames, "ndim") or frames.ndim != 4:
                return frames
            # Handle both (T, C, H, W) and (T, H, W, C).
            if frames.shape[1] <= 4 and frames.shape[-1] > 4:
                return frames.sum(axis=1)
            if frames.shape[-1] <= 4 and frames.shape[1] > 4:
                return frames.sum(axis=-1)
            if frames.shape[1] in (1, 2, 3, 4):
                return frames.sum(axis=1)
            return frames.sum(axis=-1)

        first_frames, first_label = dataset[int(indices[0])]
        first_frames = maybe_merge_polarity(first_frames)

        def downsample_frames(frames: Any) -> np.ndarray:
            arr = np.asarray(frames)
            if spatial_downsample <= 1:
                return arr
            f = spatial_downsample

            if arr.ndim == 3:
                t, h, w = arr.shape
                if h % f != 0 or w % f != 0:
                    raise RuntimeError(f"Spatial downsample factor {f} does not divide frame size {(h, w)}.")
                reshaped = arr.reshape(t, h // f, f, w // f, f)
                summed = reshaped.sum(axis=(2, 4), dtype=np.uint32)
                return np.clip(summed, 0, np.iinfo(np.uint16).max).astype(np.uint16)

            if arr.ndim == 4:
                # Handle both (T, C, H, W) and (T, H, W, C).
                if arr.shape[1] <= 4 and arr.shape[2] > 4 and arr.shape[3] > 4:
                    t, c, h, w = arr.shape
                    if h % f != 0 or w % f != 0:
                        raise RuntimeError(f"Spatial downsample factor {f} does not divide frame size {(h, w)}.")
                    reshaped = arr.reshape(t, c, h // f, f, w // f, f)
                    summed = reshaped.sum(axis=(3, 5), dtype=np.uint32)
                    return np.clip(summed, 0, np.iinfo(np.uint16).max).astype(np.uint16)
                if arr.shape[-1] <= 4 and arr.shape[1] > 4 and arr.shape[2] > 4:
                    t, h, w, c = arr.shape
                    if h % f != 0 or w % f != 0:
                        raise RuntimeError(f"Spatial downsample factor {f} does not divide frame size {(h, w)}.")
                    reshaped = arr.reshape(t, h // f, f, w // f, f, c)
                    summed = reshaped.sum(axis=(2, 4), dtype=np.uint32)
                    return np.clip(summed, 0, np.iinfo(np.uint16).max).astype(np.uint16)

            raise RuntimeError(f"Unsupported frame shape for spatial downsample: {arr.shape}")

        first_frames = downsample_frames(first_frames)
        frames_shape = first_frames.shape
        x_data = np.zeros((count,) + frames_shape, dtype=np.uint16)
        y_data = np.zeros((count,), dtype=np.int64)
        x_data[0] = np.asarray(first_frames).astype(np.uint16)
        y_data[0] = int(first_label)
        for idx in range(1, count):
            frames, label = dataset[int(indices[idx])]
            frames = downsample_frames(maybe_merge_polarity(frames))
            x_data[idx] = np.asarray(frames).astype(np.uint16)
            y_data[idx] = int(label)
            if idx % 1000 == 0:
                print(f"[DVS] Processed {idx}/{count} {split} samples")
        return x_data, y_data

    print("[DVS] Preparing DVS-CIFAR10 npz. This may take a while.")
    if split_param is None:
        dataset = build_dataset(None)
        expected_total = 10000
        expected_classes = 10
        total = len(dataset)
        targets = getattr(dataset, "targets", None)
        unique_labels = set(int(x) for x in targets) if targets is not None else set()
        if total < 9000 or len(unique_labels) < expected_classes:
            print(
                f"[DVS] Detected incomplete CIFAR10-DVS dataset (samples={total}, classes={len(unique_labels)}). "
                "Attempting repair via re-extract..."
            )
            try:
                if hasattr(dataset, "download"):
                    dataset.download()
                data_filename = getattr(dataset_cls, "data_filename", None)
                base_dir = getattr(dataset, "location_on_system", None)
                if data_filename and base_dir:
                    for fname in data_filename:
                        archive_path = os.path.join(base_dir, fname)
                        if os.path.exists(archive_path):
                            extract_archive(archive_path)
            except Exception as exc:
                print(f"[DVS] Repair attempt failed: {exc}")
            dataset = build_dataset(None)
            total = len(dataset)
            targets = getattr(dataset, "targets", None)
            unique_labels = set(int(x) for x in targets) if targets is not None else set()
            if total < 9000 or len(unique_labels) < expected_classes:
                raise RuntimeError(
                    "CIFAR10-DVS download seems incomplete. "
                    f"Got {total} samples across {len(unique_labels)} classes; expected ~{expected_total}. "
                    f"Delete '{root_dir}' and retry."
                )
        transform = ToFrame(sensor_size=dataset.sensor_size, n_time_bins=time_bins)
        dataset.transform = transform
        train_idx, test_idx = split_full_dataset(dataset)
        x_train, y_train = materialize(dataset, "train", train_limit, train_idx)
        x_test, y_test = materialize(dataset, "test", test_limit, test_idx)
    else:
        train_dataset = build_dataset(True)
        train_transform = ToFrame(sensor_size=train_dataset.sensor_size, n_time_bins=time_bins)
        train_dataset.transform = train_transform
        x_train, y_train = materialize(train_dataset, "train", train_limit)

        test_dataset = build_dataset(False)
        test_transform = ToFrame(sensor_size=test_dataset.sensor_size, n_time_bins=time_bins)
        test_dataset.transform = test_transform
        x_test, y_test = materialize(test_dataset, "test", test_limit)

    meta: Dict[str, Any] = {
        "meta_version": np.asarray(1, dtype=np.int64),
        "meta_dataset": np.asarray(getattr(dataset_cls, "__name__", "unknown")),
        "meta_time_bins": np.asarray(int(time_bins), dtype=np.int64),
        "meta_spatial_downsample": np.asarray(int(spatial_downsample), dtype=np.int64),
        "meta_use_polarity": np.asarray(1 if use_polarity else 0, dtype=np.int8),
        "meta_train_limit": np.asarray(-1 if train_limit is None else int(train_limit), dtype=np.int64),
        "meta_test_limit": np.asarray(-1 if test_limit is None else int(test_limit), dtype=np.int64),
    }
    if compress:
        np.savez_compressed(npz_path, x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test, **meta)
    else:
        np.savez(npz_path, x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test, **meta)
    print(f"[DVS] Saved prepared dataset to {npz_path}")
    return npz_path


# load_dvs_cifar10_sequences：加载 DVS-CIFAR10 序列数据集。
# 说明：包含读取/缓存、归一化、reshape/序列化等预处理。
def load_dvs_cifar10_sequences(
    npz_path: str | None,
    train_limit: int | None,
    test_limit: int | None,
    auto_download: bool = True,
    root: str | None = None,
    time_bins: int = 10,
    spatial_downsample: int = 1,
    use_polarity: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    default_root = root or os.path.join("data", "dvs_cifar10")
    candidate = npz_path or os.path.join(default_root, "dvs_cifar10.npz")
    if not os.path.exists(candidate):
        if auto_download:
            prepare_dvs_cifar10_npz(
                candidate,
                root=root,
                time_bins=time_bins,
                spatial_downsample=spatial_downsample,
                use_polarity=use_polarity,
                train_limit=train_limit,
                test_limit=test_limit,
            )
        else:
            raise RuntimeError(
                "DVS-CIFAR10 npz not found. Provide a preprocessed file via --dvs-npz.\n"
                "Expected keys: x_train/y_train/x_test/y_test (preferred) or train_inputs/train_labels."
            )
    expected_time_bins = int(max(1, int(time_bins)))
    spatial_downsample = max(1, int(spatial_downsample))

    def _load_npz_arrays(path: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Dict[str, Any]]:
        data = np.load(path)
        x_train = data.get("x_train", data.get("train_inputs"))
        y_train = data.get("y_train", data.get("train_labels"))
        x_test = data.get("x_test", data.get("test_inputs"))
        y_test = data.get("y_test", data.get("test_labels"))
        if x_train is None or y_train is None or x_test is None or y_test is None:
            raise RuntimeError("DVS-CIFAR10 npz missing required arrays.")
        meta_time_bins = data.get("meta_time_bins")
        meta_use_polarity = data.get("meta_use_polarity")
        meta_spatial_downsample = data.get("meta_spatial_downsample")
        meta_train_limit = data.get("meta_train_limit")
        meta_test_limit = data.get("meta_test_limit")
        meta_dataset = data.get("meta_dataset")
        meta: Dict[str, Any] = {
            "time_bins": None if meta_time_bins is None else int(np.asarray(meta_time_bins).item()),
            "use_polarity": None
            if meta_use_polarity is None
            else bool(int(np.asarray(meta_use_polarity).item())),
            "spatial_downsample": None
            if meta_spatial_downsample is None
            else int(np.asarray(meta_spatial_downsample).item()),
            "train_limit": None
            if meta_train_limit is None
            else (
                None
                if int(np.asarray(meta_train_limit).item()) < 0
                else int(np.asarray(meta_train_limit).item())
            ),
            "test_limit": None
            if meta_test_limit is None
            else (
                None if int(np.asarray(meta_test_limit).item()) < 0 else int(np.asarray(meta_test_limit).item())
            ),
            "dataset": None if meta_dataset is None else str(np.asarray(meta_dataset).item()),
        }
        return x_train, y_train, x_test, y_test, meta

    x_train, y_train, x_test, y_test, meta = _load_npz_arrays(candidate)

    mismatch: List[str] = []
    file_time_bins = meta.get("time_bins")
    if file_time_bins is not None and int(file_time_bins) != expected_time_bins:
        mismatch.append(f"time_bins(file={file_time_bins}, req={expected_time_bins})")
    file_use_polarity = meta.get("use_polarity")
    if file_use_polarity is not None and bool(file_use_polarity) != bool(use_polarity):
        mismatch.append(f"use_polarity(file={file_use_polarity}, req={bool(use_polarity)})")
    file_spatial_downsample = meta.get("spatial_downsample")
    if file_spatial_downsample is not None and int(file_spatial_downsample) != spatial_downsample:
        mismatch.append(f"spatial_downsample(file={file_spatial_downsample}, req={spatial_downsample})")
    if file_spatial_downsample is None and spatial_downsample != 1:
        mismatch.append("spatial_downsample(file=unknown, req!=1)")

    inferred_time_bins = int(x_train.shape[1]) if x_train.ndim in (4, 5) else None
    if inferred_time_bins is not None and inferred_time_bins != expected_time_bins:
        mismatch.append(f"time_bins(file={inferred_time_bins}, req={expected_time_bins})")
    if use_polarity and x_train.ndim != 5:
        mismatch.append(f"polarity_shape(file_ndim={x_train.ndim}, req=5)")
    if not use_polarity and x_train.ndim == 5:
        mismatch.append(f"polarity_shape(file_ndim=5, req=4)")

    file_train_limit = meta.get("train_limit")
    file_test_limit = meta.get("test_limit")
    if train_limit is None and file_train_limit is not None:
        mismatch.append(f"train_limit(file={file_train_limit}, req=None)")
    if test_limit is None and file_test_limit is not None:
        mismatch.append(f"test_limit(file={file_test_limit}, req=None)")
    if train_limit is not None and file_train_limit is not None and int(file_train_limit) < int(train_limit):
        mismatch.append(f"train_limit(file={file_train_limit}, req={int(train_limit)})")
    if test_limit is not None and file_test_limit is not None and int(file_test_limit) < int(test_limit):
        mismatch.append(f"test_limit(file={file_test_limit}, req={int(test_limit)})")

    if train_limit is None and int(x_train.shape[0]) < 1000:
        mismatch.append(f"train_samples(file={int(x_train.shape[0])}, req>=1000)")
    if test_limit is None and int(x_test.shape[0]) < 100:
        mismatch.append(f"test_samples(file={int(x_test.shape[0])}, req>=100)")

    if mismatch and auto_download:
        print(f"[DVS] Cached DVS-CIFAR10 npz incompatible ({', '.join(mismatch)}). Rebuilding: {candidate}")
        prepare_dvs_cifar10_npz(
            candidate,
            root=root,
            time_bins=time_bins,
            spatial_downsample=spatial_downsample,
            use_polarity=use_polarity,
            train_limit=train_limit,
            test_limit=test_limit,
        )
        x_train, y_train, x_test, y_test, meta = _load_npz_arrays(candidate)

    if train_limit is not None:
        x_train = x_train[:train_limit]
        y_train = y_train[:train_limit]
    if test_limit is not None:
        x_test = x_test[:test_limit]
        y_test = y_test[:test_limit]

    try:
        x_train = to_sequence_features(x_train)
        x_test = to_sequence_features(x_test)
        x_train, x_test = normalize_train_test(x_train, x_test)
    except MemoryError as exc:
        n_train = int(getattr(x_train, "shape", (0,))[0])
        time_steps = expected_time_bins
        if getattr(x_train, "ndim", 0) == 3:
            input_size = int(getattr(x_train, "shape", (0, 0))[1])
            time_steps = int(getattr(x_train, "shape", (0, 0, 0))[2])
            approx_gb = (n_train * input_size * time_steps * 4) / (1024**3)
            raise RuntimeError(
                "Out of RAM while materializing DVS-CIFAR10 sequences as float32. "
                f"Requested ~{approx_gb:.1f} GiB for x_train alone. "
                "Fix: reduce memory by using --dvs-spatial-downsample 4 (32x32), "
                "or --dvs-no-polarity, or smaller --dvs-time-bins, or set --train-limit/--test-limit, "
                "or run on a machine with more RAM."
            ) from exc
        if getattr(x_train, "ndim", 0) == 4:
            h = int(getattr(x_train, "shape", (0, 0, 0, 0))[2])
            w = int(getattr(x_train, "shape", (0, 0, 0, 0))[3])
            channels = 1
        elif getattr(x_train, "ndim", 0) == 5:
            # (N, T, H, W, C) or (N, T, C, H, W)
            a, b, c = int(x_train.shape[2]), int(x_train.shape[3]), int(x_train.shape[4])
            if c <= 4 and a > 4 and b > 4:
                h, w, channels = a, b, c
            elif a <= 4 and b > 4 and c > 4:
                channels, h, w = a, b, c
            else:
                h, w, channels = a, b, c
        else:
            h, w, channels = 128, 128, 2 if use_polarity else 1

        input_size = int(h) * int(w) * int(max(1, channels))
        approx_gb = (n_train * input_size * time_steps * 4) / (1024**3)
        raise RuntimeError(
            "Out of RAM while materializing DVS-CIFAR10 sequences as float32. "
            f"Requested ~{approx_gb:.1f} GiB for x_train alone. "
            "Fix: reduce memory by using --dvs-spatial-downsample 4 (32x32), "
            "or --dvs-no-polarity, or smaller --dvs-time-bins, or set --train-limit/--test-limit, "
            "or run on a machine with more RAM."
        ) from exc

    time_steps = x_train.shape[2]
    train_targets = build_repeated_targets(y_train.astype(np.int64), 10, time_steps)
    test_targets = build_repeated_targets(y_test.astype(np.int64), 10, time_steps)
    return x_train, train_targets, y_train.astype(np.int64), x_test, test_targets, y_test.astype(np.int64)


# generate_lorenz_sequences：生成 Lorenz 动力系统序列数据。
# 说明：为上层训练/评估流程提供辅助支持。
# prepare_shd_npz：预处理 SHD 并保存为 npz。
# 说明：将事件流用 ToFrame 量化到固定时间 bins，再展平为 (N, input_size, time_steps) 的序列输入。
def prepare_shd_npz(
    npz_path: str,
    root: str | None = None,
    time_bins: int = 100,
    train_limit: int | None = None,
    test_limit: int | None = None,
    compress: bool = True,
) -> str:
    time_bins = max(1, int(time_bins))
    if not ensure_python_package("tonic"):
        raise RuntimeError(
            "Auto-download for SHD requires the tonic package.\n"
            "Install it with: pip install tonic"
        )

    import inspect

    import tonic.datasets as tonic_datasets  # type: ignore
    from tonic.transforms import ToFrame  # type: ignore

    root_dir = root or os.path.join("data", "shd")
    os.makedirs(root_dir, exist_ok=True)
    out_dir = os.path.dirname(npz_path)
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)

    dataset_cls = getattr(tonic_datasets, "SHD", None)
    if dataset_cls is None:
        raise RuntimeError(
            "Could not find SHD dataset in tonic.datasets. Try upgrading tonic: pip install -U tonic"
        )

    params = inspect.signature(dataset_cls).parameters
    root_param = None
    for key in ("root", "save_to", "data_path", "path"):
        if key in params:
            root_param = key
            break
    split_param = None
    for key in ("train", "train_split", "split", "mode", "set"):
        if key in params:
            split_param = key
            break
    positional_params = [
        p
        for p in params.values()
        if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
    ]
    required_params = [p for p in positional_params if p.default is inspect._empty]

    def build_dataset(train: bool) -> Any:
        kwargs: Dict[str, Any] = {}
        args: List[Any] = []
        if root_param is not None:
            kwargs[root_param] = root_dir
        elif required_params:
            args.append(root_dir)
        if split_param is not None:
            if split_param == "train":
                kwargs[split_param] = train
            else:
                kwargs[split_param] = "train" if train else "test"
        return dataset_cls(*args, **kwargs)

    def frames_to_sequence(frames: Any) -> np.ndarray:
        frames_np = np.asarray(frames)
        if frames_np.ndim < 2:
            raise ValueError(f"Unexpected SHD frame shape: {frames_np.shape}")
        t_steps = int(frames_np.shape[0])
        feat = int(np.prod(frames_np.shape[1:]))
        return frames_np.reshape(t_steps, feat).transpose(1, 0)

    def materialize(dataset: Any, split: str, limit: int | None) -> Tuple[np.ndarray, np.ndarray]:
        total = len(dataset)
        count = total if limit is None else min(total, limit)
        if count <= 0:
            raise RuntimeError("Requested zero samples for SHD preparation.")

        first_frames, first_label = dataset[0]
        first_seq = frames_to_sequence(first_frames)
        x_data = np.zeros((count,) + first_seq.shape, dtype=np.uint16)
        y_data = np.zeros((count,), dtype=np.int64)
        x_data[0] = first_seq.astype(np.uint16)
        y_data[0] = int(first_label)

        for idx in range(1, count):
            frames, label = dataset[idx]
            x_data[idx] = frames_to_sequence(frames).astype(np.uint16)
            y_data[idx] = int(label)
            if idx % 1000 == 0:
                print(f"[SHD] Processed {idx}/{count} {split} samples")
        return x_data, y_data

    print("[SHD] Preparing SHD npz. This may take a while.")
    train_dataset = build_dataset(True)
    train_transform = ToFrame(sensor_size=train_dataset.sensor_size, n_time_bins=time_bins)
    train_dataset.transform = train_transform
    x_train, y_train = materialize(train_dataset, "train", train_limit)

    test_dataset = build_dataset(False)
    test_transform = ToFrame(sensor_size=test_dataset.sensor_size, n_time_bins=time_bins)
    test_dataset.transform = test_transform
    x_test, y_test = materialize(test_dataset, "test", test_limit)

    if compress:
        np.savez_compressed(npz_path, x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test)
    else:
        np.savez(npz_path, x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test)
    print(f"[SHD] Saved prepared dataset to {npz_path}")
    return npz_path


# load_shd_sequences：加载 SHD 序列数据集。
def load_shd_sequences(
    npz_path: str | None,
    train_limit: int | None,
    test_limit: int | None,
    auto_download: bool = True,
    root: str | None = None,
    time_bins: int = 100,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    default_root = root or os.path.join("data", "shd")
    candidate = npz_path or os.path.join(default_root, "shd.npz")
    if not os.path.exists(candidate):
        if not auto_download:
            raise RuntimeError(
                "SHD npz not found. Provide a preprocessed file via --shd-npz.\n"
                "Expected keys: x_train/y_train/x_test/y_test."
            )
        prepare_shd_npz(
            candidate,
            root=root,
            time_bins=time_bins,
            train_limit=train_limit,
            test_limit=test_limit,
        )
    data = np.load(candidate)
    x_train = data.get("x_train", data.get("train_inputs"))
    y_train = data.get("y_train", data.get("train_labels"))
    x_test = data.get("x_test", data.get("test_inputs"))
    y_test = data.get("y_test", data.get("test_labels"))
    if x_train is None or y_train is None or x_test is None or y_test is None:
        raise RuntimeError("SHD npz missing required arrays.")

    if train_limit is not None:
        x_train = x_train[:train_limit]
        y_train = y_train[:train_limit]
    if test_limit is not None:
        x_test = x_test[:test_limit]
        y_test = y_test[:test_limit]

    x_train = x_train.astype(np.float32)
    x_test = x_test.astype(np.float32)
    x_train, x_test = normalize_train_test(x_train, x_test)

    y_train_i = y_train.astype(np.int64)
    y_test_i = y_test.astype(np.int64)
    max_train = int(np.max(y_train_i)) if y_train_i.size else -1
    max_test = int(np.max(y_test_i)) if y_test_i.size else -1
    num_classes = max(max_train, max_test) + 1
    if num_classes <= 0:
        raise RuntimeError("SHD labels are empty; cannot infer num_classes.")

    time_steps = int(x_train.shape[2])
    train_targets = build_repeated_targets(y_train_i, num_classes, time_steps)
    test_targets = build_repeated_targets(y_test_i, num_classes, time_steps)
    return x_train, train_targets, y_train_i, x_test, test_targets, y_test_i


def generate_lorenz_sequences(
    num_samples: int,
    seq_len: int,
    frame_h: int,
    frame_w: int,
    dt: float,
    sigma: float,
    rho: float,
    beta: float,
    warmup: int,
    seed: int,
    blur_sigma: float,
) -> Tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    grid_x, grid_y = np.meshgrid(np.arange(frame_w), np.arange(frame_h))

    x_min, x_max = -30.0, 30.0
    z_min, z_max = 0.0, 50.0

    inputs = np.zeros((num_samples, frame_h * frame_w, seq_len), dtype=np.float32)
    targets = np.zeros((num_samples, frame_h * frame_w, seq_len), dtype=np.float32)

    for i in range(num_samples):
        state = rng.normal(scale=0.5, size=(3,)).astype(np.float32) + np.array([0.0, 1.0, 1.0], dtype=np.float32)
        for _ in range(warmup):
            x, y, z = state
            dx = sigma * (y - x)
            dy = x * (rho - z) - y
            dz = x * y - beta * z
            state = state + dt * np.array([dx, dy, dz], dtype=np.float32)

        frames: List[np.ndarray] = []
        for _ in range(seq_len + 1):
            x, y, z = state
            dx = sigma * (y - x)
            dy = x * (rho - z) - y
            dz = x * y - beta * z
            state = state + dt * np.array([dx, dy, dz], dtype=np.float32)

            x_norm = (x - x_min) / (x_max - x_min)
            z_norm = (z - z_min) / (z_max - z_min)
            cx = np.clip(x_norm, 0.0, 1.0) * (frame_w - 1)
            cy = np.clip(z_norm, 0.0, 1.0) * (frame_h - 1)

            dist2 = (grid_x - cx) ** 2 + (grid_y - cy) ** 2
            frame = np.exp(-dist2 / (2.0 * blur_sigma**2)).astype(np.float32)
            frames.append(frame)

        for t in range(seq_len):
            inputs[i, :, t] = frames[t].reshape(-1)
            targets[i, :, t] = frames[t + 1].reshape(-1)

    return inputs, targets

# =============================
# 训练与评估辅助
# =============================

# extract_params：提取模型参数为 numpy 数组以便保存。
# 说明：在 numpy/文件与模型参数之间转换。
def extract_params(model: Any) -> Dict[str, np.ndarray]:
    params: Dict[str, np.ndarray] = {}
    for name in ("W_hh", "W_xh", "b_h", "W_hy", "b_y"):
        value = getattr(model, name)
        if torch.is_tensor(value):
            params[name] = value.detach().cpu().numpy().copy()
        else:
            params[name] = np.asarray(value, dtype=np.float32).copy()
    return params


# load_params：将 numpy 参数数组加载回模型。
# 说明：包含读取/缓存、归一化、reshape/序列化等预处理。
def load_params(model: Any, params: Dict[str, np.ndarray]) -> None:
    for name in ("W_hh", "W_xh", "b_h", "W_hy", "b_y"):
        value = params[name]
        current = getattr(model, name)
        if torch.is_tensor(current):
            tensor = to_tensor(value, device=current.device, dtype=current.dtype)
            with torch.no_grad():
                current.copy_(tensor)
        else:
            setattr(model, name, value.copy())
    if hasattr(model, "reset_state_buffers"):
        model.reset_state_buffers()
    if hasattr(model, "reset_optimizer"):
        model.reset_optimizer()
    elif hasattr(model, "optimizer") and hasattr(model.optimizer, "state"):
        model.optimizer.state.clear()


def extract_model_hparams(model: Any) -> Dict[str, Any]:
    hparams: Dict[str, Any] = {}

    for attr in ("eta", "loss_mode", "max_grad_norm", "time_normalization", "tbptt_steps"):
        if hasattr(model, attr):
            value = getattr(model, attr)
            if isinstance(value, (int, float, bool, str)) or value is None:
                hparams[attr] = value
            else:
                try:
                    hparams[attr] = float(value)
                except Exception:
                    hparams[attr] = str(value)

    # E-Prop / RFLO
    if hasattr(model, "decay_lambda"):
        hparams["decay_lambda"] = float(getattr(model, "decay_lambda"))
    if hasattr(model, "feedback_type"):
        hparams["feedback_type"] = str(getattr(model, "feedback_type"))

    # Strict FPTT
    for attr in ("parts", "grad_clip", "oracle_momentum", "warmup_epochs", "oracle_id", "label_mode", "use_oracle"):
        if hasattr(model, attr):
            value = getattr(model, attr)
            if isinstance(value, (int, float, bool, str)) or value is None:
                hparams[attr] = value
            else:
                hparams[attr] = str(value)
    if hasattr(model, "_regularizer"):
        reg = getattr(model, "_regularizer")
        for attr in ("alpha", "beta", "rho", "lmbda"):
            if hasattr(reg, attr):
                hparams[f"reg_{attr}"] = float(getattr(reg, attr))

    # Local-rule specific (stability-related)
    for attr in ("lambda_window", "lambda_cap", "denom_floor", "alpha_rho", "lambda_rho"):
        if hasattr(model, attr):
            value = getattr(model, attr)
            if isinstance(value, (int, float, bool, str)) or value is None:
                hparams[attr] = value
            else:
                try:
                    hparams[attr] = float(value)
                except Exception:
                    hparams[attr] = str(value)

    return hparams


# train_batches：遍历批次训练一个 epoch 并汇总指标。
# 说明：遍历所有 batch 并汇总 epoch 指标。
def train_batches(
    model: Any,
    inputs: np.ndarray | torch.Tensor,
    targets: np.ndarray | torch.Tensor,
    batch_size: int,
    epochs: int,
    seed: int,
    use_surrogates: bool = False,
    step_weights: np.ndarray | torch.Tensor | None = None,
    epoch_offset: int = 0,
    progress: bool = False,
    progress_prefix: str = "[TRAIN]",
) -> None:
    rng = np.random.default_rng(seed)
    device = DEFAULT_DEVICE
    if hasattr(model, "device"):
        device = resolve_device(getattr(model, "device"))
    if hasattr(model, "W_hh") and torch.is_tensor(getattr(model, "W_hh")):
        device = getattr(model, "W_hh").device
    inputs_t = to_tensor(inputs, device=device, dtype=torch.float32)
    targets_t = to_tensor(targets, device=device, dtype=torch.float32)
    time_steps = inputs_t.shape[2]
    output_size = targets_t.shape[1]
    batches_per_epoch = math.ceil(int(inputs_t.shape[0]) / batch_size)
    if step_weights is not None and step_weights.shape[0] != time_steps:
        raise ValueError("step_weights length must match the sequence length.")
    if hasattr(model, "step_weights"):
        model.step_weights = to_tensor(step_weights, device=device, dtype=torch.float32) if step_weights is not None else None

    # FPTT 软标签在每个 epoch 内累计，并在 epoch 结束时更新。
    for epoch in range(epochs):
        if hasattr(model, "set_epoch"):
            model.set_epoch(epoch + epoch_offset)
        epoch_label = epoch + epoch_offset + 1
        epoch_total = epoch_offset + epochs
        if progress:
            print(
                f"{progress_prefix} epoch={epoch_label:02d}/{epoch_total:02d} | "
                f"batches={batches_per_epoch} | time_steps={int(time_steps)}",
                flush=True,
            )
        if use_surrogates and isinstance(model, TorchLocalRuleRNN):
            if model.fptt_Q_prev is None:
                Q0 = torch.full(
                    (output_size, time_steps),
                    1.0 / output_size,
                    dtype=torch.float32,
                    device=device,
                )
                model.enable_fptt_surrogates(time_steps, output_size, Q0)
            model.reset_fptt_epoch_accumulators(time_steps, output_size)

        log_every = max(1, batches_per_epoch // 5)
        batch_idx = 0
        for inputs_batch, targets_batch in iterate_minibatches(inputs_t, targets_t, batch_size, rng):
            batch_idx += 1
            h_prev = torch.zeros(
                (model.hidden_size, inputs_batch.shape[0]),
                dtype=torch.float32,
                device=device,
            )
            if isinstance(model, TorchLocalRuleRNN):
                model.run_one_cycle_and_update_directly(inputs_batch, targets_batch, h_prev)
            else:
                model.train_batch(inputs_batch, targets_batch, h_prev)
            if progress and (batch_idx % log_every == 0 or batch_idx == batches_per_epoch):
                print(
                    f"{progress_prefix} epoch={epoch_label:02d}/{epoch_total:02d} | "
                    f"batch={batch_idx:04d}/{batches_per_epoch:04d}",
                    flush=True,
                )

        if use_surrogates and isinstance(model, TorchLocalRuleRNN):
            model.finalize_fptt_epoch()


# train_epoch_lm：训练语言模型一个 epoch 并记录指标。
# 说明：按序列切分训练并计算困惑度。
def train_epoch_lm(
    model: Any,
    data: np.ndarray | torch.Tensor,
    vocab_size: int,
    block_size: int,
    batch_size: int,
    steps: int,
    seed: int,
    epoch: int,
    use_surrogates: bool = False,
) -> None:
    if hasattr(model, "set_epoch"):
        model.set_epoch(epoch)
    rng = np.random.default_rng(seed)
    device = DEFAULT_DEVICE
    if hasattr(model, "device"):
        device = resolve_device(getattr(model, "device"))
    if hasattr(model, "W_hh") and torch.is_tensor(getattr(model, "W_hh")):
        device = getattr(model, "W_hh").device
    # 语言模型也复用 FPTT 软标签的按步替代目标。
    if use_surrogates and isinstance(model, TorchLocalRuleRNN):
        if model.fptt_Q_prev is None:
            Q0 = torch.full(
                (vocab_size, block_size),
                1.0 / vocab_size,
                dtype=torch.float32,
                device=device,
            )
            model.enable_fptt_surrogates(block_size, vocab_size, Q0)
        model.reset_fptt_epoch_accumulators(block_size, vocab_size)
    for inputs, targets in batch_iterator(data, vocab_size, block_size, batch_size, steps, rng):
        inputs_t = to_tensor(inputs, device=device, dtype=torch.float32)
        targets_t = to_tensor(targets, device=device, dtype=torch.float32)
        h_prev = torch.zeros((model.hidden_size, inputs_t.shape[0]), dtype=torch.float32, device=device)
        if isinstance(model, TorchLocalRuleRNN):
            model.run_one_cycle_and_update_directly(inputs_t, targets_t, h_prev)
        else:
            model.train_batch(inputs_t, targets_t, h_prev)
    if use_surrogates and isinstance(model, TorchLocalRuleRNN):
        model.finalize_fptt_epoch()


# scan_gains_classification：扫描增益超参以选择分类最优初始化。
# 说明：遍历候选并记录稳定性/验证指标。
def scan_gains_classification(
    gains: np.ndarray,
    build_model: callable,
    train_inputs: np.ndarray,
    train_targets: np.ndarray,
    train_labels: np.ndarray,
    val_inputs: np.ndarray,
    val_targets: np.ndarray,
    val_labels: np.ndarray,
    batch_size: int,
    scan_epochs: int,
    seed: int,
    lyapunov_driver: np.ndarray | None = None,
    use_surrogates: bool = False,
    step_weights: np.ndarray | None = None,
    plot_dir: Path | None = None,
    plot_tag: str | None = None,
    task_label: str | None = None,
    plot_show: bool = False,
) -> Tuple[float, Dict[str, np.ndarray], Dict[str, float]]:
    best_g = float(gains[0])
    best_metric = -1.0
    best_params: Dict[str, np.ndarray] = {}
    best_stats: Dict[str, float] = {}
    scan_g: List[float] = []
    scan_metric: List[float] = []
    scan_loss: List[float] = []
    scan_lyap_pre: List[float] = []
    scan_lyap_post: List[float] = []

    num_gains = int(np.asarray(gains).size)
    batches_per_epoch = math.ceil(int(train_inputs.shape[0]) / batch_size)
    time_steps = int(train_inputs.shape[2])
    print(
        f"[SCAN] start | gains={num_gains} | scan_epochs={int(scan_epochs)} | "
        f"train={int(train_inputs.shape[0])} | batches/epoch={batches_per_epoch} | time_steps={time_steps}",
        flush=True,
    )

    for gain_idx, g in enumerate(gains, start=1):
        prefix = f"[SCAN {gain_idx}/{num_gains} g={float(g):.3f}]"
        print(f"{prefix} init", flush=True)
        model = build_model()
        model.initialize_weights_with_gain(float(g), seed=seed)
        init_params = extract_params(model)
        lambda_pre = float("nan")
        if lyapunov_driver is not None:
            print(f"{prefix} lyap=pre", flush=True)
            lambda_pre = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
        train_start = time.perf_counter()
        train_batches(
            model,
            train_inputs,
            train_targets,
            batch_size,
            scan_epochs,
            seed + 1,
            use_surrogates=use_surrogates,
            step_weights=step_weights,
            progress=True,
            progress_prefix=prefix,
        )
        train_sec = time.perf_counter() - train_start
        val_loss, val_acc = evaluate_classifier_final_step(model, val_inputs, val_targets, val_labels, batch_size)
        lambda_post = float("nan")
        if lyapunov_driver is not None:
            print(f"{prefix} lyap=post", flush=True)
            lambda_post = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
        scan_g.append(float(g))
        scan_metric.append(float(val_acc))
        scan_loss.append(float(val_loss))
        scan_lyap_pre.append(float(lambda_pre))
        scan_lyap_post.append(float(lambda_post))
        if val_acc > best_metric:
            best_metric = val_acc
            best_g = float(g)
            best_params = init_params
            best_stats = {"val_loss": val_loss, "val_acc": val_acc}
            if lyapunov_driver is not None:
                best_stats.update(
                    {
                        "lambda_pre": lambda_pre,
                        "lambda_post": lambda_post,
                    }
                )
        if lyapunov_driver is None:
            print(
                f"[SCAN] g={g:.3f} | val_acc={val_acc:.4f} | val_loss={val_loss:.4f} | train_sec={train_sec:.1f}",
                flush=True,
            )
        else:
            delta = lambda_post - lambda_pre
            print(
                f"[SCAN] g={g:.3f} | val_acc={val_acc:.4f} | val_loss={val_loss:.4f} | "
                f"lyap=(pre:{lambda_pre:.4f}, post:{lambda_post:.4f}, d:{delta:.4f}) | train_sec={train_sec:.1f}",
                flush=True,
            )

    if plot_dir is not None and plot_tag is not None and scan_g:
        plot_scan_results(
            task_label or "Scan",
            scan_g,
            scan_metric,
            "Val Accuracy",
            scan_lyap_pre,
            scan_lyap_post,
            plot_dir,
            plot_tag,
            plot_show,
            best_g=best_g,
            higher_is_better=True,
            aux_values=scan_loss,
            aux_label="Val Loss",
        )

    return best_g, best_params, best_stats


# scan_gains_regression：扫描增益超参以选择回归最优初始化。
# 说明：遍历候选并记录稳定性/验证指标。
def scan_gains_regression(
    gains: np.ndarray,
    build_model: callable,
    train_inputs: np.ndarray,
    train_targets: np.ndarray,
    val_inputs: np.ndarray,
    val_targets: np.ndarray,
    batch_size: int,
    scan_epochs: int,
    seed: int,
    lyapunov_driver: np.ndarray | None = None,
    use_surrogates: bool = False,
    step_weights: np.ndarray | None = None,
    eval_mode: str | None = None,
    eval_warmup: int = 1,
    plot_dir: Path | None = None,
    plot_tag: str | None = None,
    task_label: str | None = None,
    plot_show: bool = False,
) -> Tuple[float, Dict[str, np.ndarray], Dict[str, float]]:
    best_g = float(gains[0])
    best_metric = float("inf")
    best_params: Dict[str, np.ndarray] = {}
    best_stats: Dict[str, float] = {}
    scan_g: List[float] = []
    scan_metric: List[float] = []
    scan_lyap_pre: List[float] = []
    scan_lyap_post: List[float] = []
    eval_mode = str(eval_mode or "teacher").lower()
    if eval_mode in {"autoregressive", "ar"}:
        eval_mode = "rollout"
    use_rollout = eval_mode == "rollout"
    time_steps = int(val_inputs.shape[2])
    eval_warmup = max(1, min(int(eval_warmup), time_steps))
    val_label = "val_mse_rollout" if use_rollout else "val_mse"
    metric_label = "Val MSE (rollout)" if use_rollout else "Val MSE"

    num_gains = int(np.asarray(gains).size)
    batches_per_epoch = math.ceil(int(train_inputs.shape[0]) / batch_size)
    print(
        f"[SCAN] start | gains={num_gains} | scan_epochs={int(scan_epochs)} | "
        f"train={int(train_inputs.shape[0])} | batches/epoch={batches_per_epoch} | time_steps={time_steps} | mode={eval_mode}",
        flush=True,
    )

    for gain_idx, g in enumerate(gains, start=1):
        prefix = f"[SCAN {gain_idx}/{num_gains} g={float(g):.3f}]"
        print(f"{prefix} init", flush=True)
        model = build_model()
        model.initialize_weights_with_gain(float(g), seed=seed)
        init_params = extract_params(model)
        lambda_pre = float("nan")
        if lyapunov_driver is not None:
            print(f"{prefix} lyap=pre", flush=True)
            lambda_pre = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
        train_start = time.perf_counter()
        train_batches(
            model,
            train_inputs,
            train_targets,
            batch_size,
            scan_epochs,
            seed + 1,
            use_surrogates=use_surrogates,
            step_weights=step_weights,
            progress=True,
            progress_prefix=prefix,
        )
        train_sec = time.perf_counter() - train_start
        if use_rollout:
            val_loss = evaluate_regression_mse_rollout(
                model,
                val_inputs,
                val_targets,
                batch_size,
                warmup_steps=eval_warmup,
                step_weights=step_weights,
            )
        else:
            val_loss = evaluate_regression_mse(
                model,
                val_inputs,
                val_targets,
                batch_size,
                step_weights=step_weights,
            )
        lambda_post = float("nan")
        if lyapunov_driver is not None:
            print(f"{prefix} lyap=post", flush=True)
            lambda_post = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
        scan_g.append(float(g))
        scan_metric.append(float(val_loss))
        scan_lyap_pre.append(float(lambda_pre))
        scan_lyap_post.append(float(lambda_post))
        if val_loss < best_metric:
            best_metric = val_loss
            best_g = float(g)
            best_params = init_params
            best_stats = {"val_loss": val_loss}
            if lyapunov_driver is not None:
                best_stats.update(
                    {
                        "lambda_pre": lambda_pre,
                        "lambda_post": lambda_post,
                    }
                )
        if lyapunov_driver is None:
            print(f"[SCAN] g={g:.3f} | {val_label}={val_loss:.6f} | train_sec={train_sec:.1f}", flush=True)
        else:
            delta = lambda_post - lambda_pre
            print(
                f"[SCAN] g={g:.3f} | {val_label}={val_loss:.6f} | "
                f"lyap=(pre:{lambda_pre:.4f}, post:{lambda_post:.4f}, d:{delta:.4f}) | train_sec={train_sec:.1f}",
                flush=True,
            )

    if plot_dir is not None and plot_tag is not None and scan_g:
        plot_scan_results(
            task_label or "Scan",
            scan_g,
            scan_metric,
            metric_label,
            scan_lyap_pre,
            scan_lyap_post,
            plot_dir,
            plot_tag,
            plot_show,
            best_g=best_g,
            higher_is_better=False,
        )

    return best_g, best_params, best_stats


# scan_gains_lm：扫描增益超参以选择语言模型最优初始化。
# 说明：遍历候选并记录稳定性/验证指标。
def scan_gains_lm(
    gains: np.ndarray,
    build_model: callable,
    train_data: np.ndarray,
    val_data: np.ndarray,
    vocab_size: int,
    block_size: int,
    batch_size: int,
    steps_per_epoch: int,
    val_steps: int,
    scan_epochs: int,
    seed: int,
    lyapunov_driver: np.ndarray | None = None,
    use_surrogates: bool = False,
    plot_dir: Path | None = None,
    plot_tag: str | None = None,
    task_label: str | None = None,
    plot_show: bool = False,
) -> Tuple[float, Dict[str, np.ndarray], Dict[str, float]]:
    best_g = float(gains[0])
    best_metric = float("inf")
    best_params: Dict[str, np.ndarray] = {}
    best_stats: Dict[str, float] = {}
    scan_g: List[float] = []
    scan_metric: List[float] = []
    scan_lyap_pre: List[float] = []
    scan_lyap_post: List[float] = []

    for g in gains:
        model = build_model()
        model.initialize_weights_with_gain(float(g), seed=seed)
        init_params = extract_params(model)
        lambda_pre = float("nan")
        if lyapunov_driver is not None:
            lambda_pre = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
        for epoch in range(scan_epochs):
            train_epoch_lm(
                model,
                train_data,
                vocab_size,
                block_size,
                batch_size,
                steps_per_epoch,
                seed + epoch + 1,
                epoch,
                use_surrogates=use_surrogates,
            )
        val_loss = evaluate_language_model(model, val_data, vocab_size, block_size, batch_size, val_steps)
        lambda_post = float("nan")
        if lyapunov_driver is not None:
            lambda_post = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
        scan_g.append(float(g))
        scan_metric.append(float(val_loss))
        scan_lyap_pre.append(float(lambda_pre))
        scan_lyap_post.append(float(lambda_post))
        if val_loss < best_metric:
            best_metric = val_loss
            best_g = float(g)
            best_params = init_params
            best_stats = {"val_loss": val_loss}
            if lyapunov_driver is not None:
                best_stats.update(
                    {
                        "lambda_pre": lambda_pre,
                        "lambda_post": lambda_post,
                    }
                )
        if lyapunov_driver is None:
            print(f"[SCAN] g={g:.3f} | val_loss={val_loss:.4f}")
        else:
            delta = lambda_post - lambda_pre
            print(
                f"[SCAN] g={g:.3f} | val_loss={val_loss:.4f} | "
                f"lyap=(pre:{lambda_pre:.4f}, post:{lambda_post:.4f}, d:{delta:.4f})"
            )

    if plot_dir is not None and plot_tag is not None and scan_g:
        plot_scan_results(
            task_label or "Scan",
            scan_g,
            scan_metric,
            "Val Loss",
            scan_lyap_pre,
            scan_lyap_post,
            plot_dir,
            plot_tag,
            plot_show,
            best_g=best_g,
            higher_is_better=False,
        )

    return best_g, best_params, best_stats


# =============================
# 任务入口（组装模型 + 扫描 + 训练对比组）
# =============================

# run_classification_task：分类任务入口：准备数据、扫描增益并训练评估。
# 说明：为上层训练/评估流程提供辅助支持。
def run_classification_task(task_data: Dict[str, Any], args: argparse.Namespace, gains: np.ndarray) -> None:
    set_global_seed(int(getattr(args, "seed", 0)))
    # 步骤1：准备数据、权重与绘图路径。
    train_inputs_np = task_data["train_inputs"]
    train_targets_np = task_data["train_targets"]
    train_labels_np = task_data["train_labels"]
    test_inputs_np = task_data["test_inputs"]
    test_targets_np = task_data["test_targets"]
    test_labels_np = task_data["test_labels"]
    task_label = task_data.get("task_name", getattr(args, "task", "Sequence Task"))
    plot_enabled = getattr(args, "plot", True)
    plot_dir = None
    plot_tag = None
    if plot_enabled:
        plot_dir, plot_tag = resolve_plot_context(args, task_label)
    # step_labels=fptt 时启用软标签与替代权重。
    step_label_mode = str(getattr(args, "step_labels", "final")).lower()
    use_surrogates = step_label_mode == "fptt"
    time_weighting = task_data.get("time_weighting", getattr(args, "time_weighting", None))
    step_weights = build_time_weights(train_inputs_np.shape[2], time_weighting)
    device = DEFAULT_DEVICE
    train_inputs = to_tensor(train_inputs_np, device=device, dtype=torch.float32)
    train_targets = to_tensor(train_targets_np, device=device, dtype=torch.float32)
    train_labels = to_tensor(train_labels_np, device=device, dtype=torch.long)
    train_groups_np = task_data.get("train_groups")
    train_groups = (
        to_tensor(train_groups_np, device=device, dtype=torch.long) if train_groups_np is not None else None
    )
    test_inputs = to_tensor(test_inputs_np, device=device, dtype=torch.float32)
    test_targets = to_tensor(test_targets_np, device=device, dtype=torch.float32)
    test_labels = to_tensor(test_labels_np, device=device, dtype=torch.long)
    step_weights_t = to_tensor(step_weights, device=device, dtype=torch.float32) if step_weights is not None else None
    time_steps = int(train_inputs_np.shape[2])
    tbptt_short = max(1, int(getattr(args, "tbptt_short", 1)))
    tbptt_long_default = max(2, min(10, time_steps))
    tbptt_long_value = getattr(args, "tbptt_long", None)
    tbptt_long = tbptt_long_default if tbptt_long_value is None else int(tbptt_long_value)

    rng = np.random.default_rng(args.seed)
    if train_groups is None:
        tr_inputs, tr_targets, tr_labels, val_inputs, val_targets, val_labels = split_train_val(
            train_inputs, train_targets, train_labels, 0.1, rng
        )
    else:
        tr_inputs, tr_targets, tr_labels, val_inputs, val_targets, val_labels = split_train_val_grouped(
            train_inputs, train_targets, train_labels, train_groups, 0.1, rng
        )
    train_inputs_fit = tr_inputs
    train_targets_fit = tr_targets
    train_labels_fit = tr_labels
    batches_per_epoch = math.ceil(int(train_inputs_fit.shape[0]) / args.batch_size)
    print(
        f"[DATA] train={int(train_inputs_fit.shape[0])} | val={int(val_inputs.shape[0])} | "
        f"test={int(test_inputs.shape[0])} | input_size={int(task_data['input_size'])} | "
        f"time_steps={time_steps} | device={device}",
        flush=True,
    )

    run_config: Dict[str, Any] = {
        "task_id": getattr(args, "task", None),
        "task_type": "classification",
        "task_name": task_label,
        "seed": int(getattr(args, "seed", 0)),
        "epochs": int(getattr(args, "epochs", 0)),
        "scan_epochs": int(getattr(args, "scan_epochs", 0)),
        "batch_size": int(getattr(args, "batch_size", 0)),
        "hidden": int(getattr(args, "hidden", 0)),
        "lr": float(getattr(args, "lr", 0.0)),
        "time_weighting": time_weighting,
        "step_labels": step_label_mode,
        "use_surrogates": bool(use_surrogates),
        "tbptt_short": int(tbptt_short),
        "tbptt_long": int(tbptt_long),
        "gains": [float(x) for x in np.asarray(gains).ravel().tolist()],
        "input_size": int(task_data["input_size"]),
        "output_size": int(task_data["output_size"]),
        "time_steps": int(time_steps),
        "train_size": int(train_inputs.shape[0]),
        "val_size": int(val_inputs.shape[0]),
        "test_size": int(test_inputs.shape[0]),
        "device": str(device),
        "args": dict(vars(args)),
        "argv": list(sys.argv),
        "eprop_feedback": EPROP_FEEDBACK,
        "eprop_seed": int(EPROP_SEED),
        "fptt_parts": int(FPTT_PARTS),
        "fptt_lambda": float(FPTT_LAMBDA),
        "fptt_oracle_momentum": float(FPTT_ORACLE_MOMENTUM),
    }

    # build_local：构建 Local Rule 模型及其优化器配置。
    # 说明：为上层训练/评估流程提供辅助支持。
    def build_local() -> TorchLocalRuleRNN:
        return TorchLocalRuleRNN(
            task_data["input_size"],
            args.hidden,
            task_data["output_size"],
            eta=args.lr,
            loss_mode="ce",
            seed=args.seed,
            device=device,
        )

    # 步骤2：增益扫描（Local Rule）并选最优初始化。
    lyapunov_driver = build_lyapunov_driver(val_inputs)
    best_g, init_params, stats = scan_gains_classification(
        gains,
        build_local,
        tr_inputs,
        tr_targets,
        tr_labels,
        val_inputs,
        val_targets,
        val_labels,
        args.batch_size,
        args.scan_epochs,
        args.seed,
        lyapunov_driver=lyapunov_driver,
        use_surrogates=use_surrogates,
        step_weights=step_weights_t,
        plot_dir=plot_dir,
        plot_tag=plot_tag,
        task_label=task_label,
        plot_show=False,
    )
    summary = f"Best g={best_g:.3f} | val_acc={stats['val_acc']:.4f} | val_loss={stats['val_loss']:.4f}"
    if "lambda_pre" in stats and "lambda_post" in stats:
        delta = stats["lambda_post"] - stats["lambda_pre"]
        summary += f" | lyap(pre:{stats['lambda_pre']:.4f}, post:{stats['lambda_post']:.4f}, d:{delta:.4f})"
    print(summary)
    run_config["best_g"] = float(best_g)
    run_config["scan_best"] = {k: float(v) for k, v in stats.items()}

    local_name = "Local Rule (FPTT)" if use_surrogates else "Local Rule"
    # 对照组：Local Rule / BPTT / E-Prop / FPTT。
    # 步骤3：构建对照组模型（Local/BPTT/E-Prop/FPTT/TBPTT）。
    skip_eprop = bool(getattr(args, "no_eprop", False))
    models: Dict[str, Any] = {local_name: build_local()}
    models["BPTT"] = BPTTRNN(
        task_data["input_size"],
        args.hidden,
        task_data["output_size"],
        eta=args.lr,
        loss_mode="ce",
        time_normalization=False,
        seed=args.seed,
        device=device,
    )
    if not skip_eprop:
        models["E-Prop"] = StandardEPropRNN(
            task_data["input_size"],
            args.hidden,
            task_data["output_size"],
            eta=args.lr,
            feedback=EPROP_FEEDBACK,
            seed=EPROP_SEED,
            loss_mode="ce",
            device=device,
        )
    models["FPTT"] = StrictFPTTClassifier(
        task_data["input_size"],
        args.hidden,
        task_data["output_size"],
        eta=args.lr,
        parts=FPTT_PARTS,
        clip=5.0,
        lmbda=FPTT_LAMBDA,
        oracle_momentum=FPTT_ORACLE_MOMENTUM,
        label_mode="last",
        oracle_id=args.task,
        use_oracle=True,
        device=device,
    )
    if 0 < tbptt_short < time_steps:
        models[f"TBPTT-{tbptt_short}"] = BPTTRNN(
            task_data["input_size"],
            args.hidden,
            task_data["output_size"],
            eta=args.lr,
            loss_mode="ce",
            tbptt_steps=tbptt_short,
            time_normalization=False,
            seed=args.seed,
            device=device,
        )
    if tbptt_long != tbptt_short and 0 < tbptt_long < time_steps:
        models[f"TBPTT-{tbptt_long}"] = BPTTRNN(
            task_data["input_size"],
            args.hidden,
            task_data["output_size"],
            eta=args.lr,
            loss_mode="ce",
            tbptt_steps=tbptt_long,
            time_normalization=False,
            seed=args.seed,
            device=device,
        )

    print(f"[RUN] models={', '.join(models.keys())}", flush=True)

    prediction_cfg = task_data.get("prediction_plot") or {}
    prediction_bundle: Dict[str, Any] | None = None
    if plot_enabled and prediction_cfg:
        if test_inputs_np.shape[0] == 0:
            print("[PLOT] Skipping prediction plot: empty test set.")
        else:
            plot_type = str(prediction_cfg.get("type", "")).lower()
            if not plot_type and "frame_h" in prediction_cfg and "frame_w" in prediction_cfg:
                plot_type = "image"
            if plot_type not in {"image", "trajectory3d", "timeseries"}:
                print("[PLOT] Skipping prediction plot: unknown type.")
            else:
                sample_index = int(prediction_cfg.get("sample_index", 0))
                sample_index = max(0, min(sample_index, test_inputs_np.shape[0] - 1))
                prediction_bundle = {
                    "plot_type": plot_type,
                    "sample_input": test_inputs_np[sample_index],
                    "sample_target": test_targets_np[sample_index],
                    "predictions": {},
                    "model_names": prediction_cfg.get("models"),
                    "local_only": bool(prediction_cfg.get("local_only", False)),
                }
                pred_mode = str(prediction_cfg.get("mode", "teacher")).lower()
                pred_warmup = max(1, int(prediction_cfg.get("warmup_steps", 1)))
                if pred_mode in {"rollout", "autoregressive", "ar"}:
                    if task_data["input_size"] != task_data["output_size"]:
                        print(
                            "[PLOT] Autoregressive prediction requires input_size == output_size; "
                            "using teacher forcing."
                        )
                        pred_mode = "teacher"
                if pred_mode in {"rollout", "autoregressive", "ar"}:
                    mode_label = "Rollout"
                elif pred_mode in {"teacher", "teacher_forcing"}:
                    mode_label = "Teacher"
                else:
                    mode_label = pred_mode.replace("_", " ").title()
                prediction_bundle.update(
                    {
                        "prediction_mode": pred_mode,
                        "prediction_warmup": pred_warmup,
                        "show_input": bool(prediction_cfg.get("show_input", False)),
                        "mode_label": mode_label,
                    }
                )
                if plot_type == "image":
                    frame_h = int(prediction_cfg.get("frame_h", 0))
                    frame_w = int(prediction_cfg.get("frame_w", 0))
                    frame_channels = max(1, int(prediction_cfg.get("frame_channels", 1)))
                    frame_depth = frame_h * frame_w * frame_channels
                    if (
                        frame_h > 0
                        and frame_w > 0
                        and frame_depth > 0
                        and test_targets_np.shape[1] % frame_depth == 0
                    ):
                        time_indices = prediction_cfg.get("time_indices")
                        if time_indices is not None:
                            time_indices = [int(t) for t in time_indices]
                        frame_index = prediction_cfg.get("frame_index")
                        if frame_index is not None:
                            frame_index = int(frame_index)
                        prediction_bundle.update(
                            {
                                "frame_h": frame_h,
                                "frame_w": frame_w,
                                "frame_channels": frame_channels,
                                "time_indices": time_indices,
                                "frame_index": frame_index,
                            }
                        )
                    else:
                        prediction_bundle = None
                        print("[PLOT] Skipping prediction plot: frame size mismatch.")
                elif plot_type in {"trajectory3d", "timeseries"}:
                    dims = prediction_cfg.get("dims")
                    if dims is not None:
                        dims = [int(d) for d in dims]
                    prediction_bundle["dims"] = dims
    pred_mode = None
    pred_warmup = 1
    mode_label = None
    show_input = False
    if prediction_bundle is not None:
        pred_mode = str(prediction_bundle.get("prediction_mode", "teacher")).lower()
        pred_warmup = int(prediction_bundle.get("prediction_warmup", 1))
        mode_label = prediction_bundle.get("mode_label")
        show_input = bool(prediction_bundle.get("show_input", False))
    # 步骤4：训练/评估各模型并汇总结果。
    results: Dict[str, Dict[str, Any]] = {}
    log_every = max(1, args.epochs // 5)

    for name, model in models.items():
        print(
            f"[{name}] start | epochs={int(args.epochs)} | batches/epoch={int(batches_per_epoch)} | time_steps={time_steps}",
            flush=True,
        )
        load_params(model, init_params)
        lambda_pre = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
        history: List[float] = []
        use_surrogates_local = use_surrogates and name == local_name
        complexity = estimate_model_complexity(model)
        update_stats = estimate_training_counts(model, time_steps, batches_per_epoch, args.epochs)
        train_runtime_sec = 0.0
        eval_runtime_sec = 0.0
        best_epoch = 0
        best_val_metric = -float("inf")
        best_val_loss = float("inf")
        best_params: Dict[str, np.ndarray] | None = None
        if plot_enabled:
            for epoch in range(args.epochs):
                train_start = time.perf_counter()
                train_batches(
                    model,
                    train_inputs_fit,
                    train_targets_fit,
                    args.batch_size,
                    1,
                    args.seed + 10 + epoch,
                    use_surrogates=use_surrogates_local,
                    step_weights=step_weights_t,
                    epoch_offset=epoch,
                )
                train_runtime_sec += time.perf_counter() - train_start
                eval_start = time.perf_counter()
                val_loss, val_acc = evaluate_classifier_final_step(
                    model, val_inputs, val_targets, val_labels, args.batch_size
                )
                eval_runtime_sec += time.perf_counter() - eval_start
                history.append(float(val_acc))
                if val_acc > best_val_metric:
                    best_val_metric = float(val_acc)
                    best_val_loss = float(val_loss)
                    best_epoch = epoch + 1
                    best_params = extract_params(model)
                if (epoch + 1) % log_every == 0 or (epoch + 1) == args.epochs:
                    print(
                        f"[{name}] epoch={epoch+1:02d} | val_acc={val_acc:.4f} | val_loss={val_loss:.4f}"
                    )
        else:
            train_start = time.perf_counter()
            train_batches(
                model,
                train_inputs_fit,
                train_targets_fit,
                args.batch_size,
                args.epochs,
                args.seed + 10,
                use_surrogates=use_surrogates_local,
                step_weights=step_weights_t,
                epoch_offset=0,
            )
            train_runtime_sec = time.perf_counter() - train_start
        if best_params is not None:
            load_params(model, best_params)
        eval_start = time.perf_counter()
        final_test_loss, final_test_acc = evaluate_classifier_final_step(
            model, test_inputs, test_targets, test_labels, args.batch_size
        )
        eval_runtime_sec += time.perf_counter() - eval_start
        if best_epoch == 0:
            best_epoch = int(args.epochs)
            best_val_metric = float("nan")
            best_val_loss = float("nan")
        runtime_sec = train_runtime_sec + eval_runtime_sec
        updates_total = update_stats["updates_total"]
        steps_total = update_stats["steps_total"]
        runtime_per_update_sec = (
            train_runtime_sec / updates_total if updates_total > 0 else float("nan")
        )
        runtime_per_step_sec = (
            train_runtime_sec / steps_total if steps_total > 0 else float("nan")
        )
        lambda_post = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
        delta = lambda_post - lambda_pre
        print(
            f"[{name}] test_acc={final_test_acc:.4f} | test_loss={final_test_loss:.4f} | "
            f"best_val_acc={best_val_metric:.4f} (epoch={best_epoch:02d}) | "
            f"lyap=(pre:{lambda_pre:.4f}, post:{lambda_post:.4f}, d:{delta:.4f})"
        )
        results[name] = {
            "metric": float(final_test_acc),
            "val_metric": float(best_val_metric),
            "val_loss": float(best_val_loss),
            "best_epoch": int(best_epoch),
            "lyap_pre": float(lambda_pre),
            "lyap_post": float(lambda_post),
            "history": history,
            "complexity_params": float(complexity["params"]),
            "complexity_state": float(complexity["state"]),
            "complexity_total": float(complexity["total"]),
            "runtime_sec": float(runtime_sec),
            "train_runtime_sec": float(train_runtime_sec),
            "eval_runtime_sec": float(eval_runtime_sec),
            "runtime_per_update_sec": float(runtime_per_update_sec),
            "runtime_per_step_sec": float(runtime_per_step_sec),
            "batches_per_epoch": int(update_stats["batches_per_epoch"]),
            "time_steps": int(update_stats["time_steps"]),
            "update_factor": int(update_stats["update_factor"]),
            "updates_per_epoch": int(update_stats["updates_per_epoch"]),
            "updates_total": int(update_stats["updates_total"]),
            "steps_total": int(update_stats["steps_total"]),
            "hparams": extract_model_hparams(model),
        }

    if plot_enabled and plot_dir is not None and plot_tag is not None:
        save_results_summary(
            task_label,
            "Test Accuracy",
            results,
            plot_dir,
            plot_tag,
            run_config=run_config,
        )
        plot_comparison_results(
            task_label,
            results,
            metric_label="Test Accuracy",
            history_label="Val Accuracy",
            higher_is_better=True,
            plot_dir=plot_dir,
            plot_tag=plot_tag,
            show=False,
        )
        plot_cost_profile(
            task_label,
            results,
            plot_dir=plot_dir,
            plot_tag=plot_tag,
            show=False,
        )


# run_regression_task：回归任务入口：准备数据并训练评估。
# 说明：为上层训练/评估流程提供辅助支持。
def run_regression_task(task_data: Dict[str, Any], args: argparse.Namespace, gains: np.ndarray) -> None:
    set_global_seed(int(getattr(args, "seed", 0)))
    # 回归任务入口：支持 teacher/rollout 两种评估模式。
    # 步骤1：准备数据、权重、评估与绘图配置。
    train_inputs_np = task_data["train_inputs"]
    train_targets_np = task_data["train_targets"]
    test_inputs_np = task_data["test_inputs"]
    test_targets_np = task_data["test_targets"]
    task_label = task_data.get("task_name", getattr(args, "task", "Sequence Task"))
    plot_enabled = getattr(args, "plot", True)
    plot_dir = None
    plot_tag = None
    if plot_enabled:
        plot_dir, plot_tag = resolve_plot_context(args, task_label)
    time_weighting = getattr(args, "time_weighting", None)
    if time_weighting is None:
        time_weighting = task_data.get("time_weighting")
    step_weights = build_time_weights(train_inputs_np.shape[2], time_weighting)
    device = DEFAULT_DEVICE
    train_inputs = to_tensor(train_inputs_np, device=device, dtype=torch.float32)
    train_targets = to_tensor(train_targets_np, device=device, dtype=torch.float32)
    test_inputs = to_tensor(test_inputs_np, device=device, dtype=torch.float32)
    test_targets = to_tensor(test_targets_np, device=device, dtype=torch.float32)
    step_weights_t = to_tensor(step_weights, device=device, dtype=torch.float32) if step_weights is not None else None
    time_steps = int(train_inputs_np.shape[2])
    tbptt_short = max(1, int(getattr(args, "tbptt_short", 1)))
    tbptt_long_default = max(2, min(10, time_steps))
    tbptt_long_value = getattr(args, "tbptt_long", None)
    tbptt_long = tbptt_long_default if tbptt_long_value is None else int(tbptt_long_value)

    prediction_cfg = task_data.get("prediction_plot") or {}
    eval_cfg = task_data.get("evaluation") or {}
    eval_mode = getattr(args, "eval_mode", None)
    # 回归评估支持 teacher forcing / rollout 两种模式。
    if eval_mode is None:
        eval_mode = eval_cfg.get("mode")
    if eval_mode is None:
        eval_mode = prediction_cfg.get("mode")
    if eval_mode is None:
        eval_mode = getattr(args, "pred_mode", None)
    if eval_mode is None:
        eval_mode = "teacher"
    eval_mode = str(eval_mode).lower()
    if eval_mode in {"autoregressive", "ar"}:
        eval_mode = "rollout"
    if eval_mode not in {"teacher", "rollout"}:
        print(f"[EVAL] Unknown eval mode '{eval_mode}', using teacher forcing.")
        eval_mode = "teacher"

    eval_warmup = getattr(args, "eval_warmup", None)
    if eval_warmup is None:
        eval_warmup = eval_cfg.get("warmup_steps")
    if eval_warmup is None:
        eval_warmup = prediction_cfg.get("warmup_steps")
    if eval_warmup is None:
        eval_warmup = getattr(args, "pred_warmup", None)
    if eval_warmup is None:
        eval_warmup = 1
    eval_warmup = max(1, min(int(eval_warmup), time_steps))

    if eval_mode == "rollout" and task_data["input_size"] != task_data["output_size"]:
        print("[EVAL] Autoregressive eval requires input_size == output_size; using teacher forcing.")
        eval_mode = "teacher"

    eval_label = "Rollout" if eval_mode == "rollout" else "Teacher"
    metric_label = f"Test MSE ({eval_label})"
    history_label = f"Val MSE ({eval_label})"
    val_label = "val_mse_rollout" if eval_mode == "rollout" else "val_mse"

    rng = np.random.default_rng(args.seed)
    dummy_labels = torch.zeros((train_inputs.shape[0],), dtype=torch.long, device=device)
    tr_inputs, tr_targets, _, val_inputs, val_targets, _ = split_train_val(
        train_inputs, train_targets, dummy_labels, 0.1, rng
    )
    train_inputs_fit = tr_inputs
    train_targets_fit = tr_targets
    batches_per_epoch = math.ceil(int(train_inputs_fit.shape[0]) / args.batch_size)

    run_config: Dict[str, Any] = {
        "task_id": getattr(args, "task", None),
        "task_type": "regression",
        "task_name": task_label,
        "seed": int(getattr(args, "seed", 0)),
        "epochs": int(getattr(args, "epochs", 0)),
        "scan_epochs": int(getattr(args, "scan_epochs", 0)),
        "batch_size": int(getattr(args, "batch_size", 0)),
        "hidden": int(getattr(args, "hidden", 0)),
        "lr": float(getattr(args, "lr", 0.0)),
        "time_weighting": time_weighting,
        "eval_mode": str(eval_mode),
        "eval_warmup": int(eval_warmup),
        "tbptt_short": int(tbptt_short),
        "tbptt_long": int(tbptt_long),
        "gains": [float(x) for x in np.asarray(gains).ravel().tolist()],
        "input_size": int(task_data["input_size"]),
        "output_size": int(task_data["output_size"]),
        "time_steps": int(time_steps),
        "train_size": int(train_inputs.shape[0]),
        "val_size": int(val_inputs.shape[0]),
        "test_size": int(test_inputs.shape[0]),
        "device": str(device),
        "args": dict(vars(args)),
        "argv": list(sys.argv),
        "eprop_feedback": EPROP_FEEDBACK,
        "eprop_seed": int(EPROP_SEED),
        "fptt_parts": int(FPTT_PARTS),
        "fptt_lambda": float(FPTT_LAMBDA),
    }

    # build_local：构建 Local Rule 模型及其优化器配置。
    # 说明：为上层训练/评估流程提供辅助支持。
    def build_local() -> TorchLocalRuleRNN:
        return TorchLocalRuleRNN(
            task_data["input_size"],
            args.hidden,
            task_data["output_size"],
            eta=args.lr,
            loss_mode="mse",
            seed=args.seed,
            device=device,
        )

    # 步骤2：增益扫描（Local Rule）并选最优初始化。
    lyapunov_driver = build_lyapunov_driver(val_inputs)
    best_g, init_params, stats = scan_gains_regression(
        gains,
        build_local,
        tr_inputs,
        tr_targets,
        val_inputs,
        val_targets,
        args.batch_size,
        args.scan_epochs,
        args.seed,
        lyapunov_driver=lyapunov_driver,
        step_weights=step_weights_t,
        eval_mode=eval_mode,
        eval_warmup=eval_warmup,
        plot_dir=plot_dir,
        plot_tag=plot_tag,
        task_label=task_label,
        plot_show=False,
    )
    summary = f"Best g={best_g:.3f} | {val_label}={stats['val_loss']:.6f}"
    if "lambda_pre" in stats and "lambda_post" in stats:
        delta = stats["lambda_post"] - stats["lambda_pre"]
        summary += f" | lyap(pre:{stats['lambda_pre']:.4f}, post:{stats['lambda_post']:.4f}, d:{delta:.4f})"
    print(summary)
    run_config["best_g"] = float(best_g)
    run_config["scan_best"] = {k: float(v) for k, v in stats.items()}

    bptt_cfg = task_data.get("bptt_config") or {}
    bptt_mode = str(bptt_cfg.get("mode", "full")).lower()
    bptt_steps = bptt_cfg.get("tbptt_steps")
    bptt_time_norm = bool(bptt_cfg.get("time_normalization", True))
    bptt_clip = float(bptt_cfg.get("grad_clip", 5.0))
    bptt_label = bptt_cfg.get("label")

    if bptt_mode in {"tbptt", "truncated"}:
        bptt_steps = int(bptt_steps or tbptt_short)
        if bptt_label is None:
            bptt_label = f"BPTT(TBPTT-{bptt_steps})"
    else:
        bptt_steps = None
        if bptt_label is None:
            bptt_label = "BPTT"

    # build_bptt：构建 BPTT/TBPTT 对照模型实例。
    # 说明：为上层训练/评估流程提供辅助支持。
    def build_bptt(tbptt_steps: int | None) -> BPTTRNN:
        return BPTTRNN(
            task_data["input_size"],
            args.hidden,
            task_data["output_size"],
            eta=args.lr,
            loss_mode="mse",
            max_grad_norm=bptt_clip,
            tbptt_steps=tbptt_steps,
            time_normalization=bptt_time_norm,
            seed=args.seed,
            device=device,
        )

    # 步骤3：构建对照组模型（Local/BPTT/E-Prop/FPTT/TBPTT）。
    skip_eprop = bool(getattr(args, "no_eprop", False))
    models: Dict[str, Any] = {
        "Local Rule": build_local(),
        bptt_label: build_bptt(bptt_steps),
    }
    if not skip_eprop:
        models["E-Prop"] = StandardEPropRNN(
            task_data["input_size"],
            args.hidden,
            task_data["output_size"],
            eta=args.lr,
            feedback=EPROP_FEEDBACK,
            seed=EPROP_SEED,
            loss_mode="mse",
            device=device,
        )
    models["FPTT"] = StrictFPTTRegressor(
        task_data["input_size"],
        args.hidden,
        task_data["output_size"],
        eta=args.lr,
        parts=FPTT_PARTS,
        clip=5.0,
        lmbda=FPTT_LAMBDA,
        device=device,
    )
    if bptt_mode == "full":
        if 0 < tbptt_short < time_steps:
            models[f"TBPTT-{tbptt_short}"] = build_bptt(tbptt_short)
        if tbptt_long != tbptt_short and 0 < tbptt_long < time_steps:
            models[f"TBPTT-{tbptt_long}"] = build_bptt(tbptt_long)

    prediction_bundle: Dict[str, Any] | None = None
    if plot_enabled and prediction_cfg:
        if test_inputs_np.shape[0] == 0:
            print("[PLOT] Skipping prediction plot: empty test set.")
        else:
            plot_type = str(prediction_cfg.get("type", "")).lower()
            if not plot_type and "frame_h" in prediction_cfg and "frame_w" in prediction_cfg:
                plot_type = "image"
            if plot_type not in {"image", "trajectory3d", "timeseries"}:
                print("[PLOT] Skipping prediction plot: unknown type.")
            else:
                sample_index = int(prediction_cfg.get("sample_index", 0))
                sample_index = max(0, min(sample_index, test_inputs_np.shape[0] - 1))
                prediction_bundle = {
                    "plot_type": plot_type,
                    "sample_input": test_inputs_np[sample_index],
                    "sample_target": test_targets_np[sample_index],
                    "predictions": {},
                    "model_names": prediction_cfg.get("models"),
                    "local_only": bool(prediction_cfg.get("local_only", False)),
                }
                pred_mode = str(prediction_cfg.get("mode", "teacher")).lower()
                pred_warmup = max(1, int(prediction_cfg.get("warmup_steps", 1)))
                if pred_mode in {"rollout", "autoregressive", "ar"}:
                    if task_data["input_size"] != task_data["output_size"]:
                        print(
                            "[PLOT] Autoregressive prediction requires input_size == output_size; "
                            "using teacher forcing."
                        )
                        pred_mode = "teacher"
                if pred_mode in {"rollout", "autoregressive", "ar"}:
                    mode_label = "Rollout"
                elif pred_mode in {"teacher", "teacher_forcing"}:
                    mode_label = "Teacher"
                else:
                    mode_label = pred_mode.replace("_", " ").title()
                prediction_bundle.update(
                    {
                        "prediction_mode": pred_mode,
                        "prediction_warmup": pred_warmup,
                        "show_input": bool(prediction_cfg.get("show_input", False)),
                        "mode_label": mode_label,
                    }
                )
                if plot_type == "image":
                    frame_h = int(prediction_cfg.get("frame_h", 0))
                    frame_w = int(prediction_cfg.get("frame_w", 0))
                    frame_channels = max(1, int(prediction_cfg.get("frame_channels", 1)))
                    frame_depth = frame_h * frame_w * frame_channels
                    if (
                        frame_h > 0
                        and frame_w > 0
                        and frame_depth > 0
                        and test_targets_np.shape[1] % frame_depth == 0
                    ):
                        time_indices = prediction_cfg.get("time_indices")
                        if time_indices is not None:
                            time_indices = [int(t) for t in time_indices]
                        frame_index = prediction_cfg.get("frame_index")
                        if frame_index is not None:
                            frame_index = int(frame_index)
                        prediction_bundle.update(
                            {
                                "frame_h": frame_h,
                                "frame_w": frame_w,
                                "frame_channels": frame_channels,
                                "time_indices": time_indices,
                                "frame_index": frame_index,
                            }
                        )
                    else:
                        prediction_bundle = None
                        print("[PLOT] Skipping prediction plot: frame size mismatch.")
                elif plot_type in {"trajectory3d", "timeseries"}:
                    dims = prediction_cfg.get("dims")
                    if dims is not None:
                        dims = [int(d) for d in dims]
                    prediction_bundle["dims"] = dims
    pred_mode = None
    pred_warmup = 1
    mode_label = None
    show_input = False
    if prediction_bundle is not None:
        pred_mode = str(prediction_bundle.get("prediction_mode", "teacher")).lower()
        pred_warmup = int(prediction_bundle.get("prediction_warmup", 1))
        mode_label = prediction_bundle.get("mode_label")
        show_input = bool(prediction_bundle.get("show_input", False))
    # 步骤4：训练/评估各模型并汇总结果。
    results: Dict[str, Dict[str, Any]] = {}
    log_every = max(1, args.epochs // 5)

    for name, model in models.items():
        load_params(model, init_params)
        lambda_pre = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
        history: List[float] = []
        complexity = estimate_model_complexity(model)
        update_stats = estimate_training_counts(model, time_steps, batches_per_epoch, args.epochs)
        train_runtime_sec = 0.0
        eval_runtime_sec = 0.0
        best_epoch = 0
        best_val_metric = float("inf")
        best_params: Dict[str, np.ndarray] | None = None
        if plot_enabled:
            for epoch in range(args.epochs):
                train_start = time.perf_counter()
                train_batches(
                    model,
                    train_inputs_fit,
                    train_targets_fit,
                    args.batch_size,
                    1,
                    args.seed + 10 + epoch,
                    epoch_offset=epoch,
                    step_weights=step_weights_t,
                )
                train_runtime_sec += time.perf_counter() - train_start
                eval_start = time.perf_counter()
                if eval_mode == "rollout":
                    val_mse = evaluate_regression_mse_rollout(
                        model,
                        val_inputs,
                        val_targets,
                        args.batch_size,
                        warmup_steps=eval_warmup,
                        step_weights=step_weights_t,
                    )
                else:
                    val_mse = evaluate_regression_mse(
                        model,
                        val_inputs,
                        val_targets,
                        args.batch_size,
                        step_weights=step_weights_t,
                    )
                eval_runtime_sec += time.perf_counter() - eval_start
                history.append(float(val_mse))
                if float(val_mse) < best_val_metric:
                    best_val_metric = float(val_mse)
                    best_epoch = epoch + 1
                    best_params = extract_params(model)
                if (epoch + 1) % log_every == 0 or (epoch + 1) == args.epochs:
                    print(
                        f"[{name}] epoch={epoch+1:02d} | val_mse={val_mse:.6f} ({eval_label})"
                    )
        else:
            train_start = time.perf_counter()
            train_batches(
                model,
                train_inputs_fit,
                train_targets_fit,
                args.batch_size,
                args.epochs,
                args.seed + 10,
                epoch_offset=0,
                step_weights=step_weights_t,
            )
            train_runtime_sec = time.perf_counter() - train_start
        if best_params is not None:
            load_params(model, best_params)
        eval_start = time.perf_counter()
        if eval_mode == "rollout":
            final_test_mse = evaluate_regression_mse_rollout(
                model,
                test_inputs,
                test_targets,
                args.batch_size,
                warmup_steps=eval_warmup,
                step_weights=step_weights_t,
            )
        else:
            final_test_mse = evaluate_regression_mse(
                model,
                test_inputs,
                test_targets,
                args.batch_size,
                step_weights=step_weights_t,
            )
        eval_runtime_sec += time.perf_counter() - eval_start
        if best_epoch == 0:
            best_epoch = int(args.epochs)
            best_val_metric = float("nan")
        runtime_sec = train_runtime_sec + eval_runtime_sec
        updates_total = update_stats["updates_total"]
        steps_total = update_stats["steps_total"]
        runtime_per_update_sec = (
            train_runtime_sec / updates_total if updates_total > 0 else float("nan")
        )
        runtime_per_step_sec = (
            train_runtime_sec / steps_total if steps_total > 0 else float("nan")
        )
        lambda_post = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
        delta = lambda_post - lambda_pre
        print(
            f"[{name}] test_mse={final_test_mse:.6f} ({eval_label}) | "
            f"best_val_mse={best_val_metric:.6f} (epoch={best_epoch:02d}) | "
            f"lyap=(pre:{lambda_pre:.4f}, post:{lambda_post:.4f}, d:{delta:.4f})"
        )
        results[name] = {
            "metric": float(final_test_mse),
            "val_metric": float(best_val_metric),
            "val_loss": float(best_val_metric),
            "best_epoch": int(best_epoch),
            "lyap_pre": float(lambda_pre),
            "lyap_post": float(lambda_post),
            "history": history,
            "complexity_params": float(complexity["params"]),
            "complexity_state": float(complexity["state"]),
            "complexity_total": float(complexity["total"]),
            "runtime_sec": float(runtime_sec),
            "train_runtime_sec": float(train_runtime_sec),
            "eval_runtime_sec": float(eval_runtime_sec),
            "runtime_per_update_sec": float(runtime_per_update_sec),
            "runtime_per_step_sec": float(runtime_per_step_sec),
            "batches_per_epoch": int(update_stats["batches_per_epoch"]),
            "time_steps": int(update_stats["time_steps"]),
            "update_factor": int(update_stats["update_factor"]),
            "updates_per_epoch": int(update_stats["updates_per_epoch"]),
            "updates_total": int(update_stats["updates_total"]),
            "steps_total": int(update_stats["steps_total"]),
            "hparams": extract_model_hparams(model),
        }
        if prediction_bundle is not None:
            model_names = prediction_bundle["model_names"]
            use_rollout = pred_mode in {"rollout", "autoregressive", "ar"}
            if model_names is not None:
                if name in model_names:
                    if use_rollout:
                        prediction_bundle["predictions"][name] = predict_sequence_rollout(
                            model,
                            prediction_bundle["sample_input"],
                            warmup_steps=pred_warmup,
                        )
                    else:
                        prediction_bundle["predictions"][name] = predict_sequence_outputs(
                            model,
                            prediction_bundle["sample_input"],
                        )
            else:
                if prediction_bundle["local_only"] and not name.startswith("Local Rule"):
                    continue
                if use_rollout:
                    prediction_bundle["predictions"][name] = predict_sequence_rollout(
                        model,
                        prediction_bundle["sample_input"],
                        warmup_steps=pred_warmup,
                    )
                else:
                    prediction_bundle["predictions"][name] = predict_sequence_outputs(
                        model,
                        prediction_bundle["sample_input"],
                    )

    if plot_enabled and plot_dir is not None and plot_tag is not None:
        save_results_summary(
            task_label,
            metric_label,
            results,
            plot_dir,
            plot_tag,
            run_config=run_config,
        )
        plot_comparison_results(
            task_label,
            results,
            metric_label=metric_label,
            history_label=history_label,
            higher_is_better=False,
            plot_dir=plot_dir,
            plot_tag=plot_tag,
            show=False,
        )
        plot_cost_profile(
            task_label,
            results,
            plot_dir=plot_dir,
            plot_tag=plot_tag,
            show=False,
        )
        if prediction_bundle is not None and prediction_bundle["predictions"]:
            plot_type = prediction_bundle["plot_type"]
            if plot_type == "image":
                plot_image_sequence_predictions(
                    task_label,
                    prediction_bundle["predictions"],
                    prediction_bundle["sample_target"],
                    prediction_bundle["frame_h"],
                    prediction_bundle["frame_w"],
                    prediction_bundle.get("frame_channels", 1),
                    prediction_bundle["time_indices"],
                    prediction_bundle.get("frame_index"),
                    plot_dir,
                    plot_tag,
                    show=False,
                    inputs=prediction_bundle["sample_input"] if show_input else None,
                    mode_label=mode_label,
                )
            elif plot_type == "trajectory3d":
                pred_plot_path = build_plot_path(plot_dir, plot_tag, "pred_trajectory")
                plot_trajectory_predictions(
                    task_label,
                    prediction_bundle["predictions"],
                    prediction_bundle["sample_target"],
                    prediction_bundle.get("dims"),
                    pred_plot_path,
                    show=False,
                )
            elif plot_type == "timeseries":
                pred_plot_path = build_plot_path(plot_dir, plot_tag, "pred_timeseries")
                plot_timeseries_predictions(
                    task_label,
                    prediction_bundle["predictions"],
                    prediction_bundle["sample_target"],
                    prediction_bundle.get("dims"),
                    pred_plot_path,
                    show=False,
                )


# run_lm_task：语言模型入口：准备 PTB 数据并训练评估。
# 说明：为上层训练/评估流程提供辅助支持。
def run_lm_task(task_data: Dict[str, Any], args: argparse.Namespace, gains: np.ndarray) -> None:
    set_global_seed(int(getattr(args, "seed", 0)))
    # 语言模型入口：PTB 序列切块 + 交叉熵评估。
    # 步骤1：准备数据、权重与绘图配置。
    task_label = task_data.get("task_name", getattr(args, "task", "Sequence Task"))
    plot_enabled = getattr(args, "plot", True)
    plot_dir = None
    plot_tag = None
    if plot_enabled:
        plot_dir, plot_tag = resolve_plot_context(args, task_label)
    device = DEFAULT_DEVICE
    train_data = to_tensor(task_data["train_data"], device=device, dtype=torch.long)
    valid_data = to_tensor(task_data["valid_data"], device=device, dtype=torch.long)
    test_data = to_tensor(task_data["test_data"], device=device, dtype=torch.long)
    vocab_size = task_data["output_size"]
    step_label_mode = str(getattr(args, "step_labels", "final")).lower()
    use_surrogates = step_label_mode == "fptt"
    batches_per_epoch = int(args.ptb_steps_per_epoch)
    time_steps = int(args.ptb_block_size)
    tbptt_short = max(1, int(getattr(args, "tbptt_short", 1)))
    tbptt_long_default = max(2, min(10, time_steps))
    tbptt_long_value = getattr(args, "tbptt_long", None)
    tbptt_long = tbptt_long_default if tbptt_long_value is None else int(tbptt_long_value)

    run_config: Dict[str, Any] = {
        "task_id": getattr(args, "task", None),
        "task_type": "lm",
        "task_name": task_label,
        "seed": int(getattr(args, "seed", 0)),
        "epochs": int(getattr(args, "epochs", 0)),
        "scan_epochs": int(getattr(args, "scan_epochs", 0)),
        "batch_size": int(getattr(args, "batch_size", 0)),
        "hidden": int(getattr(args, "hidden", 0)),
        "lr": float(getattr(args, "lr", 0.0)),
        "step_labels": step_label_mode,
        "use_surrogates": bool(use_surrogates),
        "tbptt_short": int(tbptt_short),
        "tbptt_long": int(tbptt_long),
        "gains": [float(x) for x in np.asarray(gains).ravel().tolist()],
        "vocab_size": int(vocab_size),
        "block_size": int(getattr(args, "ptb_block_size", 0)),
        "steps_per_epoch": int(getattr(args, "ptb_steps_per_epoch", 0)),
        "val_steps": int(getattr(args, "ptb_val_steps", 0)),
        "eval_every": int(getattr(args, "eval_every", 0)),
        "train_tokens": int(train_data.shape[0]),
        "valid_tokens": int(valid_data.shape[0]),
        "test_tokens": int(test_data.shape[0]),
        "device": str(device),
        "args": dict(vars(args)),
        "argv": list(sys.argv),
        "eprop_feedback": EPROP_FEEDBACK,
        "eprop_seed": int(EPROP_SEED),
        "fptt_parts": int(FPTT_PARTS),
        "fptt_lambda": float(FPTT_LAMBDA),
        "fptt_oracle_momentum": float(FPTT_ORACLE_MOMENTUM),
    }

    # build_local：构建 Local Rule 模型及其优化器配置。
    # 说明：为上层训练/评估流程提供辅助支持。
    def build_local() -> TorchLocalRuleRNN:
        return TorchLocalRuleRNN(
            task_data["input_size"],
            args.hidden,
            vocab_size,
            eta=args.lr,
            loss_mode="ce",
            seed=args.seed,
            device=device,
        )

    # 步骤2：增益扫描（Local Rule）并选最优初始化。
    lyapunov_driver = build_lm_lyapunov_driver(
        train_data,
        vocab_size,
        args.ptb_block_size,
        args.seed,
    )
    best_g, init_params, stats = scan_gains_lm(
        gains,
        build_local,
        train_data,
        valid_data,
        vocab_size,
        args.ptb_block_size,
        args.batch_size,
        args.ptb_steps_per_epoch,
        args.ptb_val_steps,
        args.scan_epochs,
        args.seed,
        lyapunov_driver=lyapunov_driver,
        use_surrogates=use_surrogates,
        plot_dir=plot_dir,
        plot_tag=plot_tag,
        task_label=task_label,
        plot_show=False,
    )
    summary = f"Best g={best_g:.3f} | val_loss={stats['val_loss']:.4f}"
    if "lambda_pre" in stats and "lambda_post" in stats:
        delta = stats["lambda_post"] - stats["lambda_pre"]
        summary += f" | lyap(pre:{stats['lambda_pre']:.4f}, post:{stats['lambda_post']:.4f}, d:{delta:.4f})"
    print(summary)
    run_config["best_g"] = float(best_g)
    run_config["scan_best"] = {k: float(v) for k, v in stats.items()}

    local_name = "Local Rule (FPTT)" if use_surrogates else "Local Rule"
    # 步骤3：构建对照组模型（Local/BPTT/E-Prop/FPTT/TBPTT）。
    skip_eprop = bool(getattr(args, "no_eprop", False))
    models: Dict[str, Any] = {local_name: build_local()}
    models["BPTT"] = BPTTRNN(
        task_data["input_size"],
        args.hidden,
        vocab_size,
        eta=args.lr,
        loss_mode="ce",
        seed=args.seed,
        device=device,
    )
    if not skip_eprop:
        models["E-Prop"] = StandardEPropRNN(
            task_data["input_size"],
            args.hidden,
            vocab_size,
            eta=args.lr,
            feedback=EPROP_FEEDBACK,
            seed=EPROP_SEED,
            loss_mode="ce",
            device=device,
        )
    models["FPTT"] = StrictFPTTClassifier(
        task_data["input_size"],
        args.hidden,
        vocab_size,
        eta=args.lr,
        parts=FPTT_PARTS,
        clip=5.0,
        lmbda=FPTT_LAMBDA,
        oracle_momentum=FPTT_ORACLE_MOMENTUM,
        label_mode="all",
        oracle_id=f"{args.task}_lm",
        use_oracle=False,
        device=device,
    )
    if 0 < tbptt_short < time_steps:
        models[f"TBPTT-{tbptt_short}"] = BPTTRNN(
            task_data["input_size"],
            args.hidden,
            vocab_size,
            eta=args.lr,
            loss_mode="ce",
            tbptt_steps=tbptt_short,
            seed=args.seed,
            device=device,
        )
    if tbptt_long != tbptt_short and 0 < tbptt_long < time_steps:
        models[f"TBPTT-{tbptt_long}"] = BPTTRNN(
            task_data["input_size"],
            args.hidden,
            vocab_size,
            eta=args.lr,
            loss_mode="ce",
            tbptt_steps=tbptt_long,
            seed=args.seed,
            device=device,
        )

    results: Dict[str, Dict[str, Any]] = {}
    eval_every = max(1, int(getattr(args, "eval_every", 1)))
    log_every = max(1, int(args.epochs) // 5)

    for name, model in models.items():
        load_params(model, init_params)
        lambda_pre = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
        history: List[float] = []
        use_surrogates_local = use_surrogates and name == local_name
        complexity = estimate_model_complexity(model)
        update_stats = estimate_training_counts(model, time_steps, batches_per_epoch, args.epochs)
        train_runtime_sec = 0.0
        eval_runtime_sec = 0.0
        val_loss = float("nan")
        for epoch in range(args.epochs):
            train_start = time.perf_counter()
            train_epoch_lm(
                model,
                train_data,
                vocab_size,
                args.ptb_block_size,
                args.batch_size,
                args.ptb_steps_per_epoch,
                args.seed + epoch + 10,
                epoch,
                use_surrogates=use_surrogates_local,
            )
            train_runtime_sec += time.perf_counter() - train_start
            if (epoch + 1) % eval_every == 0 or (epoch + 1) == args.epochs:
                eval_start = time.perf_counter()
                val_loss = evaluate_language_model(
                    model,
                    valid_data,
                    vocab_size,
                    args.ptb_block_size,
                    args.batch_size,
                    args.ptb_val_steps,
                )
                eval_runtime_sec += time.perf_counter() - eval_start
                if plot_enabled:
                    history.append(float(val_loss))
                print(f"[{name}] epoch={epoch+1:02d} | val_loss={val_loss:.4f}")
            elif (epoch + 1) % log_every == 0:
                print(f"[{name}] epoch={epoch+1:02d} | val_loss=skip")

        eval_start = time.perf_counter()
        test_loss = evaluate_language_model(
            model,
            test_data,
            vocab_size,
            args.ptb_block_size,
            args.batch_size,
            args.ptb_val_steps,
        )
        eval_runtime_sec += time.perf_counter() - eval_start
        runtime_sec = train_runtime_sec + eval_runtime_sec
        updates_total = update_stats["updates_total"]
        steps_total = update_stats["steps_total"]
        runtime_per_update_sec = (
            train_runtime_sec / updates_total if updates_total > 0 else float("nan")
        )
        runtime_per_step_sec = (
            train_runtime_sec / steps_total if steps_total > 0 else float("nan")
        )
        lambda_post = calculate_lyapunov_exponent_numpy(model, lyapunov_driver)
        delta = lambda_post - lambda_pre
        print(
            f"[{name}] test_loss={test_loss:.4f} | "
            f"lyap=(pre:{lambda_pre:.4f}, post:{lambda_post:.4f}, d:{delta:.4f})"
        )
        results[name] = {
            "metric": float(test_loss),
            "lyap_pre": float(lambda_pre),
            "lyap_post": float(lambda_post),
            "history": history,
            "complexity_params": float(complexity["params"]),
            "complexity_state": float(complexity["state"]),
            "complexity_total": float(complexity["total"]),
            "runtime_sec": float(runtime_sec),
            "train_runtime_sec": float(train_runtime_sec),
            "eval_runtime_sec": float(eval_runtime_sec),
            "runtime_per_update_sec": float(runtime_per_update_sec),
            "runtime_per_step_sec": float(runtime_per_step_sec),
            "batches_per_epoch": int(update_stats["batches_per_epoch"]),
            "time_steps": int(update_stats["time_steps"]),
            "update_factor": int(update_stats["update_factor"]),
            "updates_per_epoch": int(update_stats["updates_per_epoch"]),
            "updates_total": int(update_stats["updates_total"]),
            "steps_total": int(update_stats["steps_total"]),
            "hparams": extract_model_hparams(model),
        }

    if plot_enabled and plot_dir is not None and plot_tag is not None:
        save_results_summary(
            task_label,
            "Test Loss",
            results,
            plot_dir,
            plot_tag,
            run_config=run_config,
        )
        plot_comparison_results(
            task_label,
            results,
            metric_label="Test Loss",
            history_label="Val Loss",
            higher_is_better=False,
            plot_dir=plot_dir,
            plot_tag=plot_tag,
            show=False,
        )
        plot_cost_profile(
            task_label,
            results,
            plot_dir=plot_dir,
            plot_tag=plot_tag,
            show=False,
        )


# =============================
# End of helpers
# =============================
