"""核心"""
from __future__ import annotations

import argparse
import csv
import importlib
import json
import math
import os
import platform
import random
import re
import subprocess
import sys
import threading
import time
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.nn.utils import clip_grad_norm_

# 文件结构速览：
# 1) 通用工具/统计/绘图：resolve_device / build_time_weights / plot_* 等。
# 2) 数据加载（MNIST/Fashion/CIFAR/DVS）：load_mnist_images / load_cifar10_images / load_dvs_cifar10_images。
# 3) 模型定义：SimpleCNNEncoder / TorchLocalRuleConvRNN / TorchEPropConvRNN / TorchBPTTConvRNN / StrictFPTTConvClassifier。
# 4) Lyapunov + 训练/评估/扫描 + 任务入口：build_lyapunov_driver / train_batches / evaluate_* / scan_gains_* / run_classification_task。

DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FPTT_PARTS = 10
FPTT_ORACLE_MOMENTUM = 1.0
FPTT_LAMBDA = 1.0
FPTT_WARMUP_EPOCHS = 20
EPROP_FEEDBACK = "symmetric"
EPROP_SEED = 1234
LYAPUNOV_MIN_STEPS = 50

# 共享默认超参：训练 FPTT/E-Prop 与稳定性扫描时复用。

# =============================
# 通用工具与训练权重
# =============================


# resolve_device：解析 device 参数并选择 CPU/GPU。
# 关键步骤：解析参数 → 选择 CPU/GPU → 返回 device。
# 算法要点：统一设备选择，避免运行不一致。
def resolve_device(device: torch.device | str | None = None) -> torch.device:
    if device is None:
        return DEFAULT_DEVICE
    return torch.device(device)


# to_tensor：将 numpy/列表转为 tensor 并搬到设备。
# 关键步骤：转换 dtype → 搬到设备 → 返回 tensor。
# 算法要点：确保设备一致，避免 CPU/GPU 混用。
def to_tensor(
    array: np.ndarray | torch.Tensor,
    device: torch.device,
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    if torch.is_tensor(array):
        return array.to(device=device, dtype=dtype)
    return torch.as_tensor(array, dtype=dtype, device=device)


# softmax：对 logits 做数值稳定 softmax 并输出概率。
# 关键步骤：减去最大值 → 指数化 → 归一化为概率。
# 算法要点：通过数值稳定化避免指数溢出。
def softmax(logits: torch.Tensor) -> torch.Tensor:
    logits_shifted = logits - torch.max(logits, dim=1, keepdim=True).values
    exp_logits = torch.exp(logits_shifted)
    return exp_logits / (torch.sum(exp_logits, dim=1, keepdim=True) + 1e-12)


# iterate_minibatches：按 batch_size 产出小批次样本。
# 关键步骤：读取输入 → 处理逻辑 → 返回结果。
# 算法要点：作为通用工具支撑上层流程。
def iterate_minibatches(
    inputs: np.ndarray | torch.Tensor,
    targets: np.ndarray | torch.Tensor,
    batch_size: int,
    rng: np.random.Generator,
) -> Iterable[Tuple[np.ndarray | torch.Tensor, np.ndarray | torch.Tensor]]:
    if torch.is_tensor(inputs):
        indices = rng.permutation(inputs.shape[0])
        indices_t = torch.as_tensor(indices, device=inputs.device)
        for start_idx in range(0, inputs.shape[0], batch_size):
            batch_indices = indices_t[start_idx : start_idx + batch_size]
            yield inputs[batch_indices], targets[batch_indices]
        return
    indices = np.arange(inputs.shape[0])
    rng.shuffle(indices)
    for start_idx in range(0, inputs.shape[0], batch_size):
        batch_indices = indices[start_idx : start_idx + batch_size]
        yield inputs[batch_indices], targets[batch_indices]


# split_train_val：按比例拆分训练/验证集并可打乱。
# 关键步骤：可选打乱 → 按比例切分 → 返回子集。
# 算法要点：维持样本分布一致性与随机性。
def split_train_val(
    inputs: np.ndarray | torch.Tensor,
    targets: np.ndarray | torch.Tensor,
    labels: np.ndarray | torch.Tensor,
    val_fraction: float,
    rng: np.random.Generator,
) -> Tuple[
    np.ndarray | torch.Tensor,
    np.ndarray | torch.Tensor,
    np.ndarray | torch.Tensor,
    np.ndarray | torch.Tensor,
    np.ndarray | torch.Tensor,
    np.ndarray | torch.Tensor,
]:
    total = inputs.shape[0]
    val_size = max(1, int(total * val_fraction))
    indices = rng.permutation(total)
    if torch.is_tensor(inputs):
        device = inputs.device
        val_idx = torch.as_tensor(indices[:val_size], device=device)
        train_idx = torch.as_tensor(indices[val_size:], device=device)
        return (
            inputs[train_idx],
            targets[train_idx],
            labels[train_idx],
            inputs[val_idx],
            targets[val_idx],
            labels[val_idx],
        )
    val_idx = indices[:val_size]
    train_idx = indices[val_size:]
    return (
        inputs[train_idx],
        targets[train_idx],
        labels[train_idx],
        inputs[val_idx],
        targets[val_idx],
        labels[val_idx],
    )


# build_repeated_targets：将目标扩展为每个时间步重复的序列标签。
# 关键步骤：复制标签 → 对齐时间维 → 返回序列。
# 算法要点：保持标签与序列长度对齐。
def build_repeated_targets(
    labels: np.ndarray | torch.Tensor,
    num_classes: int,
    time_steps: int,
) -> np.ndarray | torch.Tensor:
    if torch.is_tensor(labels):
        labels = labels.to(dtype=torch.long)
        onehot = torch.zeros(
            (labels.shape[0], num_classes),
            dtype=torch.float32,
            device=labels.device,
        )
        onehot.scatter_(1, labels.view(-1, 1), 1.0)
        return onehot.unsqueeze(2).repeat(1, 1, time_steps)
    onehot = np.zeros((labels.shape[0], num_classes), dtype=np.float32)
    onehot[np.arange(labels.shape[0]), labels] = 1.0
    return np.repeat(onehot[:, :, None], time_steps, axis=2)


# normalize_train_test：使用训练集统计量归一化训练/测试集。
# 关键步骤：计算训练均值方差 → 归一化 train/test。
# 算法要点：使用训练统计量防止数据泄漏。
def normalize_train_test(
    train_data: np.ndarray,
    test_data: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
    mean = float(np.mean(train_data))
    std = float(np.std(train_data)) + 1e-7
    return (train_data - mean) / std, (test_data - mean) / std


# ensure_python_package：确保依赖包可用，必要时提示/安装。
# 关键步骤：读取输入 → 处理逻辑 → 返回结果。
# 算法要点：避免重复下载并保证数据完整。
def ensure_python_package(import_name: str, package_name: str | None = None) -> bool:
    try:
        importlib.import_module(import_name)
        return True
    except ImportError:
        pass

    pkg = package_name or import_name
    try:
        print(f"[SETUP] Installing missing package: {pkg}")
        subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])
        importlib.import_module(import_name)
        return True
    except Exception as exc:
        print(f"[SETUP] Failed to install {pkg}: {exc}")
        return False


def available_physical_memory_bytes() -> int | None:
    try:
        import ctypes

        class MEMORYSTATUSEX(ctypes.Structure):
            _fields_ = [
                ("dwLength", ctypes.c_ulong),
                ("dwMemoryLoad", ctypes.c_ulong),
                ("ullTotalPhys", ctypes.c_ulonglong),
                ("ullAvailPhys", ctypes.c_ulonglong),
                ("ullTotalPageFile", ctypes.c_ulonglong),
                ("ullAvailPageFile", ctypes.c_ulonglong),
                ("ullTotalVirtual", ctypes.c_ulonglong),
                ("ullAvailVirtual", ctypes.c_ulonglong),
                ("ullAvailExtendedVirtual", ctypes.c_ulonglong),
            ]

        stat = MEMORYSTATUSEX()
        stat.dwLength = ctypes.sizeof(MEMORYSTATUSEX)
        if not ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(stat)):
            return None
        return int(stat.ullAvailPhys)
    except Exception:
        return None


# build_time_weights：生成时间步权重向量用于损失加权。
# 关键步骤：确定策略 → 生成权重 → 返回向量。
# 算法要点：体现时间步权重策略的损失偏好。
def build_time_weights(time_steps: int, mode: str | None, min_weight: float = 0.05) -> np.ndarray | None:
    if mode is None:
        return None
    mode = mode.lower()
    if mode == "none":
        return None
    time_steps = max(1, int(time_steps))
    min_weight = float(max(min_weight, 1e-4))
    if mode in {"final", "last"}:
        weights = np.zeros(time_steps, dtype=np.float32)
        weights[-1] = 1.0
        return weights
    if mode in {"late", "linear"}:
        return np.linspace(min_weight, 1.0, time_steps, dtype=np.float32)
    raise ValueError(f"Unknown time-weighting mode: {mode}")


# build_label_step_weights：生成标签逐步权重（none/final/fptt）。
# 关键步骤：选择策略 → 生成步权重 → 返回向量。
# 算法要点：控制时间维监督强度。
def build_label_step_weights(
    time_steps: int,
    step_label_mode: str | None,
    time_weighting: str | None,
) -> np.ndarray | None:
    mode = str(step_label_mode or "final").lower()
    if mode in {"final", "last"}:
        weights = np.zeros(time_steps, dtype=np.float32)
        weights[-1] = 1.0
        return weights
    return build_time_weights(time_steps, time_weighting)


# clip_gradients：按阈值裁剪梯度，避免梯度爆炸。
# 关键步骤：读取输入 → 处理逻辑 → 返回结果。
# 算法要点：提升数值稳定性，避免溢出/爆炸。
def clip_gradients(grads: List[torch.Tensor], max_norm: float) -> List[torch.Tensor]:
    if max_norm <= 0:
        return grads
    total_norm_sq = torch.zeros((), device=grads[0].device)
    for g in grads:
        total_norm_sq = total_norm_sq + torch.sum(g ** 2)
    total_norm = torch.sqrt(total_norm_sq)
    if not torch.isfinite(total_norm) or total_norm <= max_norm:
        return grads
    scale = max_norm / (total_norm + 1e-8)
    return [g * scale for g in grads]


# _count_elements：递归统计张量/列表/字典的元素数量。
# 关键步骤：遍历参数/状态 → 累计统计 → 返回计数。
# 算法要点：用于对比复杂度/内存开销。
def _count_elements(value: Any) -> int:
    if value is None:
        return 0
    if torch.is_tensor(value):
        return int(value.numel())
    return int(np.asarray(value).size)


# count_model_parameters：统计模型参数量与可学习参数量。
# 关键步骤：遍历参数/状态 → 累计统计 → 返回计数。
# 算法要点：用于对比复杂度/内存开销。
def count_model_parameters(model: Any) -> int:
    if isinstance(model, nn.Module):
        return int(sum(p.numel() for p in model.parameters()))
    total = 0
    for name in ("W_hh", "W_xh", "b_h", "W_hy", "b_y"):
        if hasattr(model, name):
            total += _count_elements(getattr(model, name))
    if hasattr(model, "encoder"):
        total += int(sum(p.numel() for p in model.encoder.parameters()))
    return total


# count_persistent_state：统计模型持久状态/缓冲占用规模。
# 关键步骤：遍历参数/状态 → 累计统计 → 返回计数。
# 算法要点：用于对比复杂度/内存开销。
def count_persistent_state(model: Any) -> int:
    total = 0
    for name in ("alpha_num", "alpha_den", "alpha_hat", "S_A2", "S_AB", "lambda_vals"):
        if hasattr(model, name):
            total += _count_elements(getattr(model, name))
    for name in ("fptt_Q_prev", "fptt_Q_sum", "fptt_Q_count"):
        if hasattr(model, name):
            total += _count_elements(getattr(model, name))
    if hasattr(model, "B_fb"):
        total += _count_elements(getattr(model, "B_fb"))
    regularizer = getattr(model, "_regularizer", None)
    if regularizer is not None and hasattr(regularizer, "_state"):
        for state in regularizer._state.values():
            total += _count_elements(state.get("sm"))
            total += _count_elements(state.get("lm"))
    return total


# estimate_model_complexity：估算模型参数规模与计算复杂度。
# 关键步骤：采集规模 → 估算成本 → 返回估计值。
# 算法要点：用于粗略评估训练成本与规模。
def estimate_model_complexity(model: Any) -> Dict[str, int]:
    params = count_model_parameters(model)
    state = count_persistent_state(model)
    return {"params": params, "state": state, "total": params + state}


# estimate_update_factor：估算一次更新的尺度因子。
# 关键步骤：采集规模 → 估算成本 → 返回估计值。
# 算法要点：用于粗略评估训练成本与规模。
def estimate_update_factor(model: Any, time_steps: int) -> int:
    time_steps = max(1, int(time_steps))
    if isinstance(model, (TorchLocalRuleConvRNN, TorchEPropConvRNN)):
        return time_steps
    if isinstance(model, StrictFPTTConvClassifier):
        parts = getattr(model, "parts", FPTT_PARTS)
        return len(build_chunk_schedule(time_steps, parts))
    if isinstance(model, TorchBPTTConvRNN):
        tbptt_steps = getattr(model, "tbptt_steps", None)
        if tbptt_steps is not None:
            tbptt_steps = int(tbptt_steps)
            if 0 < tbptt_steps < time_steps:
                return int(math.ceil(time_steps / tbptt_steps))
    return 1


# estimate_training_counts：估算训练步数/批次数量。
# 关键步骤：采集规模 → 估算成本 → 返回估计值。
# 算法要点：用于粗略评估训练成本与规模。
def estimate_training_counts(
    model: Any,
    time_steps: int,
    batches_per_epoch: int,
    epochs: int,
) -> Dict[str, int]:
    batches = max(1, int(batches_per_epoch))
    epochs = max(1, int(epochs))
    time_steps = max(1, int(time_steps))
    update_factor = estimate_update_factor(model, time_steps)
    updates_per_epoch = update_factor * batches
    updates_total = updates_per_epoch * epochs
    steps_total = batches * epochs * time_steps
    return {
        "batches_per_epoch": batches,
        "time_steps": time_steps,
        "update_factor": update_factor,
        "updates_per_epoch": updates_per_epoch,
        "updates_total": updates_total,
        "steps_total": steps_total,
    }


def _get_process_rss_bytes() -> int:
    if os.name == "nt":
        try:
            import ctypes
            from ctypes import wintypes

            class PROCESS_MEMORY_COUNTERS(ctypes.Structure):
                _fields_ = [
                    ("cb", wintypes.DWORD),
                    ("PageFaultCount", wintypes.DWORD),
                    ("PeakWorkingSetSize", ctypes.c_size_t),
                    ("WorkingSetSize", ctypes.c_size_t),
                    ("QuotaPeakPagedPoolUsage", ctypes.c_size_t),
                    ("QuotaPagedPoolUsage", ctypes.c_size_t),
                    ("QuotaPeakNonPagedPoolUsage", ctypes.c_size_t),
                    ("QuotaNonPagedPoolUsage", ctypes.c_size_t),
                    ("PagefileUsage", ctypes.c_size_t),
                    ("PeakPagefileUsage", ctypes.c_size_t),
                ]

            counters = PROCESS_MEMORY_COUNTERS()
            counters.cb = ctypes.sizeof(PROCESS_MEMORY_COUNTERS)
            handle = ctypes.windll.kernel32.GetCurrentProcess()
            if not ctypes.windll.psapi.GetProcessMemoryInfo(handle, ctypes.byref(counters), counters.cb):
                return 0
            return int(counters.WorkingSetSize)
        except Exception:
            return 0
    try:
        with open("/proc/self/statm", "r", encoding="utf-8") as handle:
            parts = handle.read().strip().split()
        rss_pages = int(parts[1]) if len(parts) > 1 else 0
        return int(rss_pages * os.sysconf("SC_PAGE_SIZE"))
    except Exception:
        return 0


class _RSSSampler:
    def __init__(self, interval_sec: float = 0.001) -> None:
        self.interval_sec = max(1e-4, float(interval_sec))
        self.baseline_bytes = 0
        self.peak_bytes = 0
        self._stop = threading.Event()
        self._thread = threading.Thread(target=self._run, daemon=True)

    def _run(self) -> None:
        while not self._stop.is_set():
            rss = _get_process_rss_bytes()
            if rss > self.peak_bytes:
                self.peak_bytes = rss
            time.sleep(self.interval_sec)

    def __enter__(self) -> "_RSSSampler":
        self.baseline_bytes = _get_process_rss_bytes()
        self.peak_bytes = self.baseline_bytes
        self._thread.start()
        return self

    def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
        self._stop.set()
        self._thread.join(timeout=1.0)
        return False


def profile_update_cost(
    model: Any,
    inputs_batch: np.ndarray | torch.Tensor,
    targets_batch: np.ndarray | torch.Tensor,
    *,
    use_surrogates: bool,
    step_weights: np.ndarray | torch.Tensor | None,
) -> Dict[str, float]:
    flag = str(os.getenv("OLL_PROFILE_UPDATE_COST", "1")).strip().lower()
    if flag in {"0", "false", "no"}:
        return {}
    try:
        from torch.profiler import ProfilerActivity, profile
    except Exception as exc:
        print(f"[PROFILE] Skipping FLOPs/memory profiling: torch.profiler unavailable ({exc}).")
        return {}

    device = resolve_device(getattr(model, "device", None))
    if hasattr(model, "W_hh") and torch.is_tensor(getattr(model, "W_hh")):
        device = getattr(model, "W_hh").device
    inputs = to_tensor(inputs_batch, device=device, dtype=torch.float32)
    targets = to_tensor(targets_batch, device=device, dtype=torch.float32)
    batch_size = int(inputs.shape[0])
    time_steps = int(targets.shape[2])
    output_size = int(targets.shape[1])

    if hasattr(model, "step_weights"):
        model.step_weights = (
            to_tensor(step_weights, device=device, dtype=torch.float32) if step_weights is not None else None
        )

    if use_surrogates and hasattr(model, "enable_fptt_surrogates"):
        if getattr(model, "fptt_Q_prev", None) is None:
            Q0 = torch.full(
                (output_size, time_steps),
                1.0 / float(output_size),
                dtype=torch.float32,
                device=device,
            )
            model.enable_fptt_surrogates(time_steps, output_size, Q0)
        if hasattr(model, "reset_fptt_epoch_accumulators"):
            model.reset_fptt_epoch_accumulators(time_steps, output_size)

    update_factor = estimate_update_factor(model, time_steps)

    def _run_step() -> None:
        model.train_batch(inputs, targets, None)

    prev_use_oracle: bool | None = None
    if hasattr(model, "use_oracle"):
        prev_use_oracle = bool(getattr(model, "use_oracle"))
        setattr(model, "use_oracle", False)

    use_cuda = bool(torch.cuda.is_available()) and getattr(device, "type", "") == "cuda"
    activities: List[ProfilerActivity] = [ProfilerActivity.CPU]
    if use_cuda:
        activities.append(ProfilerActivity.CUDA)

    sampler: _RSSSampler | None = None
    peak_cuda_bytes = 0
    delta_cuda_bytes = 0
    peak_cuda_reserved_bytes = 0
    delta_cuda_reserved_bytes = 0
    peak_cpu_alloc_delta_bytes = 0
    try:
        if use_cuda:
            torch.cuda.synchronize(device)
            baseline_cuda_alloc = int(torch.cuda.memory_allocated(device))
            baseline_cuda_reserved = int(torch.cuda.memory_reserved(device))
            torch.cuda.reset_peak_memory_stats(device)
            with profile(
                activities=activities,
                with_flops=True,
                profile_memory=False,
                record_shapes=False,
            ) as prof:
                _run_step()
            torch.cuda.synchronize(device)
            peak_cuda_bytes = int(torch.cuda.max_memory_allocated(device))
            peak_cuda_reserved_bytes = int(torch.cuda.max_memory_reserved(device))
            delta_cuda_bytes = max(0, peak_cuda_bytes - baseline_cuda_alloc)
            delta_cuda_reserved_bytes = max(0, peak_cuda_reserved_bytes - baseline_cuda_reserved)
        else:
            sampler = _RSSSampler()
            with sampler:
                with profile(
                    activities=activities,
                    with_flops=True,
                    profile_memory=True,
                    record_shapes=False,
                ) as prof:
                    _run_step()
            running = 0
            peak = 0
            for evt in prof.events():
                mem = getattr(evt, "self_cpu_memory_usage", 0)
                if not mem:
                    continue
                running += int(mem)
                if running > peak:
                    peak = running
            peak_cpu_alloc_delta_bytes = max(0, int(peak))
    finally:
        if prev_use_oracle is not None:
            setattr(model, "use_oracle", prev_use_oracle)

    total_flops = 0.0
    for event in prof.key_averages():
        flops = getattr(event, "flops", None)
        if flops:
            total_flops += float(flops)

    peak_rss_bytes = int(sampler.peak_bytes) if sampler is not None else 0
    baseline_rss_bytes = int(sampler.baseline_bytes) if sampler is not None else 0
    delta_rss_bytes = max(0, peak_rss_bytes - baseline_rss_bytes)

    flops_per_step = total_flops / max(1, time_steps)
    flops_per_update = total_flops / max(1, update_factor)

    return {
        "update_profile_batch_size": float(batch_size),
        "update_flops": float(total_flops),
        "update_flops_per_step": float(flops_per_step),
        "update_flops_per_update": float(flops_per_update),
        "update_peak_rss_bytes": float(peak_rss_bytes),
        "update_peak_rss_delta_bytes": float(delta_rss_bytes),
        "update_peak_cpu_alloc_delta_bytes": float(peak_cpu_alloc_delta_bytes),
        "update_peak_cuda_bytes": float(peak_cuda_bytes),
        "update_peak_cuda_delta_bytes": float(delta_cuda_bytes),
        "update_peak_cuda_reserved_bytes": float(peak_cuda_reserved_bytes),
        "update_peak_cuda_reserved_delta_bytes": float(delta_cuda_reserved_bytes),
    }


# _slugify：将字符串规范化为安全文件名。
# 关键步骤：去除非法字符 → 合并空白 → 生成 slug。
# 算法要点：避免文件名非法导致保存失败。
def _slugify(value: str) -> str:
    cleaned = re.sub(r"[^a-zA-Z0-9]+", "_", value.strip().lower())
    cleaned = cleaned.strip("_")
    return cleaned or "task"


# resolve_plot_context：解析绘图上下文配置并返回路径/前缀。
# 关键步骤：解析前缀 → 创建目录 → 返回路径。
# 算法要点：保证输出路径一致与可用。
def resolve_plot_context(
    args: argparse.Namespace,
    task_label: str,
) -> Tuple[Path, str]:
    plot_path = getattr(args, "plot_path", None)
    plot_dir = getattr(args, "plot_dir", None)
    plot_tag = getattr(args, "plot_tag", None)
    task_slug = _slugify(task_label)

    if plot_path:
        base = Path(plot_path)
        plot_dir = base.parent
        plot_tag = base.stem or task_slug
    else:
        if plot_dir is None:
            plot_dir = Path("plots") / task_slug
        else:
            plot_dir = Path(plot_dir)
        if not plot_tag:
            plot_tag = time.strftime("%Y%m%d_%H%M%S")

    plot_dir.mkdir(parents=True, exist_ok=True)
    return plot_dir, plot_tag


# build_plot_path：构建图表输出路径并确保目录存在。
# 关键步骤：拼接路径 → mkdir → 返回完整路径。
# 算法要点：避免路径不存在导致保存失败。
def build_plot_path(plot_dir: Path, plot_tag: str, suffix: str) -> str:
    safe_suffix = _slugify(suffix)
    return str(plot_dir / f"{plot_tag}_{safe_suffix}.png")


# set_global_seed：统一设置 random/numpy/torch 的随机种子。
# 说明：确保数据划分、batch 打乱与（可能的）torch 随机操作可复现。
def set_global_seed(seed: int, deterministic_cudnn: bool = True) -> None:
    seed = int(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    if deterministic_cudnn and hasattr(torch.backends, "cudnn"):
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def _find_git_root(start: Path) -> Path | None:
    for cand in (start, *start.parents):
        if (cand / ".git").exists():
            return cand
    return None


def collect_environment_info() -> Dict[str, Any]:
    git_commit = None
    git_dirty = None
    git_root = _find_git_root(Path(__file__).resolve())
    if git_root is not None:
        try:
            git_commit = (
                subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=str(git_root), text=True)
                .strip()
            )
            git_dirty = bool(
                subprocess.check_output(["git", "status", "--porcelain"], cwd=str(git_root), text=True).strip()
            )
        except Exception:
            git_commit = None
            git_dirty = None

    cudnn_version = None
    try:
        cudnn_version = torch.backends.cudnn.version() if hasattr(torch.backends, "cudnn") else None
    except Exception:
        cudnn_version = None

    deterministic_algos = None
    if hasattr(torch, "are_deterministic_algorithms_enabled"):
        deterministic_algos = bool(torch.are_deterministic_algorithms_enabled())

    return {
        "python_version": platform.python_version(),
        "python_executable": sys.executable,
        "platform": platform.platform(),
        "numpy_version": np.__version__,
        "torch_version": torch.__version__,
        "cuda_available": bool(torch.cuda.is_available()),
        "cuda_version": torch.version.cuda,
        "cudnn_version": cudnn_version,
        "cudnn_deterministic": getattr(torch.backends.cudnn, "deterministic", None)
        if hasattr(torch.backends, "cudnn")
        else None,
        "cudnn_benchmark": getattr(torch.backends.cudnn, "benchmark", None)
        if hasattr(torch.backends, "cudnn")
        else None,
        "torch_deterministic_algorithms": deterministic_algos,
        "git_commit": git_commit,
        "git_dirty": git_dirty,
    }


# save_results_summary：保存实验结果与配置汇总到文件。
# 关键步骤：整理指标 → 序列化 → 写入文件。
# 算法要点：便于实验对比与结果追踪。
def save_results_summary(
    task_label: str,
    metric_label: str,
    results: Dict[str, Dict[str, Any]],
    plot_dir: Path,
    plot_tag: str,
    *,
    run_config: Dict[str, Any] | None = None,
    env_info: Dict[str, Any] | None = None,
) -> None:
    if not plot_dir:
        return

    run_config = run_config or {}
    env_info = env_info or collect_environment_info()

    def _csv_value(value: Any) -> Any:
        if value is None:
            return ""
        if isinstance(value, (str, int, float, bool)):
            return value
        try:
            return json.dumps(value, ensure_ascii=False, sort_keys=True)
        except Exception:
            return str(value)

    rows: List[Dict[str, Any]] = []
    for name, data in results.items():
        row: Dict[str, Any] = {
            "task": task_label,
            "metric_label": metric_label,
            "method": name,
        }
        for k, v in run_config.items():
            row[f"cfg_{k}"] = _csv_value(v)
        for k, v in env_info.items():
            row[f"env_{k}"] = _csv_value(v)

        row.update(
            {
                "method": name,
                "metric": data.get("metric"),
                "val_metric": data.get("val_metric"),
                "val_loss": data.get("val_loss"),
                "best_epoch": data.get("best_epoch"),
                "lyap_pre": data.get("lyap_pre"),
                "lyap_post": data.get("lyap_post"),
                "runtime_sec": data.get("runtime_sec"),
                "train_runtime_sec": data.get("train_runtime_sec"),
                "eval_runtime_sec": data.get("eval_runtime_sec"),
                "runtime_per_update_sec": data.get("runtime_per_update_sec"),
                "runtime_per_step_sec": data.get("runtime_per_step_sec"),
                "batches_per_epoch": data.get("batches_per_epoch"),
                "time_steps": data.get("time_steps"),
                "update_factor": data.get("update_factor"),
                "updates_per_epoch": data.get("updates_per_epoch"),
                "updates_total": data.get("updates_total"),
                "steps_total": data.get("steps_total"),
                "complexity_params": data.get("complexity_params"),
                "complexity_state": data.get("complexity_state"),
                "complexity_total": data.get("complexity_total"),
                "update_profile_batch_size": data.get("update_profile_batch_size"),
                "update_flops": data.get("update_flops"),
                "update_flops_per_step": data.get("update_flops_per_step"),
                "update_flops_per_update": data.get("update_flops_per_update"),
                "update_peak_rss_bytes": data.get("update_peak_rss_bytes"),
                "update_peak_rss_delta_bytes": data.get("update_peak_rss_delta_bytes"),
                "update_peak_cpu_alloc_delta_bytes": data.get("update_peak_cpu_alloc_delta_bytes"),
                "update_peak_cuda_bytes": data.get("update_peak_cuda_bytes"),
                "update_peak_cuda_delta_bytes": data.get("update_peak_cuda_delta_bytes"),
                "update_peak_cuda_reserved_bytes": data.get("update_peak_cuda_reserved_bytes"),
                "update_peak_cuda_reserved_delta_bytes": data.get("update_peak_cuda_reserved_delta_bytes"),
                "history_len": len(data.get("history", [])),
            }
        )
        hparams = data.get("hparams")
        if isinstance(hparams, dict):
            for k, v in hparams.items():
                row[f"hp_{k}"] = _csv_value(v)

        rows.append(row)

    csv_path = plot_dir / f"{plot_tag}_summary.csv"
    with csv_path.open("w", newline="", encoding="utf-8") as handle:
        if rows:
            preferred = [
                "task",
                "metric_label",
                "method",
                "metric",
                "val_metric",
                "val_loss",
                "best_epoch",
                "lyap_pre",
                "lyap_post",
                "runtime_sec",
                "train_runtime_sec",
                "eval_runtime_sec",
                "runtime_per_update_sec",
                "runtime_per_step_sec",
                "batches_per_epoch",
                "time_steps",
                "update_factor",
                "updates_per_epoch",
                "updates_total",
                "steps_total",
                "complexity_params",
                "complexity_state",
                "complexity_total",
                "update_profile_batch_size",
                "update_flops",
                "update_flops_per_step",
                "update_flops_per_update",
                "update_peak_rss_bytes",
                "update_peak_rss_delta_bytes",
                "update_peak_cpu_alloc_delta_bytes",
                "update_peak_cuda_bytes",
                "update_peak_cuda_delta_bytes",
                "update_peak_cuda_reserved_bytes",
                "update_peak_cuda_reserved_delta_bytes",
                "history_len",
            ]
            all_keys = set().union(*(row.keys() for row in rows))
            fieldnames: List[str] = []
            for key in preferred:
                if key in all_keys:
                    fieldnames.append(key)
                    all_keys.remove(key)
            fieldnames.extend(sorted(all_keys))
        else:
            fieldnames = []
        writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore")
        if rows:
            writer.writeheader()
            writer.writerows(rows)

    json_path = plot_dir / f"{plot_tag}_summary.json"
    payload = {
        "schema_version": 2,
        "task": task_label,
        "metric_label": metric_label,
        "generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
        "run_config": run_config,
        "environment": env_info,
        "results": results,
    }
    with json_path.open("w", encoding="utf-8") as handle:
        json.dump(payload, handle, indent=2, ensure_ascii=False)


# =============================
# 绘图与结果输出
# =============================

# _apply_icml_plot_style：设置 ICML 风格的字体、线宽与布局参数。
# 关键步骤：设置 rcParams → 统一字体/线宽 → 生效。
# 算法要点：统一风格以保证图表可比性。
def _apply_icml_plot_style() -> None:
    try:
        import matplotlib as mpl
    except Exception:
        return

    mpl.rcParams.update(
        {
            "font.family": "serif",
            "font.serif": ["Times New Roman", "Times", "DejaVu Serif"],
            "font.weight": "normal",
            "axes.labelsize": 11,
            "axes.labelweight": "normal",
            "axes.titlesize": 12,
            "axes.titleweight": "normal",
            "xtick.labelsize": 10,
            "ytick.labelsize": 10,
            "figure.titlesize": 14,
            "figure.titleweight": "normal",
            "legend.fontsize": 9,
            "legend.frameon": False,
            "axes.spines.top": False,
            "axes.spines.right": False,
            "axes.axisbelow": True,
            "grid.alpha": 0.35,
            "grid.linestyle": ":",
            "grid.linewidth": 0.8,
            "lines.linewidth": 2.0,
            "savefig.dpi": 300,
            "figure.dpi": 150,
        }
    )


# _lyapunov_plot_enabled：判断当前配置是否需要输出 Lyapunov 图。
# 关键步骤：读取输入 → 处理逻辑 → 返回结果。
# 算法要点：作为通用工具支撑上层流程。
def _lyapunov_plot_enabled() -> bool:
    flag = str(os.getenv("OLL_PHYSICAL_PROBE_NO_LYAPUNOV", "")).strip().lower()
    return flag not in {"1", "true", "yes"}


# _method_palette：按方法数量生成可区分的颜色列表。
# 关键步骤：生成颜色列表 → 按方法索引分配。
# 算法要点：颜色稳定分配避免方法混淆。
def _method_palette(count: int) -> List[str]:
    base = [
        "#0072B2",
        "#D55E00",
        "#009E73",
        "#CC79A7",
        "#E69F00",
        "#56B4E9",
        "#F0E442",
        "#999999",
    ]
    return [base[idx % len(base)] for idx in range(max(0, count))]


# _save_figure：统一保存图像文件并处理布局与 dpi。
# 关键步骤：tight_layout → 保存文件 → 关闭图像。
# 算法要点：确保输出一致并释放资源。
def _save_figure(fig: Any, path: str | Path) -> None:
    fig.savefig(path, dpi=300, bbox_inches="tight", pad_inches=0.2)


# plot_comparison_results：绘制不同方法/模型的对比曲线。
# 关键步骤：整理结果 → 绘制图形 → 保存文件。
# 算法要点：统一配色/字体，保证图表可比性。
def plot_comparison_results(
    task_label: str,
    results: Dict[str, Dict[str, Any]],
    metric_label: str,
    history_label: str | None,
    higher_is_better: bool,
    plot_dir: Path,
    plot_tag: str,
    show: bool,
) -> None:
    try:
        import matplotlib.pyplot as plt
    except Exception as exc:
        print(f"[PLOT] Skipping plot: matplotlib unavailable ({exc}).")
        return
    _apply_icml_plot_style()

    names = list(results.keys())
    if not names:
        return

    colors = _method_palette(len(names))
    x = np.arange(len(names))

    metric_values = [results[name]["metric"] for name in names]
    fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
    bars = ax.bar(x, metric_values, color=colors, edgecolor="black", linewidth=0.3, zorder=3)
    better_note = "higher is better" if higher_is_better else "lower is better"
    ax.set_title(f"{task_label}: Final Performance ({better_note})", fontsize=12, fontweight="normal")
    ax.set_ylabel(metric_label, fontsize=12, fontweight="normal")
    ax.set_xticks(x)
    ax.set_xticklabels(names, rotation=15)
    ax.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
    for bar, val in zip(bars, metric_values):
        ax.text(
            bar.get_x() + bar.get_width() / 2.0,
            bar.get_height(),
            f"{val:.4f}",
            ha="center",
            va="bottom",
            fontsize=9,
        )
    metric_path = build_plot_path(plot_dir, plot_tag, "metric")
    _save_figure(fig, metric_path)
    print(f"[PLOT] Saved metric plot to {metric_path}")
    if show:
        plt.show()
    else:
        plt.close(fig)

    if _lyapunov_plot_enabled():
        lyap_pre = [results[name].get("lyap_pre", float("nan")) for name in names]
        lyap_post = [results[name].get("lyap_post", float("nan")) for name in names]
        fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
        width = 0.32
        ax.bar(x - width / 2.0, lyap_pre, width, label="Pre", color=colors, alpha=0.4, zorder=3)
        ax.bar(x + width / 2.0, lyap_post, width, label="Post", color=colors, alpha=0.9, zorder=3)
        ax.axhline(0.0, color="crimson", linestyle="--", linewidth=1.0, zorder=2)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=15)
        ax.set_title(f"{task_label}: Lyapunov (pre/post)", fontsize=12, fontweight="normal")
        ax.set_ylabel("Lyap max", fontsize=12, fontweight="normal")
        ax.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
        ax.legend(fontsize=9)
        lyap_path = build_plot_path(plot_dir, plot_tag, "lyapunov")
        _save_figure(fig, lyap_path)
        print(f"[PLOT] Saved lyapunov plot to {lyap_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)
    else:
        print("[PLOT] Lyapunov plot disabled by OLL_PHYSICAL_PROBE_NO_LYAPUNOV.")

    histories = [results[name].get("history", []) for name in names]
    has_history = any(len(h) > 0 for h in histories)
    if has_history:
        markers = ["o", "s", "^", "D", "v", "P", "X", "*", "<", ">"]
        fig, ax = plt.subplots(figsize=(7.4, 4.4), constrained_layout=True)
        for idx, (name, history, color) in enumerate(zip(names, histories, colors)):
            if not history:
                continue
            epochs = np.arange(1, len(history) + 1)
            mark_every = max(1, len(history) // 10)
            ax.plot(
                epochs,
                history,
                label=name,
                color=color,
                linewidth=2.0,
                marker=markers[idx % len(markers)],
                markersize=4,
                markevery=mark_every,
            )
        ax.set_xlabel("Epoch", fontsize=12, fontweight="normal")
        ax.set_ylabel(history_label or metric_label, fontsize=12, fontweight="normal")
        ax.set_title(f"{task_label}: Learning Curve", fontsize=12, fontweight="normal")
        ax.grid(True, linestyle=":", linewidth=0.7)
        ax.tick_params(axis="x", labelsize=10)
        ax.tick_params(axis="y", labelsize=10)
        ncol = min(3, max(1, len(names)))
        ax.legend(fontsize=9, ncol=ncol, loc="upper center", bbox_to_anchor=(0.5, 1.22))
        curve_path = build_plot_path(plot_dir, plot_tag, "curve")
        _save_figure(fig, curve_path)
        print(f"[PLOT] Saved learning-curve plot to {curve_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)


# plot_scan_results：绘制增益扫描结果并标注最优点。
# 关键步骤：整理结果 → 绘制图形 → 保存文件。
# 算法要点：统一配色/字体，保证图表可比性。
def plot_scan_results(
    task_label: str,
    gains: Iterable[float],
    metric_values: Iterable[float],
    metric_label: str,
    lyap_pre: Iterable[float] | None,
    lyap_post: Iterable[float] | None,
    plot_dir: Path,
    plot_tag: str,
    show: bool,
    best_g: float | None = None,
    higher_is_better: bool = True,
    aux_values: Iterable[float] | None = None,
    aux_label: str | None = None,
    suffix: str = "scan_lyapunov",
) -> None:
    try:
        import matplotlib.pyplot as plt
    except Exception as exc:
        print(f"[PLOT] Skipping scan plot: matplotlib unavailable ({exc}).")
        return
    if plot_dir is None:
        return
    _apply_icml_plot_style()

    gains_arr = np.asarray(list(gains), dtype=float)
    metric_arr = np.asarray(list(metric_values), dtype=float)
    if gains_arr.size == 0 or metric_arr.size == 0:
        return
    count = min(gains_arr.size, metric_arr.size)
    gains_arr = gains_arr[:count]
    metric_arr = metric_arr[:count]

    order = np.argsort(gains_arr)
    gains_arr = gains_arr[order]
    metric_arr = metric_arr[order]

    aux_arr = None
    if aux_values is not None:
        aux_arr = np.asarray(list(aux_values), dtype=float)[:count][order]
    lyap_pre_arr = None
    if lyap_pre is not None:
        lyap_pre_arr = np.asarray(list(lyap_pre), dtype=float)[:count][order]
    lyap_post_arr = None
    if lyap_post is not None:
        lyap_post_arr = np.asarray(list(lyap_post), dtype=float)[:count][order]

    lyap_enabled = _lyapunov_plot_enabled()
    if lyap_enabled:
        fig, axes = plt.subplots(2, 1, figsize=(7.2, 6.4), constrained_layout=True)
        ax_metric = axes[0]
    else:
        fig, ax_metric = plt.subplots(figsize=(7.2, 3.6), constrained_layout=True)
    ax_metric.plot(gains_arr, metric_arr, color="#0072B2", marker="o", label=metric_label)
    if best_g is not None and np.isfinite(best_g):
        ax_metric.axvline(
            float(best_g),
            color="crimson",
            linestyle="--",
            linewidth=1.0,
            label=f"best g={best_g:.3f}",
        )
    ax_metric.set_ylabel(metric_label)
    note = "higher is better" if higher_is_better else "lower is better"
    ax_metric.set_title(f"{task_label}: Scan ({note})", fontsize=12, fontweight="normal")
    ax_metric.grid(True, axis="y", linestyle=":", linewidth=0.7)
    handles, labels = ax_metric.get_legend_handles_labels()
    if aux_arr is not None and aux_label:
        ax_aux = ax_metric.twinx()
        ax_aux.plot(gains_arr, aux_arr, color="#D55E00", marker="s", label=aux_label)
        ax_aux.set_ylabel(aux_label)
        handles2, labels2 = ax_aux.get_legend_handles_labels()
        handles += handles2
        labels += labels2
    if handles:
        ax_metric.legend(handles, labels, fontsize=9, loc="best")

    if lyap_enabled:
        ax_lyap = axes[1]
        has_lyap = False
        if lyap_pre_arr is not None and np.any(np.isfinite(lyap_pre_arr)):
            ax_lyap.plot(gains_arr, lyap_pre_arr, color="#56B4E9", marker="o", label="Lyap pre")
            has_lyap = True
        if lyap_post_arr is not None and np.any(np.isfinite(lyap_post_arr)):
            ax_lyap.plot(gains_arr, lyap_post_arr, color="#009E73", marker="s", label="Lyap post")
            has_lyap = True
        ax_lyap.axhline(0.0, color="crimson", linestyle="--", linewidth=1.0)
        if best_g is not None and np.isfinite(best_g):
            ax_lyap.axvline(float(best_g), color="crimson", linestyle="--", linewidth=1.0)
        ax_lyap.set_xlabel("Gain g")
        ax_lyap.set_ylabel("Lyapunov (max)")
        ax_lyap.grid(True, axis="y", linestyle=":", linewidth=0.7)
        if has_lyap:
            ax_lyap.legend(fontsize=9, loc="best")

    scan_suffix = suffix if lyap_enabled else suffix.replace("lyapunov", "metric")
    scan_path = build_plot_path(plot_dir, plot_tag, scan_suffix)
    _save_figure(fig, scan_path)
    if lyap_enabled:
        print(f"[PLOT] Saved scan plot to {scan_path}")
    else:
        print(f"[PLOT] Saved metric-only scan plot to {scan_path}")
    if show:
        plt.show()
    else:
        plt.close(fig)


# plot_cost_profile：绘制损失或代价随时间的变化。
# 关键步骤：整理结果 → 绘制图形 → 保存文件。
# 算法要点：统一配色/字体，保证图表可比性。
def plot_cost_profile(
    task_label: str,
    results: Dict[str, Dict[str, Any]],
    plot_dir: Path,
    plot_tag: str,
    show: bool,
) -> None:
    try:
        import matplotlib.pyplot as plt
    except Exception as exc:
        print(f"[PLOT] Skipping cost plot: matplotlib unavailable ({exc}).")
        return
    _apply_icml_plot_style()

    names = list(results.keys())
    if not names:
        return

    params = [results[name].get("complexity_params") for name in names]
    state = [results[name].get("complexity_state") for name in names]
    runtime = [results[name].get("runtime_sec") for name in names]
    train_runtime = [results[name].get("train_runtime_sec") for name in names]
    eval_runtime = [results[name].get("eval_runtime_sec") for name in names]
    runtime_per_update = [results[name].get("runtime_per_update_sec") for name in names]
    runtime_per_step = [results[name].get("runtime_per_step_sec") for name in names]

    has_complexity = any(v is not None for v in params) or any(v is not None for v in state)
    has_runtime = any(v is not None for v in runtime) or any(v is not None for v in train_runtime)
    has_runtime_per_update = any(v is not None for v in runtime_per_update)
    has_runtime_per_step = any(v is not None for v in runtime_per_step)
    if not has_complexity and not has_runtime and not has_runtime_per_update and not has_runtime_per_step:
        return

    colors = _method_palette(len(names))
    x = np.arange(len(names))

    # _scale_time：按量级缩放时间并返回合适单位。
    # 关键步骤：统计最大值 → 选择单位 → 缩放数值。
    # 算法要点：自动单位缩放便于读图。
    def _scale_time(values: np.ndarray) -> Tuple[np.ndarray, str]:
        finite = values[np.isfinite(values)]
        if finite.size == 0:
            return values, "s"
        max_val = float(np.max(finite))
        if max_val < 1e-3:
            return values * 1e6, "us"
        if max_val < 1.0:
            return values * 1e3, "ms"
        return values, "s"

    if has_complexity:
        fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
        params_arr = np.array([float(v) if v is not None else 0.0 for v in params])
        state_arr = np.array([float(v) if v is not None else 0.0 for v in state])
        params_m = params_arr / 1e6
        state_m = state_arr / 1e6
        ax.bar(x, params_m, color=colors, edgecolor="black", linewidth=0.3, label="Params", zorder=3)
        if np.any(state_m > 0):
            ax.bar(
                x,
                state_m,
                bottom=params_m,
                color="0.7",
                edgecolor="black",
                linewidth=0.3,
                label="State",
                zorder=3,
            )
        totals = params_m + state_m
        for xi, total in zip(x, totals):
            ax.text(xi, total, f"{total:.2f}", ha="center", va="bottom", fontsize=9)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=15)
        ax.set_ylabel("Elements (M)")
        ax.set_title(f"{task_label}: Model Size (params + persistent state)", fontsize=12, fontweight="normal")
        ax.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
        if np.any(state_m > 0):
            ax.legend(fontsize=9)
        complexity_path = build_plot_path(plot_dir, plot_tag, "complexity")
        _save_figure(fig, complexity_path)
        print(f"[PLOT] Saved complexity plot to {complexity_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)

    if has_runtime:
        fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
        use_train_runtime = any(v is not None for v in train_runtime)
        runtime_source = train_runtime if use_train_runtime else runtime
        runtime_arr = np.array([float(v) if v is not None else float("nan") for v in runtime_source])
        runtime_clean = np.nan_to_num(runtime_arr, nan=0.0)
        if use_train_runtime:
            eval_arr = np.array([float(v) if v is not None else 0.0 for v in eval_runtime])
            has_eval = np.any(eval_arr > 0)
        else:
            eval_arr = np.zeros_like(runtime_clean)
            has_eval = False

        if use_train_runtime:
            ax.bar(x, runtime_clean, color=colors, edgecolor="black", linewidth=0.3, zorder=3, label="Train")
        else:
            ax.bar(x, runtime_clean, color=colors, edgecolor="black", linewidth=0.3, zorder=3)
        if has_eval:
            ax.bar(
                x,
                eval_arr,
                bottom=runtime_clean,
                color="0.85",
                edgecolor="black",
                linewidth=0.3,
                zorder=3,
                label="Eval",
            )
        totals = runtime_clean + (eval_arr if has_eval else 0.0)
        for xi, base, total in zip(x, runtime_arr, totals):
            label = "n/a" if not np.isfinite(base) else f"{total:.2f}s"
            y = 0.0 if not np.isfinite(base) else total
            ax.text(xi, y, label, ha="center", va="bottom", fontsize=9)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=15)
        ax.set_ylabel("Seconds")
        if use_train_runtime:
            title = (
                f"{task_label}: Runtime (train + eval, fixed epochs)"
                if has_eval
                else f"{task_label}: Runtime (train, fixed epochs)"
            )
        else:
            title = f"{task_label}: Runtime"
        ax.set_title(title, fontsize=12, fontweight="normal")
        ax.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
        if has_eval:
            ax.legend(fontsize=9)
        runtime_path = build_plot_path(plot_dir, plot_tag, "runtime")
        _save_figure(fig, runtime_path)
        print(f"[PLOT] Saved runtime plot to {runtime_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)

    if has_runtime_per_update:
        fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
        per_update_arr = np.array([float(v) if v is not None else float("nan") for v in runtime_per_update])
        scaled_vals, unit = _scale_time(per_update_arr)
        scaled_vals = np.nan_to_num(scaled_vals, nan=0.0)
        ax.bar(x, scaled_vals, color=colors, edgecolor="black", linewidth=0.3, zorder=3)
        for xi, raw, scaled in zip(x, per_update_arr, scaled_vals):
            label = "n/a" if not np.isfinite(raw) else f"{scaled:.2f}{unit}"
            y = 0.0 if not np.isfinite(raw) else scaled
            ax.text(xi, y, label, ha="center", va="bottom", fontsize=9)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=15)
        ax.set_ylabel(f"Time per update ({unit})")
        ax.set_title(f"{task_label}: Runtime per Update (method-specific)", fontsize=12, fontweight="normal")
        ax.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
        per_update_path = build_plot_path(plot_dir, plot_tag, "runtime_per_update")
        _save_figure(fig, per_update_path)
        print(f"[PLOT] Saved runtime-per-update plot to {per_update_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)

    if has_runtime_per_step:
        fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
        per_step_arr = np.array([float(v) if v is not None else float("nan") for v in runtime_per_step])
        scaled_vals, unit = _scale_time(per_step_arr)
        scaled_vals = np.nan_to_num(scaled_vals, nan=0.0)
        ax.bar(x, scaled_vals, color=colors, edgecolor="black", linewidth=0.3, zorder=3)
        for xi, raw, scaled in zip(x, per_step_arr, scaled_vals):
            label = "n/a" if not np.isfinite(raw) else f"{scaled:.2f}{unit}"
            y = 0.0 if not np.isfinite(raw) else scaled
            ax.text(xi, y, label, ha="center", va="bottom", fontsize=9)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=15)
        ax.set_ylabel(f"Time per step ({unit})")
        ax.set_title(f"{task_label}: Runtime per Step (normalized)", fontsize=12, fontweight="normal")
        ax.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
        per_step_path = build_plot_path(plot_dir, plot_tag, "runtime_per_step")
        _save_figure(fig, per_step_path)
        print(f"[PLOT] Saved runtime-per-step plot to {per_step_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)

    flops_per_step = [results[name].get("update_flops_per_step") for name in names]
    flops_per_update = [results[name].get("update_flops_per_update") for name in names]
    mem_delta_cuda = [results[name].get("update_peak_cuda_delta_bytes") for name in names]
    mem_delta_cpu = [results[name].get("update_peak_cpu_alloc_delta_bytes") for name in names]
    mem_delta_rss = [results[name].get("update_peak_rss_delta_bytes") for name in names]
    has_cuda_mem = any((v is not None and float(v) > 0.0) for v in mem_delta_cuda)
    has_cpu_mem = any((v is not None and float(v) > 0.0) for v in mem_delta_cpu)
    if has_cuda_mem:
        mem_delta = mem_delta_cuda
        mem_label = "Peak CUDA Δ"
    elif has_cpu_mem:
        mem_delta = mem_delta_cpu
        mem_label = "Peak CPU alloc Δ"
    else:
        mem_delta = mem_delta_rss
        mem_label = "Peak RSS Δ"

    has_flops_per_step = any(v is not None for v in flops_per_step)
    has_flops_per_update = any(v is not None for v in flops_per_update)
    has_mem_delta = any(v is not None for v in mem_delta)

    def _scale_flops(values: np.ndarray) -> Tuple[np.ndarray, str]:
        finite = values[np.isfinite(values)]
        if finite.size == 0:
            return values, "FLOPs"
        max_val = float(np.max(finite))
        if max_val < 1e6:
            return values, "FLOPs"
        if max_val < 1e9:
            return values / 1e6, "MFLOPs"
        if max_val < 1e12:
            return values / 1e9, "GFLOPs"
        return values / 1e12, "TFLOPs"

    def _scale_bytes(values: np.ndarray) -> Tuple[np.ndarray, str]:
        finite = values[np.isfinite(values)]
        if finite.size == 0:
            return values, "B"
        max_val = float(np.max(finite))
        if max_val < 1024.0:
            return values, "B"
        if max_val < 1024.0**2:
            return values / 1024.0, "KB"
        if max_val < 1024.0**3:
            return values / (1024.0**2), "MB"
        return values / (1024.0**3), "GB"

    if has_flops_per_step:
        fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
        flops_arr = np.array([float(v) if v is not None else float("nan") for v in flops_per_step])
        scaled_vals, unit = _scale_flops(flops_arr)
        scaled_vals = np.nan_to_num(scaled_vals, nan=0.0)
        ax.bar(x, scaled_vals, color=colors, edgecolor="black", linewidth=0.3, zorder=3)
        for xi, raw, scaled in zip(x, flops_arr, scaled_vals):
            label = "n/a" if not np.isfinite(raw) else f"{scaled:.2f}{unit}"
            y = 0.0 if not np.isfinite(raw) else scaled
            ax.text(xi, y, label, ha="center", va="bottom", fontsize=9)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=15)
        ax.set_ylabel(f"FLOPs per step ({unit})")
        ax.set_title(f"{task_label}: Update Compute per Step (profiled)", fontsize=12, fontweight="normal")
        ax.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
        flops_path = build_plot_path(plot_dir, plot_tag, "update_flops_per_step")
        _save_figure(fig, flops_path)
        print(f"[PLOT] Saved update FLOPs-per-step plot to {flops_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)

    if has_flops_per_update:
        fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
        flops_arr = np.array([float(v) if v is not None else float("nan") for v in flops_per_update])
        scaled_vals, unit = _scale_flops(flops_arr)
        scaled_vals = np.nan_to_num(scaled_vals, nan=0.0)
        ax.bar(x, scaled_vals, color=colors, edgecolor="black", linewidth=0.3, zorder=3)
        for xi, raw, scaled in zip(x, flops_arr, scaled_vals):
            label = "n/a" if not np.isfinite(raw) else f"{scaled:.2f}{unit}"
            y = 0.0 if not np.isfinite(raw) else scaled
            ax.text(xi, y, label, ha="center", va="bottom", fontsize=9)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=15)
        ax.set_ylabel(f"FLOPs per update ({unit})")
        ax.set_title(f"{task_label}: Update Compute per Weight-Update (profiled)", fontsize=12, fontweight="normal")
        ax.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
        flops_path = build_plot_path(plot_dir, plot_tag, "update_flops_per_update")
        _save_figure(fig, flops_path)
        print(f"[PLOT] Saved update FLOPs-per-update plot to {flops_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)

    if has_mem_delta:
        fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
        mem_arr = np.array([float(v) if v is not None else float("nan") for v in mem_delta])
        scaled_vals, unit = _scale_bytes(mem_arr)
        scaled_vals = np.nan_to_num(scaled_vals, nan=0.0)
        ax.bar(x, scaled_vals, color=colors, edgecolor="black", linewidth=0.3, zorder=3)
        for xi, raw, scaled in zip(x, mem_arr, scaled_vals):
            label = "n/a" if not np.isfinite(raw) else f"{scaled:.2f}{unit}"
            y = 0.0 if not np.isfinite(raw) else scaled
            ax.text(xi, y, label, ha="center", va="bottom", fontsize=9)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=15)
        ax.set_ylabel(f"{mem_label} ({unit})")
        ax.set_title(f"{task_label}: Update Peak Memory (RSS delta, profiled)", fontsize=12, fontweight="normal")
        ax.grid(True, axis="y", linestyle=":", linewidth=0.7, zorder=0)
        mem_path = build_plot_path(plot_dir, plot_tag, "update_memory_peak")
        _save_figure(fig, mem_path)
        print(f"[PLOT] Saved update memory plot to {mem_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)

    if has_mem_delta and (has_flops_per_step or has_flops_per_update):
        fig, ax = plt.subplots(figsize=(6.8, 4.2), constrained_layout=True)
        mem_arr = np.array([float(v) if v is not None else float("nan") for v in mem_delta])
        flops_arr = (
            np.array([float(v) if v is not None else float("nan") for v in flops_per_step])
            if has_flops_per_step
            else np.array([float(v) if v is not None else float("nan") for v in flops_per_update])
        )
        mem_scaled, mem_unit = _scale_bytes(mem_arr)
        flops_scaled, flops_unit = _scale_flops(flops_arr)
        for idx, name in enumerate(names):
            if not np.isfinite(mem_scaled[idx]) or not np.isfinite(flops_scaled[idx]):
                continue
            marker_size = 80.0
            metric_val = results[name].get("metric")
            try:
                metric_float = float(metric_val) if metric_val is not None else None
            except Exception:
                metric_float = None
            if metric_float is not None and 0.0 <= metric_float <= 1.0:
                marker_size = 80.0 + 420.0 * metric_float
            ax.scatter(
                mem_scaled[idx],
                flops_scaled[idx],
                s=marker_size,
                color=colors[idx],
                edgecolor="black",
                linewidth=0.3,
                zorder=4,
            )
            ax.text(
                mem_scaled[idx],
                flops_scaled[idx],
                f" {name}",
                fontsize=9,
                ha="left",
                va="center",
            )
        ax.set_xlabel(f"{mem_label} ({mem_unit})")
        y_label = "FLOPs per step" if has_flops_per_step else "FLOPs per update"
        ax.set_ylabel(f"{y_label} ({flops_unit})")
        ax.set_title(f"{task_label}: Update Cost Tradeoff (profiled)", fontsize=12, fontweight="normal")
        ax.grid(True, linestyle=":", linewidth=0.7, zorder=0)
        tradeoff_path = build_plot_path(plot_dir, plot_tag, "update_tradeoff")
        _save_figure(fig, tradeoff_path)
        print(f"[PLOT] Saved update tradeoff plot to {tradeoff_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)


# plot_classification_predictions：绘制分类预测随时间变化的曲线。
# 关键步骤：整理预测数据 → 绘制曲线 → 保存图像。
# 算法要点：统一配色/字体，保证图表可比性。
def plot_classification_predictions(
    task_label: str,
    model: Any,
    images: np.ndarray | torch.Tensor,
    labels: np.ndarray | torch.Tensor,
    plot_dir: Path,
    plot_tag: str,
    show: bool,
    max_samples: int = 12,
    seed: int = 1234,
) -> None:
    try:
        import matplotlib.pyplot as plt
    except Exception as exc:
        print(f"[PLOT] Skipping prediction plot: matplotlib unavailable ({exc}).")
        return
    if images is None or images.shape[0] == 0:
        print("[PLOT] Skipping prediction plot: empty dataset.")
        return

    rng = np.random.default_rng(seed)
    total = int(images.shape[0])
    sample_count = min(int(max_samples), total)
    if sample_count <= 0:
        print("[PLOT] Skipping prediction plot: no samples requested.")
        return
    if total == sample_count:
        indices = np.arange(total)
    else:
        indices = rng.choice(total, size=sample_count, replace=False)

    if torch.is_tensor(images):
        idx_tensor = torch.as_tensor(indices, device=images.device)
        sample_images = images[idx_tensor]
    else:
        sample_images = images[indices]
    if torch.is_tensor(labels):
        label_arr = labels.detach().cpu().numpy()
    else:
        label_arr = np.asarray(labels)
    sample_labels = label_arr[indices]

    preds = predict_classifier_final_step(model, sample_images, batch_size=min(128, sample_count))

    ncols = min(4, sample_count)
    nrows = int(math.ceil(sample_count / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(2.4 * ncols, 2.4 * nrows), constrained_layout=True)
    axes = np.array(axes).reshape(nrows, ncols)
    for idx, ax in enumerate(axes.flat):
        if idx >= sample_count:
            ax.axis("off")
            continue
        img = sample_images[idx]
        if torch.is_tensor(img):
            img = img.detach().cpu().numpy()
        if img.ndim != 3:
            ax.axis("off")
            continue
        channels = int(img.shape[0])
        if channels == 1:
            frame = img[0]
            cmap = "gray"
        elif channels == 3:
            frame = np.transpose(img, (1, 2, 0))
            cmap = None
        else:
            frame = np.mean(img, axis=0)
            cmap = "gray"
        vmin = float(np.min(frame))
        vmax = float(np.max(frame))
        if vmax > vmin:
            frame = (frame - vmin) / (vmax - vmin)
        ax.imshow(frame, cmap=cmap, interpolation="nearest")
        true_label = int(sample_labels[idx])
        pred_label = int(preds[idx]) if idx < len(preds) else -1
        ok = true_label == pred_label
        color = "green" if ok else "crimson"
        ax.set_title(f"t:{true_label} p:{pred_label}", color=color, fontsize=9)
        ax.axis("off")

    fig.suptitle(f"{task_label}: Local Rule Predictions", fontsize=12, fontweight="normal")
    plot_path = build_plot_path(plot_dir, plot_tag, "pred_local_rule")
    _save_figure(fig, plot_path)
    print(f"[PLOT] Saved Local Rule prediction plot to {plot_path}")
    if show:
        plt.show()
    else:
        plt.close(fig)


# parse_channel_list：解析通道配置字符串为整数列表。
# 关键步骤：按逗号分割 → 转 int → 返回二元组。
# 算法要点：严格要求两个通道配置。
def parse_channel_list(value: str | None, default: Tuple[int, int]) -> Tuple[int, int]:
    if not value:
        return default
    parts = [int(item.strip()) for item in value.split(",") if item.strip()]
    if len(parts) != 2:
        raise ValueError("enc_channels must have two comma-separated integers, e.g. '16,32'")
    return int(parts[0]), int(parts[1])


# =============================
# 数据加载（MNIST/Fashion/CIFAR/DVS）
# =============================

MNIST_NPZ = "mnist.npz"
FASHION_MNIST_NPZ = "fashion-mnist.npz"
FASHION_MNIST_NPZ_ALT = "fashion_mnist.npz"


# load_mnist_images：加载 MNIST 图像及标签。
# 关键步骤：读取缓存/文件 → 预处理/归一化 → 返回数据。
# 算法要点：优先本地缓存，减少重复预处理成本。
def load_mnist_images(
    train_limit: int | None = None,
    test_limit: int | None = None,
    npz_path: str | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    candidate_paths: List[Path] = []
    if npz_path is not None:
        candidate_paths.append(Path(npz_path))
    candidate_paths.append(Path(MNIST_NPZ))
    candidate_paths.append(Path.home() / ".keras" / "datasets" / MNIST_NPZ)

    dataset = None
    for path in candidate_paths:
        if path.exists():
            with np.load(path) as npz_data:
                x_train = npz_data["x_train"]
                y_train = npz_data["y_train"]
                x_test = npz_data["x_test"]
                y_test = npz_data["y_test"]
            dataset = ((x_train, y_train), (x_test, y_test))
            print(f"Loaded MNIST data from '{path}'.")
            break

    if dataset is None:
        try:
            from torchvision.datasets import MNIST  # type: ignore
            train_set = MNIST(root="./data", train=True, download=True)
            test_set = MNIST(root="./data", train=False, download=True)
            dataset = (
                (np.array(train_set.data), np.array(train_set.targets)),
                (np.array(test_set.data), np.array(test_set.targets)),
            )
            print("Loaded MNIST data via torchvision.datasets.MNIST.")
        except Exception:
            try:
                from tensorflow.keras.datasets import mnist  # type: ignore
            except Exception as exc:
                raise RuntimeError(
                    "MNIST dataset not found locally and TensorFlow is unavailable. "
                    "Place 'mnist.npz' next to this script or install tensorflow."
                ) from exc
            dataset = mnist.load_data()
            print("Loaded MNIST data via tensorflow.keras.datasets.mnist.")

    (train_images, train_labels), (test_images, test_labels) = dataset

    if train_limit is not None:
        train_images = train_images[:train_limit]
        train_labels = train_labels[:train_limit]
    if test_limit is not None:
        test_images = test_images[:test_limit]
        test_labels = test_labels[:test_limit]

    train_images = train_images.astype(np.float32) / 255.0
    test_images = test_images.astype(np.float32) / 255.0
    train_images, test_images = normalize_train_test(train_images, test_images)

    train_images = train_images[:, None, :, :]
    test_images = test_images[:, None, :, :]

    return train_images, train_labels.astype(np.int64), test_images, test_labels.astype(np.int64)


# load_fashion_mnist_images：加载 Fashion-MNIST 图像及标签。
# 关键步骤：读取缓存/文件 → 预处理/归一化 → 返回数据。
# 算法要点：优先本地缓存，减少重复预处理成本。
def load_fashion_mnist_images(
    train_limit: int | None = None,
    test_limit: int | None = None,
    npz_path: str | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    candidate_paths: List[Path] = []
    if npz_path is not None:
        candidate_paths.append(Path(npz_path))
    candidate_paths.append(Path(FASHION_MNIST_NPZ))
    candidate_paths.append(Path(FASHION_MNIST_NPZ_ALT))
    candidate_paths.append(Path.home() / ".keras" / "datasets" / FASHION_MNIST_NPZ)
    candidate_paths.append(Path.home() / ".keras" / "datasets" / FASHION_MNIST_NPZ_ALT)

    dataset = None
    for path in candidate_paths:
        if path.exists():
            with np.load(path) as npz_data:
                x_train = npz_data["x_train"]
                y_train = npz_data["y_train"]
                x_test = npz_data["x_test"]
                y_test = npz_data["y_test"]
            dataset = ((x_train, y_train), (x_test, y_test))
            print(f"Loaded Fashion MNIST data from '{path}'.")
            break

    if dataset is None:
        try:
            from torchvision.datasets import FashionMNIST  # type: ignore
            train_set = FashionMNIST(root="./data", train=True, download=True)
            test_set = FashionMNIST(root="./data", train=False, download=True)
            dataset = (
                (np.array(train_set.data), np.array(train_set.targets)),
                (np.array(test_set.data), np.array(test_set.targets)),
            )
            print("Loaded Fashion MNIST data via torchvision.datasets.FashionMNIST.")
        except Exception:
            try:
                from tensorflow.keras.datasets import fashion_mnist  # type: ignore
            except Exception as exc:
                raise RuntimeError(
                    "Fashion MNIST dataset not found locally and TensorFlow is unavailable. "
                    "Place 'fashion-mnist.npz' next to this script or install tensorflow."
                ) from exc
            dataset = fashion_mnist.load_data()
            print("Loaded Fashion MNIST data via tensorflow.keras.datasets.fashion_mnist.")

    (train_images, train_labels), (test_images, test_labels) = dataset

    if train_limit is not None:
        train_images = train_images[:train_limit]
        train_labels = train_labels[:train_limit]
    if test_limit is not None:
        test_images = test_images[:test_limit]
        test_labels = test_labels[:test_limit]

    train_images = train_images.astype(np.float32) / 255.0
    test_images = test_images.astype(np.float32) / 255.0
    train_images, test_images = normalize_train_test(train_images, test_images)

    train_images = train_images[:, None, :, :]
    test_images = test_images[:, None, :, :]

    return train_images, train_labels.astype(np.int64), test_images, test_labels.astype(np.int64)


# load_permuted_mnist_images：加载并构造 permuted MNIST 图像序列。
# 关键步骤：读取缓存/文件 → 预处理/归一化 → 返回数据。
# 算法要点：优先本地缓存，减少重复预处理成本。
def load_permuted_mnist_images(
    train_limit: int | None = None,
    test_limit: int | None = None,
    permute_seed: int = 1234,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    train_images, train_labels, test_images, test_labels = load_mnist_images(
        train_limit=train_limit, test_limit=test_limit
    )
    rng = np.random.default_rng(permute_seed)
    perm = rng.permutation(28 * 28)
    train_flat = train_images.reshape(train_images.shape[0], -1)[:, perm]
    test_flat = test_images.reshape(test_images.shape[0], -1)[:, perm]
    train_images = train_flat.reshape(-1, 1, 28, 28)
    test_images = test_flat.reshape(-1, 1, 28, 28)
    return train_images, train_labels, test_images, test_labels


# load_cifar10_images：加载 CIFAR-10 图像及标签。
# 关键步骤：读取缓存/文件 → 预处理/归一化 → 返回数据。
# 算法要点：优先本地缓存，减少重复预处理成本。
def load_cifar10_images(
    train_limit: int | None = None,
    test_limit: int | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    dataset = None
    try:
        from torchvision.datasets import CIFAR10  # type: ignore
        train_set = CIFAR10(root="./data", train=True, download=True)
        test_set = CIFAR10(root="./data", train=False, download=True)
        dataset = (train_set.data, np.array(train_set.targets), test_set.data, np.array(test_set.targets))
    except Exception:
        try:
            from tensorflow.keras.datasets import cifar10  # type: ignore
        except Exception as exc:
            raise RuntimeError(
                "CIFAR-10 dataset not available. Install torchvision or tensorflow."
            ) from exc
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        dataset = (x_train, y_train.squeeze(), x_test, y_test.squeeze())

    x_train, y_train, x_test, y_test = dataset

    if train_limit is not None:
        x_train = x_train[:train_limit]
        y_train = y_train[:train_limit]
    if test_limit is not None:
        x_test = x_test[:test_limit]
        y_test = y_test[:test_limit]

    x_train = x_train.astype(np.float32) / 255.0
    x_test = x_test.astype(np.float32) / 255.0
    x_train, x_test = normalize_train_test(x_train, x_test)

    x_train = np.transpose(x_train, (0, 3, 1, 2))
    x_test = np.transpose(x_test, (0, 3, 1, 2))

    return x_train, y_train.astype(np.int64), x_test, y_test.astype(np.int64)


def load_dvs_gesture_images(
    npz_path: str | None,
    train_limit: int | None,
    test_limit: int | None,
    auto_download: bool = True,
    root: str | None = None,
    time_bins: int = 20,
    spatial_downsample: int = 1,
    use_polarity: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    default_root = root or os.path.join("data", "dvs_gesture")
    candidate = npz_path or os.path.join(default_root, "dvs_gesture.npz")
    if not os.path.exists(candidate):
        if not auto_download:
            raise RuntimeError(
                "DVS-Gesture npz not found. Provide a preprocessed file via --gesture-npz.\n"
                "Expected keys: x_train/y_train/x_test/y_test (preferred) or train_inputs/train_labels."
            )
        _prepare_dvs_gesture_npz(
            candidate,
            root=root,
            time_bins=time_bins,
            spatial_downsample=spatial_downsample,
            use_polarity=use_polarity,
            train_limit=train_limit,
            test_limit=test_limit,
        )
    expected_channels = int(max(1, int(time_bins))) * (2 if use_polarity else 1)
    spatial_downsample = max(1, int(spatial_downsample))

    def _load_npz_arrays(path: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Dict[str, Any]]:
        data = np.load(path)
        x_train = data.get("x_train", data.get("train_inputs"))
        y_train = data.get("y_train", data.get("train_labels"))
        x_test = data.get("x_test", data.get("test_inputs"))
        y_test = data.get("y_test", data.get("test_labels"))
        if x_train is None or y_train is None or x_test is None or y_test is None:
            raise RuntimeError("DVS-Gesture npz missing required arrays.")
        meta_time_bins = data.get("meta_time_bins")
        meta_use_polarity = data.get("meta_use_polarity")
        meta_spatial_downsample = data.get("meta_spatial_downsample")
        meta_train_limit = data.get("meta_train_limit")
        meta_test_limit = data.get("meta_test_limit")
        meta_dataset = data.get("meta_dataset")
        meta: Dict[str, Any] = {
            "time_bins": None if meta_time_bins is None else int(np.asarray(meta_time_bins).item()),
            "use_polarity": None
            if meta_use_polarity is None
            else bool(int(np.asarray(meta_use_polarity).item())),
            "spatial_downsample": None
            if meta_spatial_downsample is None
            else int(np.asarray(meta_spatial_downsample).item()),
            "train_limit": None
            if meta_train_limit is None
            else (
                None
                if int(np.asarray(meta_train_limit).item()) < 0
                else int(np.asarray(meta_train_limit).item())
            ),
            "test_limit": None
            if meta_test_limit is None
            else (
                None if int(np.asarray(meta_test_limit).item()) < 0 else int(np.asarray(meta_test_limit).item())
            ),
            "dataset": None if meta_dataset is None else str(np.asarray(meta_dataset).item()),
        }
        return x_train, y_train, x_test, y_test, meta

    x_train, y_train, x_test, y_test, meta = _load_npz_arrays(candidate)

    mismatch: List[str] = []
    file_time_bins = meta.get("time_bins")
    if file_time_bins is not None and int(file_time_bins) != int(time_bins):
        mismatch.append(f"time_bins(file={file_time_bins}, req={int(time_bins)})")
    file_use_polarity = meta.get("use_polarity")
    if file_use_polarity is not None and bool(file_use_polarity) != bool(use_polarity):
        mismatch.append(f"use_polarity(file={file_use_polarity}, req={bool(use_polarity)})")
    file_spatial_downsample = meta.get("spatial_downsample")
    if file_spatial_downsample is not None and int(file_spatial_downsample) != spatial_downsample:
        mismatch.append(f"spatial_downsample(file={file_spatial_downsample}, req={spatial_downsample})")
    if file_spatial_downsample is None and spatial_downsample != 1:
        mismatch.append("spatial_downsample(file=unknown, req!=1)")

    def _infer_channels(arr: np.ndarray) -> int | None:
        if arr.ndim == 4:
            return int(arr.shape[1])
        if arr.ndim == 5:
            # (N, T, H, W, C) or (N, T, C, H, W)
            if arr.shape[-1] <= 4 and arr.shape[2] > 4 and arr.shape[3] > 4:
                return int(arr.shape[1]) * int(arr.shape[-1])
            if arr.shape[2] <= 4 and arr.shape[3] > 4 and arr.shape[4] > 4:
                return int(arr.shape[1]) * int(arr.shape[2])
            if arr.shape[2] in (1, 2, 3, 4):
                return int(arr.shape[1]) * int(arr.shape[2])
            if arr.shape[-1] in (1, 2, 3, 4):
                return int(arr.shape[1]) * int(arr.shape[-1])
            return int(arr.shape[1]) * int(arr.shape[-1])
        return None

    inferred_channels = _infer_channels(x_train)
    if inferred_channels is not None and inferred_channels != expected_channels:
        mismatch.append(f"channels(file={inferred_channels}, exp={expected_channels})")

    file_train_limit = meta.get("train_limit")
    file_test_limit = meta.get("test_limit")
    if train_limit is None and file_train_limit is not None:
        mismatch.append(f"train_limit(file={file_train_limit}, req=None)")
    if test_limit is None and file_test_limit is not None:
        mismatch.append(f"test_limit(file={file_test_limit}, req=None)")
    if train_limit is not None and file_train_limit is not None and int(file_train_limit) < int(train_limit):
        mismatch.append(f"train_limit(file={file_train_limit}, req={int(train_limit)})")
    if test_limit is not None and file_test_limit is not None and int(file_test_limit) < int(test_limit):
        mismatch.append(f"test_limit(file={file_test_limit}, req={int(test_limit)})")

    if file_time_bins is None and inferred_channels is not None and inferred_channels != expected_channels:
        mismatch.append("legacy_npz_incompatible")
    if train_limit is None and int(x_train.shape[0]) < 500:
        mismatch.append(f"train_samples(file={int(x_train.shape[0])}, req>=500)")
    if test_limit is None and int(x_test.shape[0]) < 50:
        mismatch.append(f"test_samples(file={int(x_test.shape[0])}, req>=50)")

    if mismatch and auto_download:
        print(f"[DVS] Cached DVS-Gesture npz incompatible ({', '.join(mismatch)}). Rebuilding: {candidate}")
        _prepare_dvs_gesture_npz(
            candidate,
            root=root,
            time_bins=time_bins,
            spatial_downsample=spatial_downsample,
            use_polarity=use_polarity,
            train_limit=train_limit,
            test_limit=test_limit,
        )
        x_train, y_train, x_test, y_test, meta = _load_npz_arrays(candidate)
        inferred_channels = _infer_channels(x_train)
        if inferred_channels is not None and inferred_channels != expected_channels:
            raise RuntimeError(
                "DVS-Gesture preparation produced unexpected channel count. "
                f"Expected {expected_channels}, got {inferred_channels}. Path: {candidate}"
            )

    if train_limit is not None:
        x_train = x_train[:train_limit]
        y_train = y_train[:train_limit]
    if test_limit is not None:
        x_test = x_test[:test_limit]
        y_test = y_test[:test_limit]

    try:
        x_train = _dvs_to_channels(x_train)
        x_test = _dvs_to_channels(x_test)
        x_train, x_test = normalize_train_test(x_train, x_test)
    except MemoryError as exc:
        n_train = int(getattr(x_train, "shape", (0,))[0])
        ch = int(expected_channels)
        h = int(getattr(x_train, "shape", (0, 0, 0, 0))[2]) if getattr(x_train, "ndim", 0) == 4 else 128
        w = int(getattr(x_train, "shape", (0, 0, 0, 0))[3]) if getattr(x_train, "ndim", 0) == 4 else 128
        approx_gb = (n_train * ch * h * w * 4) / (1024**3)
        avail = available_physical_memory_bytes()
        avail_gb = None if avail is None else (avail / (1024**3))
        hint = ""
        if avail_gb is not None:
            hint = f" (available RAM ~{avail_gb:.1f} GiB)"
        raise RuntimeError(
            "Out of RAM while materializing DVS-Gesture frames as float32. "
            f"Requested ~{approx_gb:.1f} GiB for x_train alone{hint}. "
            "Fix: reduce memory by using --gesture-spatial-downsample 4 (32x32), "
            "or --gesture-no-polarity, or smaller --gesture-time-bins, or set --train-limit/--test-limit, "
            "or run on a machine with more RAM."
        ) from exc
    return x_train, y_train.astype(np.int64), x_test, y_test.astype(np.int64)


# _dvs_to_channels：将 DVS 事件流转换为固定通道的张量/帧。
# 关键步骤：读取输入 → 处理逻辑 → 返回结果。
# 算法要点：作为通用工具支撑上层流程。
def _dvs_to_channels(data: np.ndarray) -> np.ndarray:
    if data.ndim == 4:
        n, t, h, w = data.shape
        return data.reshape(n, t, h, w).astype(np.float32)
    if data.ndim == 5:
        # Support both (N, T, H, W, C) and (N, T, C, H, W).
        n, t, a, b, c = data.shape
        if c <= 4 and a > 4 and b > 4:
            # (N, T, H, W, C) -> (N, T, C, H, W)
            data = np.transpose(data, (0, 1, 4, 2, 3))
            n, t, ch, h, w = data.shape
            return data.reshape(n, t * ch, h, w).astype(np.float32)
        if a <= 4 and b > 4 and c > 4:
            # (N, T, C, H, W)
            n, t, ch, h, w = data.shape
            return data.reshape(n, t * ch, h, w).astype(np.float32)
        # Fallback: assume last dim is channels/polarity.
        n, t, h, w, ch = data.shape
        return data.reshape(n, t * ch, h, w).astype(np.float32)
    raise ValueError(f"Unsupported DVS data shape: {data.shape}")


# load_dvs_cifar10_images：加载 DVS-CIFAR10 帧数据及标签。
# 关键步骤：读取缓存/文件 → 预处理/归一化 → 返回数据。
# 算法要点：优先本地缓存，减少重复预处理成本。
def load_dvs_cifar10_images(
    npz_path: str | None,
    train_limit: int | None,
    test_limit: int | None,
    auto_download: bool = True,
    root: str | None = None,
    time_bins: int = 10,
    spatial_downsample: int = 1,
    use_polarity: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    default_root = root or os.path.join("data", "dvs_cifar10")
    candidate = npz_path or os.path.join(default_root, "dvs_cifar10.npz")
    if not os.path.exists(candidate):
        if not auto_download:
            raise RuntimeError(
                "DVS-CIFAR10 npz not found. Provide a preprocessed file via --dvs-npz.\n"
                "Expected keys: x_train/y_train/x_test/y_test (preferred) or train_inputs/train_labels."
            )
        _prepare_dvs_cifar10_npz(
            candidate,
            root=root,
            time_bins=time_bins,
            spatial_downsample=spatial_downsample,
            use_polarity=use_polarity,
            train_limit=train_limit,
            test_limit=test_limit,
        )
    expected_channels = int(max(1, int(time_bins))) * (2 if use_polarity else 1)
    spatial_downsample = max(1, int(spatial_downsample))

    def _load_npz_arrays(path: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Dict[str, Any]]:
        data = np.load(path)
        x_train = data.get("x_train", data.get("train_inputs"))
        y_train = data.get("y_train", data.get("train_labels"))
        x_test = data.get("x_test", data.get("test_inputs"))
        y_test = data.get("y_test", data.get("test_labels"))
        if x_train is None or y_train is None or x_test is None or y_test is None:
            raise RuntimeError("DVS-CIFAR10 npz missing required arrays.")
        meta_time_bins = data.get("meta_time_bins")
        meta_use_polarity = data.get("meta_use_polarity")
        meta_spatial_downsample = data.get("meta_spatial_downsample")
        meta_train_limit = data.get("meta_train_limit")
        meta_test_limit = data.get("meta_test_limit")
        meta_dataset = data.get("meta_dataset")
        meta: Dict[str, Any] = {
            "time_bins": None if meta_time_bins is None else int(np.asarray(meta_time_bins).item()),
            "use_polarity": None
            if meta_use_polarity is None
            else bool(int(np.asarray(meta_use_polarity).item())),
            "spatial_downsample": None
            if meta_spatial_downsample is None
            else int(np.asarray(meta_spatial_downsample).item()),
            "train_limit": None
            if meta_train_limit is None
            else (
                None
                if int(np.asarray(meta_train_limit).item()) < 0
                else int(np.asarray(meta_train_limit).item())
            ),
            "test_limit": None
            if meta_test_limit is None
            else (
                None if int(np.asarray(meta_test_limit).item()) < 0 else int(np.asarray(meta_test_limit).item())
            ),
            "dataset": None if meta_dataset is None else str(np.asarray(meta_dataset).item()),
        }
        return x_train, y_train, x_test, y_test, meta

    x_train, y_train, x_test, y_test, meta = _load_npz_arrays(candidate)

    mismatch: List[str] = []
    file_time_bins = meta.get("time_bins")
    if file_time_bins is not None and int(file_time_bins) != int(time_bins):
        mismatch.append(f"time_bins(file={file_time_bins}, req={int(time_bins)})")
    file_use_polarity = meta.get("use_polarity")
    if file_use_polarity is not None and bool(file_use_polarity) != bool(use_polarity):
        mismatch.append(f"use_polarity(file={file_use_polarity}, req={bool(use_polarity)})")
    file_spatial_downsample = meta.get("spatial_downsample")
    if file_spatial_downsample is not None and int(file_spatial_downsample) != spatial_downsample:
        mismatch.append(f"spatial_downsample(file={file_spatial_downsample}, req={spatial_downsample})")
    if file_spatial_downsample is None and spatial_downsample != 1:
        mismatch.append("spatial_downsample(file=unknown, req!=1)")

    def _infer_channels(arr: np.ndarray) -> int | None:
        if arr.ndim == 4:
            return int(arr.shape[1])
        if arr.ndim == 5:
            # (N, T, H, W, C) or (N, T, C, H, W)
            if arr.shape[-1] <= 4 and arr.shape[2] > 4 and arr.shape[3] > 4:
                return int(arr.shape[1]) * int(arr.shape[-1])
            if arr.shape[2] <= 4 and arr.shape[3] > 4 and arr.shape[4] > 4:
                return int(arr.shape[1]) * int(arr.shape[2])
            if arr.shape[2] in (1, 2, 3, 4):
                return int(arr.shape[1]) * int(arr.shape[2])
            if arr.shape[-1] in (1, 2, 3, 4):
                return int(arr.shape[1]) * int(arr.shape[-1])
            return int(arr.shape[1]) * int(arr.shape[-1])
        return None

    inferred_channels = _infer_channels(x_train)
    if inferred_channels is not None and inferred_channels != expected_channels:
        mismatch.append(f"channels(file={inferred_channels}, exp={expected_channels})")

    file_train_limit = meta.get("train_limit")
    file_test_limit = meta.get("test_limit")
    if train_limit is None and file_train_limit is not None:
        mismatch.append(f"train_limit(file={file_train_limit}, req=None)")
    if test_limit is None and file_test_limit is not None:
        mismatch.append(f"test_limit(file={file_test_limit}, req=None)")
    if train_limit is not None and file_train_limit is not None and int(file_train_limit) < int(train_limit):
        mismatch.append(f"train_limit(file={file_train_limit}, req={int(train_limit)})")
    if test_limit is not None and file_test_limit is not None and int(file_test_limit) < int(test_limit):
        mismatch.append(f"test_limit(file={file_test_limit}, req={int(test_limit)})")

    if file_time_bins is None and inferred_channels is not None and inferred_channels != expected_channels:
        mismatch.append("legacy_npz_incompatible")
    if train_limit is None and int(x_train.shape[0]) < 1000:
        mismatch.append(f"train_samples(file={int(x_train.shape[0])}, req>=1000)")
    if test_limit is None and int(x_test.shape[0]) < 100:
        mismatch.append(f"test_samples(file={int(x_test.shape[0])}, req>=100)")

    if mismatch and auto_download:
        print(f"[DVS] Cached DVS-CIFAR10 npz incompatible ({', '.join(mismatch)}). Rebuilding: {candidate}")
        _prepare_dvs_cifar10_npz(
            candidate,
            root=root,
            time_bins=time_bins,
            spatial_downsample=spatial_downsample,
            use_polarity=use_polarity,
            train_limit=train_limit,
            test_limit=test_limit,
        )
        x_train, y_train, x_test, y_test, meta = _load_npz_arrays(candidate)
        inferred_channels = _infer_channels(x_train)
        if inferred_channels is not None and inferred_channels != expected_channels:
            raise RuntimeError(
                "DVS-CIFAR10 preparation produced unexpected channel count. "
                f"Expected {expected_channels}, got {inferred_channels}. Path: {candidate}"
            )

    if train_limit is not None:
        x_train = x_train[:train_limit]
        y_train = y_train[:train_limit]
    if test_limit is not None:
        x_test = x_test[:test_limit]
        y_test = y_test[:test_limit]

    try:
        x_train = _dvs_to_channels(x_train)
        x_test = _dvs_to_channels(x_test)
        x_train, x_test = normalize_train_test(x_train, x_test)
    except MemoryError as exc:
        n_train = int(getattr(x_train, "shape", (0,))[0])
        ch = int(expected_channels)
        h = int(getattr(x_train, "shape", (0, 0, 0, 0))[2]) if getattr(x_train, "ndim", 0) == 4 else 128
        w = int(getattr(x_train, "shape", (0, 0, 0, 0))[3]) if getattr(x_train, "ndim", 0) == 4 else 128
        approx_gb = (n_train * ch * h * w * 4) / (1024**3)
        avail = available_physical_memory_bytes()
        avail_gb = None if avail is None else (avail / (1024**3))
        hint = ""
        if avail_gb is not None:
            hint = f" (available RAM ~{avail_gb:.1f} GiB)"
        raise RuntimeError(
            "Out of RAM while materializing DVS-CIFAR10 frames as float32. "
            f"Requested ~{approx_gb:.1f} GiB for x_train alone{hint}. "
            "Fix: reduce memory by using --dvs-spatial-downsample 4 (32x32), "
            "or --dvs-no-polarity, or smaller --dvs-time-bins, or set --train-limit/--test-limit, "
            "or run on a machine with more RAM."
        ) from exc
    return x_train, y_train.astype(np.int64), x_test, y_test.astype(np.int64)


# _prepare_dvs_cifar10_npz：预处理 DVS-CIFAR10 并缓存为 npz 文件。
# 关键步骤：读取缓存/文件 → 预处理/归一化 → 返回数据。
# 算法要点：优先本地缓存，减少重复预处理成本。
def _prepare_dvs_cifar10_npz(
    npz_path: str,
    root: str | None = None,
    time_bins: int = 10,
    spatial_downsample: int = 1,
    use_polarity: bool = True,
    train_limit: int | None = None,
    test_limit: int | None = None,
    compress: bool = True,
) -> str:
    time_bins = max(1, int(time_bins))
    spatial_downsample = max(1, int(spatial_downsample))
    import inspect

    if not ensure_python_package("tonic"):
        raise RuntimeError(
            "Auto-download for DVS-CIFAR10 requires the tonic package.\n"
            "Install it with: pip install tonic"
        )
    import tonic.datasets as tonic_datasets  # type: ignore
    from tonic.download_utils import extract_archive  # type: ignore
    from tonic.transforms import ToFrame  # type: ignore
    if not ensure_python_package("aedat"):
        raise RuntimeError(
            "DVS-CIFAR10 .aedat4 parsing requires the aedat package.\n"
            "Install it with: pip install aedat"
        )

    root_dir = root or os.path.join("data", "dvs_cifar10")
    os.makedirs(root_dir, exist_ok=True)
    out_dir = os.path.dirname(npz_path)
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)

    dataset_cls = None
    for name in ("DVSCIFAR10", "CIFAR10DVS", "DVSCifar10", "cifar10dvs"):
        candidate = getattr(tonic_datasets, name, None)
        if candidate is None or not callable(candidate):
            continue
        dataset_cls = candidate
        break
    if dataset_cls is None:
        raise RuntimeError(
            "Could not find DVS-CIFAR10 dataset in tonic.datasets. "
            "Try upgrading tonic: pip install -U tonic"
        )
    dataset_url = getattr(dataset_cls, "url", None)
    if isinstance(dataset_url, str) and dataset_url.startswith("https://figshare.com/ndownloader/files/"):
        file_id = dataset_url.rstrip("/").split("/")[-1]
        alt_url = f"https://ndownloader.figshare.com/files/{file_id}"
        if alt_url != dataset_url:
            print(f"[DVS] Figshare URL blocked by WAF; using direct downloader: {alt_url}")
            setattr(dataset_cls, "url", alt_url)

    dataset_sig = inspect.signature(dataset_cls)
    params = dataset_sig.parameters
    root_param = None
    for key in ("root", "data_path", "path", "root_dir", "data_dir", "data_root", "save_to", "base_folder"):
        if key in params:
            root_param = key
            break
    split_param = None
    for key in ("train", "train_split", "split", "mode", "set"):
        if key in params:
            split_param = key
            break
    download_param = None
    for key in ("download", "download_and_prepare"):
        if key in params:
            download_param = key
            break
    positional_params = [
        p
        for p in params.values()
        if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
    ]
    required_params = [p for p in positional_params if p.default is inspect._empty]

    # build_dataset：将特征/标签打包为统一数据集对象。
    # 关键步骤：读取超参 → 组装模块/配置 → 返回对象。
    # 算法要点：确保构建结果与训练入口配置一致。
    def build_dataset(train: bool | None) -> Any:
        kwargs: Dict[str, Any] = {}
        args: List[Any] = []
        if root_param is not None:
            kwargs[root_param] = root_dir
        elif required_params:
            args.append(root_dir)
        if train is not None and split_param is not None:
            if split_param == "train":
                kwargs[split_param] = train
            else:
                kwargs[split_param] = "train" if train else "test"
        if download_param is not None:
            kwargs[download_param] = True
        return dataset_cls(*args, **kwargs)

    # split_full_dataset：按比例切分训练/验证/测试集。
    # 关键步骤：可选打乱 → 按比例切分 → 返回子集。
    # 算法要点：维持样本分布一致性与随机性。
    def split_full_dataset(
        dataset: Any,
        train_fraction: float = 0.9,
        seed: int = 0,
    ) -> Tuple[np.ndarray, np.ndarray]:
        total = len(dataset)
        if total <= 0:
            raise RuntimeError("DVS-CIFAR10 dataset is empty.")
        targets = getattr(dataset, "targets", None)
        rng = np.random.default_rng(seed)
        if targets is None or len(targets) != total:
            indices = np.arange(total)
            rng.shuffle(indices)
            split_idx = int(total * train_fraction)
            split_idx = min(max(1, split_idx), total - 1)
            return indices[:split_idx], indices[split_idx:]

        targets = np.asarray(targets)
        train_indices: List[int] = []
        test_indices: List[int] = []
        for label in np.unique(targets):
            label_indices = np.where(targets == label)[0]
            rng.shuffle(label_indices)
            split_idx = int(len(label_indices) * train_fraction)
            split_idx = min(max(1, split_idx), len(label_indices) - 1)
            train_indices.extend(label_indices[:split_idx])
            test_indices.extend(label_indices[split_idx:])
        train_array = np.array(train_indices, dtype=np.int64)
        test_array = np.array(test_indices, dtype=np.int64)
        rng.shuffle(train_array)
        rng.shuffle(test_array)
        return train_array, test_array

    # materialize：将数据集惰性对象物化为 numpy 数组。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def materialize(
        dataset: Any,
        split: str,
        limit: int | None,
        indices: np.ndarray | None = None,
    ) -> Tuple[np.ndarray, np.ndarray]:
        if indices is None:
            indices = np.arange(len(dataset))
        indices = np.asarray(indices, dtype=np.int64)
        total = len(indices)
        count = total if limit is None else min(total, limit)
        if count <= 0:
            raise RuntimeError("Requested zero samples for DVS-CIFAR10 preparation.")

        def maybe_merge_polarity(frames: Any) -> Any:
            if use_polarity or not hasattr(frames, "ndim") or frames.ndim != 4:
                return frames
            # Handle both (T, C, H, W) and (T, H, W, C).
            if frames.shape[1] <= 4 and frames.shape[-1] > 4:
                return frames.sum(axis=1)
            if frames.shape[-1] <= 4 and frames.shape[1] > 4:
                return frames.sum(axis=-1)
            if frames.shape[1] in (1, 2, 3, 4):
                return frames.sum(axis=1)
            return frames.sum(axis=-1)

        def downsample_frames(frames: Any) -> np.ndarray:
            arr = np.asarray(frames)
            if spatial_downsample <= 1:
                return arr
            f = spatial_downsample

            if arr.ndim == 3:
                t, h, w = arr.shape
                if h % f != 0 or w % f != 0:
                    raise RuntimeError(f"Spatial downsample factor {f} does not divide frame size {(h, w)}.")
                reshaped = arr.reshape(t, h // f, f, w // f, f)
                summed = reshaped.sum(axis=(2, 4), dtype=np.uint32)
                return np.clip(summed, 0, np.iinfo(np.uint16).max).astype(np.uint16)

            if arr.ndim == 4:
                # Handle both (T, C, H, W) and (T, H, W, C).
                if arr.shape[1] <= 4 and arr.shape[2] > 4 and arr.shape[3] > 4:
                    t, c, h, w = arr.shape
                    if h % f != 0 or w % f != 0:
                        raise RuntimeError(f"Spatial downsample factor {f} does not divide frame size {(h, w)}.")
                    reshaped = arr.reshape(t, c, h // f, f, w // f, f)
                    summed = reshaped.sum(axis=(3, 5), dtype=np.uint32)
                    return np.clip(summed, 0, np.iinfo(np.uint16).max).astype(np.uint16)
                if arr.shape[-1] <= 4 and arr.shape[1] > 4 and arr.shape[2] > 4:
                    t, h, w, c = arr.shape
                    if h % f != 0 or w % f != 0:
                        raise RuntimeError(f"Spatial downsample factor {f} does not divide frame size {(h, w)}.")
                    reshaped = arr.reshape(t, h // f, f, w // f, f, c)
                    summed = reshaped.sum(axis=(2, 4), dtype=np.uint32)
                    return np.clip(summed, 0, np.iinfo(np.uint16).max).astype(np.uint16)

            raise RuntimeError(f"Unsupported frame shape for spatial downsample: {arr.shape}")

        first_frames, first_label = dataset[int(indices[0])]
        first_frames = downsample_frames(maybe_merge_polarity(first_frames))
        frames_shape = first_frames.shape
        x_data = np.zeros((count,) + frames_shape, dtype=np.uint16)
        y_data = np.zeros((count,), dtype=np.int64)
        x_data[0] = np.asarray(first_frames).astype(np.uint16)
        y_data[0] = int(first_label)
        for idx in range(1, count):
            frames, label = dataset[int(indices[idx])]
            frames = downsample_frames(maybe_merge_polarity(frames))
            x_data[idx] = np.asarray(frames).astype(np.uint16)
            y_data[idx] = int(label)
            if idx % 1000 == 0:
                print(f"[DVS] Processed {idx}/{count} {split} samples")
        return x_data, y_data

    print("[DVS] Preparing DVS-CIFAR10 npz. This may take a while.")
    if split_param is None:
        dataset = build_dataset(None)
        expected_total = 10000
        expected_classes = 10
        total = len(dataset)
        targets = getattr(dataset, "targets", None)
        unique_labels = set(int(x) for x in targets) if targets is not None else set()
        if total < 9000 or len(unique_labels) < expected_classes:
            print(
                f"[DVS] Detected incomplete CIFAR10-DVS dataset (samples={total}, classes={len(unique_labels)}). "
                "Attempting repair via re-extract..."
            )
            try:
                if hasattr(dataset, "download"):
                    dataset.download()
                data_filename = getattr(dataset_cls, "data_filename", None)
                base_dir = getattr(dataset, "location_on_system", None)
                if data_filename and base_dir:
                    for fname in data_filename:
                        archive_path = os.path.join(base_dir, fname)
                        if os.path.exists(archive_path):
                            extract_archive(archive_path)
            except Exception as exc:
                print(f"[DVS] Repair attempt failed: {exc}")
            dataset = build_dataset(None)
            total = len(dataset)
            targets = getattr(dataset, "targets", None)
            unique_labels = set(int(x) for x in targets) if targets is not None else set()
            if total < 9000 or len(unique_labels) < expected_classes:
                raise RuntimeError(
                    "CIFAR10-DVS download seems incomplete. "
                    f"Got {total} samples across {len(unique_labels)} classes; expected ~{expected_total}. "
                    f"Delete '{root_dir}' and retry."
                )
        transform = ToFrame(sensor_size=dataset.sensor_size, n_time_bins=time_bins)
        dataset.transform = transform
        train_idx, test_idx = split_full_dataset(dataset)
        x_train, y_train = materialize(dataset, "train", train_limit, train_idx)
        x_test, y_test = materialize(dataset, "test", test_limit, test_idx)
    else:
        train_dataset = build_dataset(True)
        train_transform = ToFrame(sensor_size=train_dataset.sensor_size, n_time_bins=time_bins)
        train_dataset.transform = train_transform
        x_train, y_train = materialize(train_dataset, "train", train_limit)

        test_dataset = build_dataset(False)
        test_transform = ToFrame(sensor_size=test_dataset.sensor_size, n_time_bins=time_bins)
        test_dataset.transform = test_transform
        x_test, y_test = materialize(test_dataset, "test", test_limit)

    meta: Dict[str, Any] = {
        "meta_version": np.asarray(1, dtype=np.int64),
        "meta_dataset": np.asarray(getattr(dataset_cls, "__name__", "unknown")),
        "meta_time_bins": np.asarray(int(time_bins), dtype=np.int64),
        "meta_spatial_downsample": np.asarray(int(spatial_downsample), dtype=np.int64),
        "meta_use_polarity": np.asarray(1 if use_polarity else 0, dtype=np.int8),
        "meta_train_limit": np.asarray(-1 if train_limit is None else int(train_limit), dtype=np.int64),
        "meta_test_limit": np.asarray(-1 if test_limit is None else int(test_limit), dtype=np.int64),
    }
    if compress:
        np.savez_compressed(npz_path, x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test, **meta)
    else:
        np.savez(npz_path, x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test, **meta)
    print(f"[DVS] Saved prepared dataset to {npz_path}")
    return npz_path


def _prepare_dvs_gesture_npz(
    npz_path: str,
    root: str | None = None,
    time_bins: int = 20,
    spatial_downsample: int = 1,
    use_polarity: bool = True,
    train_limit: int | None = None,
    test_limit: int | None = None,
    compress: bool = True,
) -> str:
    time_bins = max(1, int(time_bins))
    spatial_downsample = max(1, int(spatial_downsample))
    import inspect

    if not ensure_python_package("tonic"):
        raise RuntimeError(
            "Auto-download for DVS-Gesture requires the tonic package.\n"
            "Install it with: pip install tonic"
        )

    import tonic.datasets as tonic_datasets  # type: ignore
    from tonic.transforms import ToFrame  # type: ignore

    root_dir = root or os.path.join("data", "dvs_gesture")
    os.makedirs(root_dir, exist_ok=True)
    out_dir = os.path.dirname(npz_path)
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)

    dataset_cls = None
    for name in ("DVSGesture", "dvsgesture"):
        candidate = getattr(tonic_datasets, name, None)
        if candidate is not None:
            dataset_cls = candidate
            break
    if dataset_cls is None:
        raise RuntimeError(
            "Could not find DVS-Gesture dataset in tonic.datasets. "
            "Try upgrading tonic: pip install -U tonic"
        )
    for attr in ("train_url", "test_url"):
        url = getattr(dataset_cls, attr, None)
        if not isinstance(url, str):
            continue
        if url.startswith("https://figshare.com/ndownloader/files/"):
            file_id = url.rstrip("/").split("/")[-1]
            alt_url = f"https://ndownloader.figshare.com/files/{file_id}"
            if alt_url != url:
                print(f"[DVS] Figshare URL blocked by WAF; using direct downloader: {alt_url}")
                setattr(dataset_cls, attr, alt_url)

    params = inspect.signature(dataset_cls).parameters
    root_param = None
    for key in ("root", "save_to", "data_path", "path"):
        if key in params:
            root_param = key
            break
    split_param = None
    for key in ("train", "train_split", "split", "mode", "set"):
        if key in params:
            split_param = key
            break
    download_param = None
    for key in ("download", "download_and_prepare"):
        if key in params:
            download_param = key
            break
    positional_params = [
        p
        for p in params.values()
        if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
    ]
    required_params = [p for p in positional_params if p.default is inspect._empty]

    def build_dataset(train: bool | None) -> Any:
        kwargs: Dict[str, Any] = {}
        args: List[Any] = []
        if root_param is not None:
            kwargs[root_param] = root_dir
        elif required_params:
            args.append(root_dir)
        if train is not None and split_param is not None:
            if split_param == "train":
                kwargs[split_param] = train
            else:
                kwargs[split_param] = "train" if train else "test"
        if download_param is not None:
            kwargs[download_param] = True
        try:
            return dataset_cls(*args, **kwargs)
        except RuntimeError as exc:
            msg = str(exc)
            if "File not found or corrupted" not in msg:
                raise

            dataset_dir = os.path.join(os.path.expanduser(root_dir), getattr(dataset_cls, "__name__", "DVSGesture"))
            train_url = getattr(dataset_cls, "train_url", None)
            test_url = getattr(dataset_cls, "test_url", None)
            train_md5 = getattr(dataset_cls, "train_md5", None)
            test_md5 = getattr(dataset_cls, "test_md5", None)
            train_filename = getattr(dataset_cls, "train_filename", "ibmGestureTrain.tar.gz")
            test_filename = getattr(dataset_cls, "test_filename", "ibmGestureTest.tar.gz")
            train_path = os.path.join(dataset_dir, train_filename)
            test_path = os.path.join(dataset_dir, test_filename)

            raise RuntimeError(
                "Auto-download for DVS-Gesture failed (Figshare is returning a WAF challenge in this environment).\n"
                "Fix: manually download the archives in a browser and place them here:\n"
                f"  {dataset_dir}\n"
                f"- {train_filename} (md5={train_md5}) from {train_url}\n"
                f"  -> {train_path}\n"
                f"- {test_filename} (md5={test_md5}) from {test_url}\n"
                f"  -> {test_path}\n"
                "Then re-run. Alternatively, provide a preprocessed file via --gesture-npz."
            ) from exc

    def split_full_dataset(
        dataset: Any,
        train_fraction: float = 0.9,
        seed: int = 0,
    ) -> Tuple[np.ndarray, np.ndarray]:
        total = len(dataset)
        if total <= 0:
            raise RuntimeError("DVS-Gesture dataset is empty.")
        targets = getattr(dataset, "targets", None)
        rng = np.random.default_rng(seed)
        if targets is None or len(targets) != total:
            indices = np.arange(total)
            rng.shuffle(indices)
            split_idx = int(total * train_fraction)
            split_idx = min(max(1, split_idx), total - 1)
            return indices[:split_idx], indices[split_idx:]

        targets = np.asarray(targets)
        train_indices: List[int] = []
        test_indices: List[int] = []
        for label in np.unique(targets):
            label_indices = np.where(targets == label)[0]
            rng.shuffle(label_indices)
            split_idx = int(len(label_indices) * train_fraction)
            split_idx = min(max(1, split_idx), len(label_indices) - 1)
            train_indices.extend(label_indices[:split_idx])
            test_indices.extend(label_indices[split_idx:])
        train_array = np.array(train_indices, dtype=np.int64)
        test_array = np.array(test_indices, dtype=np.int64)
        rng.shuffle(train_array)
        rng.shuffle(test_array)
        return train_array, test_array

    def materialize(
        dataset: Any,
        split: str,
        limit: int | None,
        indices: np.ndarray | None = None,
    ) -> Tuple[np.ndarray, np.ndarray]:
        if indices is None:
            indices = np.arange(len(dataset))
        indices = np.asarray(indices, dtype=np.int64)
        total = len(indices)
        count = total if limit is None else min(total, limit)
        if count <= 0:
            raise RuntimeError("Requested zero samples for DVS-Gesture preparation.")

        def maybe_merge_polarity(frames: Any) -> Any:
            if use_polarity or not hasattr(frames, "ndim") or frames.ndim != 4:
                return frames
            # Handle both (T, C, H, W) and (T, H, W, C).
            if frames.shape[1] <= 4 and frames.shape[-1] > 4:
                return frames.sum(axis=1)
            if frames.shape[-1] <= 4 and frames.shape[1] > 4:
                return frames.sum(axis=-1)
            if frames.shape[1] in (1, 2, 3, 4):
                return frames.sum(axis=1)
            return frames.sum(axis=-1)

        first_frames, first_label = dataset[int(indices[0])]
        first_frames = maybe_merge_polarity(first_frames)

        def downsample_frames(frames: Any) -> np.ndarray:
            arr = np.asarray(frames)
            if spatial_downsample <= 1:
                return arr
            f = spatial_downsample

            if arr.ndim == 3:
                t, h, w = arr.shape
                if h % f != 0 or w % f != 0:
                    raise RuntimeError(f"Spatial downsample factor {f} does not divide frame size {(h, w)}.")
                reshaped = arr.reshape(t, h // f, f, w // f, f)
                summed = reshaped.sum(axis=(2, 4), dtype=np.uint32)
                return np.clip(summed, 0, np.iinfo(np.uint16).max).astype(np.uint16)

            if arr.ndim == 4:
                # Handle both (T, C, H, W) and (T, H, W, C).
                if arr.shape[1] <= 4 and arr.shape[2] > 4 and arr.shape[3] > 4:
                    t, c, h, w = arr.shape
                    if h % f != 0 or w % f != 0:
                        raise RuntimeError(f"Spatial downsample factor {f} does not divide frame size {(h, w)}.")
                    reshaped = arr.reshape(t, c, h // f, f, w // f, f)
                    summed = reshaped.sum(axis=(3, 5), dtype=np.uint32)
                    return np.clip(summed, 0, np.iinfo(np.uint16).max).astype(np.uint16)
                if arr.shape[-1] <= 4 and arr.shape[1] > 4 and arr.shape[2] > 4:
                    t, h, w, c = arr.shape
                    if h % f != 0 or w % f != 0:
                        raise RuntimeError(f"Spatial downsample factor {f} does not divide frame size {(h, w)}.")
                    reshaped = arr.reshape(t, h // f, f, w // f, f, c)
                    summed = reshaped.sum(axis=(2, 4), dtype=np.uint32)
                    return np.clip(summed, 0, np.iinfo(np.uint16).max).astype(np.uint16)

            raise RuntimeError(f"Unsupported frame shape for spatial downsample: {arr.shape}")

        first_frames = downsample_frames(first_frames)
        frames_shape = np.asarray(first_frames).shape
        x_data = np.zeros((count,) + frames_shape, dtype=np.uint16)
        y_data = np.zeros((count,), dtype=np.int64)
        x_data[0] = np.asarray(first_frames).astype(np.uint16)
        y_data[0] = int(first_label)
        for idx in range(1, count):
            frames, label = dataset[int(indices[idx])]
            frames = downsample_frames(maybe_merge_polarity(frames))
            x_data[idx] = np.asarray(frames).astype(np.uint16)
            y_data[idx] = int(label)
            if idx % 1000 == 0:
                print(f"[DVS] Processed {idx}/{count} {split} samples")
        return x_data, y_data

    print("[DVS] Preparing DVS-Gesture npz. This may take a while.")
    if split_param is None:
        dataset = build_dataset(None)
        transform = ToFrame(sensor_size=dataset.sensor_size, n_time_bins=time_bins)
        dataset.transform = transform
        train_idx, test_idx = split_full_dataset(dataset)
        x_train, y_train = materialize(dataset, "train", train_limit, train_idx)
        x_test, y_test = materialize(dataset, "test", test_limit, test_idx)
    else:
        train_dataset = build_dataset(True)
        train_transform = ToFrame(sensor_size=train_dataset.sensor_size, n_time_bins=time_bins)
        train_dataset.transform = train_transform
        x_train, y_train = materialize(train_dataset, "train", train_limit)

        test_dataset = build_dataset(False)
        test_transform = ToFrame(sensor_size=test_dataset.sensor_size, n_time_bins=time_bins)
        test_dataset.transform = test_transform
        x_test, y_test = materialize(test_dataset, "test", test_limit)

    meta: Dict[str, Any] = {
        "meta_version": np.asarray(1, dtype=np.int64),
        "meta_dataset": np.asarray(getattr(dataset_cls, "__name__", "unknown")),
        "meta_time_bins": np.asarray(int(time_bins), dtype=np.int64),
        "meta_spatial_downsample": np.asarray(int(spatial_downsample), dtype=np.int64),
        "meta_use_polarity": np.asarray(1 if use_polarity else 0, dtype=np.int8),
        "meta_train_limit": np.asarray(-1 if train_limit is None else int(train_limit), dtype=np.int64),
        "meta_test_limit": np.asarray(-1 if test_limit is None else int(test_limit), dtype=np.int64),
    }
    if compress:
        np.savez_compressed(npz_path, x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test, **meta)
    else:
        np.savez(npz_path, x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test, **meta)
    print(f"[DVS] Saved prepared dataset to {npz_path}")
    return npz_path


# =============================
# 模型定义（编码器 + Local Rule/E-Prop/BPTT/FPTT）
# =============================

class SimpleCNNEncoder(torch.nn.Module):
    # __init__：初始化模块超参数、子模块与状态缓存。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def __init__(
        self,
        in_channels: int,
        channels: Tuple[int, int],
        kernel_size: int = 3,
    ) -> None:
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = torch.nn.Conv2d(
            in_channels,
            channels[0],
            kernel_size=kernel_size,
            stride=2,
            padding=padding,
        )
        self.conv2 = torch.nn.Conv2d(
            channels[0],
            channels[1],
            kernel_size=kernel_size,
            stride=2,
            padding=padding,
        )
        self.act = torch.nn.ReLU()

    # forward：执行当前模块前向传播并返回输出/状态。
    # 关键步骤：接收输入/状态 → 更新递推 → 输出结果。
    # 算法要点：保持状态递推与非线性一致性。
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.act(self.conv1(x))
        x = self.act(self.conv2(x))
        return x


# _conv_rnn_step：执行一次卷积 RNN 状态更新（递归卷积+输入驱动）。
# 关键步骤：卷积递推 → 加输入驱动 → 非线性更新。
# 算法要点：保持状态递推稳定，避免梯度爆炸。
def _conv_rnn_step(
    h_prev: torch.Tensor,
    f: torch.Tensor,
    W_hh: torch.Tensor,
    W_xh: torch.Tensor,
    b_h: torch.Tensor,
    padding: int,
) -> torch.Tensor:
    recur = F.conv2d(h_prev, W_hh, padding=padding)
    drive = F.conv2d(f, W_xh, padding=0)
    return torch.tanh(recur + drive + b_h)


# _readout_logits：对卷积特征池化后输出分类 logits。
# 关键步骤：池化/展平 → 线性映射 → 输出 logits。
# 算法要点：使用线性读出保持与分类头一致。
def _readout_logits(h: torch.Tensor, W_hy: torch.Tensor, b_y: torch.Tensor) -> torch.Tensor:
    z_t = h.mean(dim=(2, 3))
    return z_t @ W_hy.T + b_y


# 局部学习规则的卷积RNN，可选 FPTT 软标签。
class TorchLocalRuleConvRNN:
    # __init__：初始化模块超参数、子模块与状态缓存。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def __init__(
        self,
        in_channels: int,
        enc_channels: Tuple[int, int],
        hidden_channels: int,
        output_size: int,
        steps: int = 8,
        eta: float = 1e-3,
        lambda_window: int = 50,
        loss_mode: str = "ce",
        max_grad_norm: float = 5.0,
        train_encoder: bool = False,
        seed: int | None = None,
        device: torch.device | str | None = None,
        kernel_size: int = 3,
    ) -> None:
        if loss_mode not in {"ce", "mse"}:
            raise ValueError("loss_mode must be 'ce' or 'mse'.")
        rng = np.random.default_rng(seed)
        self.device = resolve_device(device)
        self.hidden_size = int(hidden_channels)
        self.input_size = int(in_channels)
        self.output_size = int(output_size)
        self.eta = float(eta)
        self.loss_mode = loss_mode
        self.max_grad_norm = float(max_grad_norm)
        self.train_encoder = bool(train_encoder)
        self.kernel_size = int(kernel_size)
        self.padding = self.kernel_size // 2
        self.steps = max(1, int(steps))
        self.step_weights: torch.Tensor | None = None

        self.encoder = SimpleCNNEncoder(in_channels, enc_channels, kernel_size=self.kernel_size).to(self.device)
        self.enc_channels = enc_channels

        self.W_xh = to_tensor(
            rng.standard_normal((hidden_channels, enc_channels[1], 1, 1)).astype(np.float32) * 0.1,
            self.device,
        )
        self.W_hh = to_tensor(
            rng.standard_normal((hidden_channels, hidden_channels, self.kernel_size, self.kernel_size)).astype(
                np.float32
            ),
            self.device,
        )
        self.b_h = torch.zeros((1, hidden_channels, 1, 1), dtype=torch.float32, device=self.device)
        self.W_hy = to_tensor(
            rng.standard_normal((output_size, hidden_channels)).astype(np.float32) * 0.1,
            self.device,
        )
        self.b_y = torch.zeros((output_size,), dtype=torch.float32, device=self.device)

        self.alpha_rho = 0.995
        self.alpha_clip_min = -0.99
        self.alpha_clip_max = 0.99

        self.lambda_window = max(2, int(lambda_window))
        self.lambda_rho = (self.lambda_window - 1) / self.lambda_window
        self.lambda_cap = 0.99
        self.denom_floor = 1e-3
        self.eps_lambda = 1e-8

        self.use_fptt_surrogates = False
        self.beta_schedule = "linear"
        self.fptt_Q_prev: torch.Tensor | None = None
        self.fptt_Q_sum: torch.Tensor | None = None
        self.fptt_Q_count: torch.Tensor | None = None

        self.reset_learning_state()

    # reset_learning_state：清理训练过程中的缓存与统计量。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def reset_learning_state(self) -> None:
        shape = (1, self.hidden_size, 1, 1)
        self.alpha_num = torch.zeros(shape, dtype=torch.float32, device=self.device)
        self.alpha_den = torch.zeros(shape, dtype=torch.float32, device=self.device)
        self.alpha_hat = torch.zeros(shape, dtype=torch.float32, device=self.device)
        self.S_A2 = torch.zeros(shape, dtype=torch.float32, device=self.device)
        self.S_AB = torch.zeros(shape, dtype=torch.float32, device=self.device)
        self.lambda_vals = torch.zeros(shape, dtype=torch.float32, device=self.device)

    # initialize_weights_with_gain：按指定增益初始化权重矩阵。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def initialize_weights_with_gain(self, g: float, seed: int | None = None) -> None:
        rng = np.random.default_rng(seed)
        fan_in = self.hidden_size * self.kernel_size * self.kernel_size
        std_dev = g / math.sqrt(fan_in)
        self.W_hh = to_tensor(
            rng.standard_normal(self.W_hh.shape).astype(np.float32) * std_dev,
            self.device,
        )

    # enable_fptt_surrogates：启用 FPTT 软标签/替代目标并设置权重。
    # 关键步骤：初始化 Q → 配置权重 → 启用标志。
    # 算法要点：软标签平滑提升训练稳定性。
    def enable_fptt_surrogates(
        self,
        time_steps: int,
        output_size: int,
        Q_init: np.ndarray | torch.Tensor | None = None,
        beta_schedule: str = "linear",
    ) -> None:
        # FPTT 软标签：若未提供 Q_init，用均匀分布初始化。
        if Q_init is None:
            Q_init = np.full((output_size, time_steps), 1.0 / output_size, dtype=np.float32)
        self.use_fptt_surrogates = True
        self.beta_schedule = beta_schedule
        self.fptt_Q_prev = to_tensor(Q_init, self.device, dtype=torch.float32).clone()
        self.reset_fptt_epoch_accumulators(time_steps, output_size)

    # reset_fptt_epoch_accumulators：清空 FPTT 的 epoch 内累计统计。
    # 关键步骤：清空累加量 → 重置状态。
    # 算法要点：防止跨 epoch 污染统计量。
    def reset_fptt_epoch_accumulators(
        self,
        time_steps: int | None = None,
        output_size: int | None = None,
    ) -> None:
        if self.fptt_Q_prev is None and (time_steps is None or output_size is None):
            raise ValueError("Call enable_fptt_surrogates first.")
        if time_steps is None or output_size is None:
            time_steps = int(self.fptt_Q_prev.shape[1])
            output_size = int(self.fptt_Q_prev.shape[0])
        self.fptt_Q_sum = torch.zeros(
            (output_size, time_steps),
            dtype=torch.float64,
            device=self.device,
        )
        self.fptt_Q_count = torch.zeros((time_steps,), dtype=torch.int64, device=self.device)

    # finalize_fptt_epoch：在 epoch 结束时更新 FPTT 累计量。
    # 关键步骤：汇总累积 → 更新 Q → 返回统计。
    # 算法要点：用累计概率更新软标签分布。
    def finalize_fptt_epoch(self) -> None:
        if self.fptt_Q_sum is None or self.fptt_Q_count is None:
            return
        counts = torch.maximum(self.fptt_Q_count.to(torch.float64), torch.ones_like(self.fptt_Q_count))
        Q_new = (self.fptt_Q_sum / counts.unsqueeze(0)).to(torch.float32)
        self.fptt_Q_prev = Q_new

    # train_batch：训练单个 batch 并返回损失/指标。
    # 关键步骤：前向计算 → 计算损失 → 反向更新/统计。
    # 算法要点：依据 FPTT/E-Prop/BPTT 分支处理梯度更新。
    def train_batch(
        self,
        inputs_batch: np.ndarray | torch.Tensor,
        targets_batch: np.ndarray | torch.Tensor,
        h_prev_batch: np.ndarray | torch.Tensor | None,
    ) -> Tuple[float, torch.Tensor]:
        inputs = to_tensor(inputs_batch, self.device, dtype=torch.float32)
        targets = to_tensor(targets_batch, self.device, dtype=torch.float32)
        batch_size = int(inputs.shape[0])
        time_steps = int(targets.shape[2])
        step_weights = self.step_weights
        if step_weights is not None and not torch.is_tensor(step_weights):
            step_weights = to_tensor(step_weights, self.device, dtype=torch.float32)
        weight_norm = float(step_weights.sum().item()) if step_weights is not None else float(time_steps)
        weight_norm = max(weight_norm, 1.0)
        eps = 1e-12

        if self.train_encoder:
            f = self.encoder(inputs)
            f_detached = f.detach()
        else:
            with torch.no_grad():
                f = self.encoder(inputs)
            f_detached = f

        if h_prev_batch is None:
            h_prev = torch.zeros(
                (batch_size, self.hidden_size, f_detached.shape[2], f_detached.shape[3]),
                dtype=torch.float32,
                device=self.device,
            )
        else:
            h_prev = to_tensor(h_prev_batch, self.device, dtype=torch.float32)

        prev_g = None
        prev_u = None
        prev_delta = None
        total_loss = torch.zeros((), device=self.device)
        delta_f_total = torch.zeros_like(f_detached) if self.train_encoder else None

        for t in range(time_steps):
            step_weight = step_weights[t] if step_weights is not None else 1.0
            h_t = _conv_rnn_step(h_prev, f_detached, self.W_hh, self.W_xh, self.b_h, self.padding)
            logits = _readout_logits(h_t, self.W_hy, self.b_y)
            y_target_t = targets[:, :, t]

            if self.loss_mode == "ce":
                probs = softmax(logits)
                # FPTT：用历史软标签 Q_t 平滑当前标签，并累计概率更新 Q。
                if self.use_fptt_surrogates and self.fptt_Q_prev is not None:
                    beta_t = float(t + 1) / float(time_steps) if self.beta_schedule == "linear" else 1.0
                    Q_t = self.fptt_Q_prev[:, t].view(1, -1)
                    Q_t_batch = Q_t.repeat(batch_size, 1)
                    y_tilde = beta_t * y_target_t + (1.0 - beta_t) * Q_t_batch
                    dL_dlogits = (probs - y_tilde) * step_weight
                    loss_t = step_weight * -torch.mean(torch.sum(y_tilde * torch.log(probs + eps), dim=1))
                    if self.fptt_Q_sum is not None and self.fptt_Q_count is not None:
                        self.fptt_Q_sum[:, t] += torch.sum(probs, dim=0)
                        self.fptt_Q_count[t] += batch_size
                else:
                    dL_dlogits = (probs - y_target_t) * step_weight
                    loss_t = step_weight * -torch.mean(torch.sum(y_target_t * torch.log(probs + eps), dim=1))
            else:
                error = logits - y_target_t
                dL_dlogits = error * step_weight
                loss_t = step_weight * 0.5 * torch.mean(torch.sum(error**2, dim=1))

            total_loss = total_loss + loss_t

            g_t = dL_dlogits @ self.W_hy
            u_t = 1.0 - h_t**2
            g_t_map = (g_t / float(h_t.shape[2] * h_t.shape[3])).view(batch_size, -1, 1, 1)
            g_t_map = g_t_map.expand_as(h_t)

            # 估计局部线性系数 alpha_hat 与稳定性修正项。
            lambda_used = self.lambda_vals
            denominator = 1.0 - lambda_used * u_t
            denom_mask = torch.abs(denominator) < self.denom_floor
            denominator = torch.where(
                denom_mask,
                self.denom_floor * torch.sign(denominator + 1e-12),
                denominator,
            )
            delta_t = (u_t * g_t_map) / denominator

            # Estimate alpha_hat from the teaching signal itself:
            # \widetilde{\Delta}_t \approx \alpha \widetilde{\Delta}_{t-1}.
            if prev_delta is not None:
                dtp_mean = (delta_t * prev_delta).mean(dim=(0, 2, 3), keepdim=True)
                dpp_mean = (prev_delta**2).mean(dim=(0, 2, 3), keepdim=True)
                self.alpha_num = self.alpha_rho * self.alpha_num + (1.0 - self.alpha_rho) * dtp_mean
                self.alpha_den = self.alpha_rho * self.alpha_den + (1.0 - self.alpha_rho) * dpp_mean
                raw_alpha = self.alpha_num / (self.alpha_den + 1e-8)
                self.alpha_hat = torch.clamp(raw_alpha, self.alpha_clip_min, self.alpha_clip_max)

            scale = float(batch_size)
            dW_hh = torch.nn.grad.conv2d_weight(
                h_prev,
                self.W_hh.shape,
                delta_t,
                stride=1,
                padding=self.padding,
            ) / scale
            dW_xh = torch.nn.grad.conv2d_weight(
                f_detached,
                self.W_xh.shape,
                delta_t,
                stride=1,
                padding=0,
            ) / scale
            db_h = delta_t.mean(dim=(0, 2, 3), keepdim=True)

            z_t = h_t.mean(dim=(2, 3))
            dW_hy = (dL_dlogits.T @ z_t) / max(batch_size, 1)
            db_y = dL_dlogits.mean(dim=0)

            grads = [dW_hh, dW_xh, db_h, dW_hy, db_y]
            dW_hh, dW_xh, db_h, dW_hy, db_y = clip_gradients(grads, self.max_grad_norm)

            if self.train_encoder:
                delta_f = torch.nn.grad.conv2d_input(
                    f_detached.shape,
                    self.W_xh,
                    delta_t,
                    stride=1,
                    padding=0,
                )
                delta_f_total = delta_f_total + (delta_f / scale)

            with torch.no_grad():
                self.W_hh.add_(-self.eta * dW_hh)
                self.W_xh.add_(-self.eta * dW_xh)
                self.b_h.add_(-self.eta * db_h)
                self.W_hy.add_(-self.eta * dW_hy)
                self.b_y.add_(-self.eta * db_y)

            if prev_g is not None and prev_u is not None:
                A_s = prev_u * u_t * (self.alpha_hat * prev_g - g_t_map)
                B_s = self.alpha_hat * prev_u * prev_g - u_t * g_t_map
                A2_mean = A_s.pow(2).mean(dim=(0, 2, 3), keepdim=True)
                AB_mean = (A_s * B_s).mean(dim=(0, 2, 3), keepdim=True)
                self.S_A2 = self.lambda_rho * self.S_A2 + (1.0 - self.lambda_rho) * A2_mean
                self.S_AB = self.lambda_rho * self.S_AB + (1.0 - self.lambda_rho) * AB_mean

                lambda_unproj = self.S_AB / (self.S_A2 + self.eps_lambda)
                u_abs_max = torch.amax(torch.abs(u_t), dim=(0, 2, 3), keepdim=True) + 1e-12
                safe_cap = (1.0 - self.denom_floor) / u_abs_max
                cap = torch.minimum(safe_cap, torch.full_like(safe_cap, self.lambda_cap))
                self.lambda_vals = torch.clamp(lambda_unproj, min=-cap, max=cap)

            prev_g = g_t_map
            prev_u = u_t
            h_prev = h_t
            prev_delta = delta_t

        if self.train_encoder and delta_f_total is not None:
            self.encoder.zero_grad(set_to_none=True)
            f.backward(delta_f_total)
            grads = [p.grad for p in self.encoder.parameters() if p.grad is not None]
            if grads:
                grads = clip_gradients(grads, self.max_grad_norm)
                with torch.no_grad():
                    for param, grad in zip(self.encoder.parameters(), grads):
                        param.add_(-self.eta * grad)

        avg_loss = float((total_loss / weight_norm).item())
        return avg_loss, h_prev.detach()

    # forward_cycle：执行完整时间序列循环的前向与状态更新。
    # 关键步骤：接收输入/状态 → 更新递推 → 输出结果。
    # 算法要点：保持状态递推与非线性一致性。
    def forward_cycle(
        self, inputs_cycle: np.ndarray | torch.Tensor, h_prev_cycle: np.ndarray | torch.Tensor | None
    ) -> Tuple[List[torch.Tensor], torch.Tensor]:
        inputs = to_tensor(inputs_cycle, self.device, dtype=torch.float32)
        with torch.no_grad():
            f = self.encoder(inputs)
            batch_size = int(inputs.shape[0])
            if h_prev_cycle is None:
                h_prev = torch.zeros(
                    (batch_size, self.hidden_size, f.shape[2], f.shape[3]),
                    dtype=torch.float32,
                    device=self.device,
                )
            else:
                h_prev = to_tensor(h_prev_cycle, self.device, dtype=torch.float32)
            outputs: List[torch.Tensor] = []
            for _ in range(self.steps):
                h_prev = _conv_rnn_step(h_prev, f, self.W_hh, self.W_xh, self.b_h, self.padding)
                outputs.append(_readout_logits(h_prev, self.W_hy, self.b_y))
        return outputs, h_prev


# E-Prop：显式维护 eligibility traces，并用反馈矩阵构造学习信号。
class TorchEPropConvRNN:
    # __init__：初始化模块超参数、子模块与状态缓存。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def __init__(
        self,
        in_channels: int,
        enc_channels: Tuple[int, int],
        hidden_channels: int,
        output_size: int,
        steps: int = 8,
        eta: float = 1e-3,
        decay_lambda: float = 0.95,
        feedback: str = "symmetric",
        loss_mode: str = "ce",
        max_grad_norm: float = 5.0,
        train_encoder: bool = False,
        seed: int | None = None,
        device: torch.device | str | None = None,
        kernel_size: int = 3,
    ) -> None:
        if loss_mode not in {"ce", "mse"}:
            raise ValueError("loss_mode must be 'ce' or 'mse'.")
        rng = np.random.default_rng(seed)
        self.device = resolve_device(device)
        self.hidden_size = int(hidden_channels)
        self.input_size = int(in_channels)
        self.output_size = int(output_size)
        self.eta = float(eta)
        self.decay_lambda = float(decay_lambda)
        self.loss_mode = loss_mode
        self.max_grad_norm = float(max_grad_norm)
        self.train_encoder = bool(train_encoder)
        self.kernel_size = int(kernel_size)
        self.padding = self.kernel_size // 2
        self.steps = max(1, int(steps))
        self.step_weights: torch.Tensor | None = None

        feedback = feedback.lower()
        if feedback not in {"random", "symmetric"}:
            raise ValueError("feedback must be 'random' or 'symmetric'.")
        self.feedback_type = feedback

        self.encoder = SimpleCNNEncoder(in_channels, enc_channels, kernel_size=self.kernel_size).to(self.device)
        self.W_xh = to_tensor(
            rng.standard_normal((hidden_channels, enc_channels[1], 1, 1)).astype(np.float32) * 0.1,
            self.device,
        )
        self.W_hh = to_tensor(
            rng.standard_normal((hidden_channels, hidden_channels, self.kernel_size, self.kernel_size)).astype(
                np.float32
            ),
            self.device,
        )
        self.b_h = torch.zeros((1, hidden_channels, 1, 1), dtype=torch.float32, device=self.device)
        self.W_hy = to_tensor(
            rng.standard_normal((output_size, hidden_channels)).astype(np.float32) * 0.1,
            self.device,
        )
        self.b_y = torch.zeros((output_size,), dtype=torch.float32, device=self.device)

        self.B_fb: torch.Tensor | None = None
        if self.feedback_type == "random":
            self.B_fb = to_tensor(
                rng.standard_normal((hidden_channels, output_size)).astype(np.float32)
                / np.sqrt(max(1, output_size)),
                self.device,
            )

    # initialize_weights_with_gain：按指定增益初始化权重矩阵。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def initialize_weights_with_gain(self, g: float, seed: int | None = None) -> None:
        rng = np.random.default_rng(seed)
        fan_in = self.hidden_size * self.kernel_size * self.kernel_size
        std_dev = g / math.sqrt(fan_in)
        self.W_hh = to_tensor(
            rng.standard_normal(self.W_hh.shape).astype(np.float32) * std_dev,
            self.device,
        )

    # _learning_signal：计算 E-Prop 学习信号（反馈矩阵与误差结合）。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def _learning_signal(self, dL_dlogits: torch.Tensor) -> torch.Tensor:
        # 对称反馈使用 W_hy^T，随机反馈使用固定 B_fb。
        if self.feedback_type == "symmetric":
            return dL_dlogits @ self.W_hy
        if self.B_fb is None:
            raise RuntimeError("Random feedback matrix B_fb is None.")
        return dL_dlogits @ self.B_fb.T

    # train_batch：训练单个 batch 并返回损失/指标。
    # 关键步骤：前向计算 → 计算损失 → 反向更新/统计。
    # 算法要点：依据 FPTT/E-Prop/BPTT 分支处理梯度更新。
    def train_batch(
        self,
        inputs_batch: np.ndarray | torch.Tensor,
        targets_batch: np.ndarray | torch.Tensor,
        h_prev_batch: np.ndarray | torch.Tensor | None,
    ) -> Tuple[float, torch.Tensor]:
        inputs = to_tensor(inputs_batch, self.device, dtype=torch.float32)
        targets = to_tensor(targets_batch, self.device, dtype=torch.float32)
        batch_size = int(inputs.shape[0])
        time_steps = int(targets.shape[2])
        step_weights = self.step_weights
        if step_weights is not None and not torch.is_tensor(step_weights):
            step_weights = to_tensor(step_weights, self.device, dtype=torch.float32)
        weight_norm = float(step_weights.sum().item()) if step_weights is not None else float(time_steps)
        weight_norm = max(weight_norm, 1.0)
        eps = 1e-12

        if self.train_encoder:
            f = self.encoder(inputs)
            f_detached = f.detach()
        else:
            with torch.no_grad():
                f = self.encoder(inputs)
            f_detached = f

        if h_prev_batch is None:
            h_prev = torch.zeros(
                (batch_size, self.hidden_size, f_detached.shape[2], f_detached.shape[3]),
                dtype=torch.float32,
                device=self.device,
            )
        else:
            h_prev = to_tensor(h_prev_batch, self.device, dtype=torch.float32)

        # Eligibility traces：使用 RFLO/E-Prop 风格的低通 presynaptic trace，避免 RTRL 级别的内存与计算。
        # 这让 E-Prop 的开销接近一次卷积反传，而不是 O(hidden^2 * H * W) 的 7D trace。
        feat_h = int(f_detached.shape[2])
        feat_w = int(f_detached.shape[3])
        spatial_scale = float(feat_h * feat_w)
        batch_scale = max(batch_size, 1)
        h_trace = torch.zeros_like(h_prev)
        f_trace = torch.zeros_like(f_detached)
        total_loss = torch.zeros((), device=self.device)
        delta_f_total = torch.zeros_like(f_detached) if self.train_encoder else None

        for t in range(time_steps):
            step_weight = step_weights[t] if step_weights is not None else 1.0

            # 低通滤波的 presynaptic trace：递归项不跨通道传播（E-Prop/RFLO 的局部近似）。
            h_trace = self.decay_lambda * h_trace + h_prev
            f_trace = self.decay_lambda * f_trace + f_detached
            h_t = _conv_rnn_step(h_prev, f_detached, self.W_hh, self.W_xh, self.b_h, self.padding)
            logits = _readout_logits(h_t, self.W_hy, self.b_y)
            y_true_t = targets[:, :, t]
            if self.loss_mode == "ce":
                probs = softmax(logits)
                dL_dlogits = (probs - y_true_t) * step_weight
                loss_t = step_weight * -torch.mean(torch.sum(y_true_t * torch.log(probs + eps), dim=1))
            else:
                error = logits - y_true_t
                dL_dlogits = error * step_weight
                loss_t = step_weight * 0.5 * torch.mean(torch.sum(error**2, dim=1))
            total_loss = total_loss + loss_t

            l_t = self._learning_signal(dL_dlogits)
            l_scale = l_t / spatial_scale

            psi = 1.0 - h_t**2
            delta_t = psi * l_scale[:, :, None, None]

            dW_hh = (
                torch.nn.grad.conv2d_weight(
                    h_trace,
                    self.W_hh.shape,
                    delta_t,
                    stride=1,
                    padding=self.padding,
                )
                / batch_scale
            )
            dW_xh = (
                torch.nn.grad.conv2d_weight(
                    f_trace,
                    self.W_xh.shape,
                    delta_t,
                    stride=1,
                    padding=0,
                )
                / batch_scale
            )
            db_h = delta_t.sum(dim=(0, 2, 3), keepdim=True) / batch_scale

            z_t = h_t.mean(dim=(2, 3))
            dW_hy = (dL_dlogits.T @ z_t) / batch_scale
            db_y = dL_dlogits.mean(dim=0)

            grads = [dW_hh, dW_xh, db_h, dW_hy, db_y]
            dW_hh, dW_xh, db_h, dW_hy, db_y = clip_gradients(grads, self.max_grad_norm)

            if self.train_encoder and delta_f_total is not None:
                delta_f = torch.nn.grad.conv2d_input(
                    f_detached.shape,
                    self.W_xh,
                    delta_t,
                    stride=1,
                    padding=0,
                )
                delta_f_total = delta_f_total + (delta_f / batch_scale)

            with torch.no_grad():
                self.W_hh.add_(-self.eta * dW_hh)
                self.W_xh.add_(-self.eta * dW_xh)
                self.b_h.add_(-self.eta * db_h)
                self.W_hy.add_(-self.eta * dW_hy)
                self.b_y.add_(-self.eta * db_y)

            h_prev = h_t

        if self.train_encoder and delta_f_total is not None:
            self.encoder.zero_grad(set_to_none=True)
            f.backward(delta_f_total)
            grads = [p.grad for p in self.encoder.parameters() if p.grad is not None]
            if grads:
                grads = clip_gradients(grads, self.max_grad_norm)
                with torch.no_grad():
                    for param, grad in zip(self.encoder.parameters(), grads):
                        param.add_(-self.eta * grad)

        avg_loss = float((total_loss / weight_norm).item())
        return avg_loss, h_prev.detach()

    # forward_cycle：执行完整时间序列循环的前向与状态更新。
    # 关键步骤：接收输入/状态 → 更新递推 → 输出结果。
    # 算法要点：保持状态递推与非线性一致性。
    def forward_cycle(
        self, inputs_cycle: np.ndarray | torch.Tensor, h_prev_cycle: np.ndarray | torch.Tensor | None
    ) -> Tuple[List[torch.Tensor], torch.Tensor]:
        inputs = to_tensor(inputs_cycle, self.device, dtype=torch.float32)
        with torch.no_grad():
            f = self.encoder(inputs)
            batch_size = int(inputs.shape[0])
            if h_prev_cycle is None:
                h_prev = torch.zeros(
                    (batch_size, self.hidden_size, f.shape[2], f.shape[3]),
                    dtype=torch.float32,
                    device=self.device,
                )
            else:
                h_prev = to_tensor(h_prev_cycle, self.device, dtype=torch.float32)
            outputs: List[torch.Tensor] = []
            for _ in range(self.steps):
                h_prev = _conv_rnn_step(h_prev, f, self.W_hh, self.W_xh, self.b_h, self.padding)
                outputs.append(_readout_logits(h_prev, self.W_hy, self.b_y))
        return outputs, h_prev


class TorchBPTTConvRNN(torch.nn.Module):
    # __init__：初始化模块超参数、子模块与状态缓存。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def __init__(
        self,
        in_channels: int,
        enc_channels: Tuple[int, int],
        hidden_channels: int,
        output_size: int,
        steps: int = 8,
        eta: float = 1e-3,
        loss_mode: str = "ce",
        max_grad_norm: float = 5.0,
        tbptt_steps: int | None = None,
        time_normalization: bool = True,
        train_encoder: bool = False,
        seed: int | None = None,
        device: torch.device | str | None = None,
        kernel_size: int = 3,
    ) -> None:
        super().__init__()
        if loss_mode not in {"ce", "mse"}:
            raise ValueError("loss_mode must be 'ce' or 'mse'.")
        rng = np.random.default_rng(seed)
        self.device = resolve_device(device)
        self.hidden_size = int(hidden_channels)
        self.input_size = int(in_channels)
        self.output_size = int(output_size)
        self.eta = float(eta)
        self.loss_mode = loss_mode
        self.max_grad_norm = float(max_grad_norm)
        self.tbptt_steps = None if tbptt_steps is None else int(tbptt_steps)
        self.time_normalization = bool(time_normalization)
        self.train_encoder = bool(train_encoder)
        self.kernel_size = int(kernel_size)
        self.padding = self.kernel_size // 2
        self.steps = max(1, int(steps))
        self.step_weights: torch.Tensor | None = None

        self.encoder = SimpleCNNEncoder(in_channels, enc_channels, kernel_size=self.kernel_size).to(self.device)
        self.W_xh = torch.nn.Parameter(
            to_tensor(
                rng.standard_normal((hidden_channels, enc_channels[1], 1, 1)).astype(np.float32) * 0.1,
                self.device,
            )
        )
        self.W_hh = torch.nn.Parameter(
            to_tensor(
                rng.standard_normal((hidden_channels, hidden_channels, self.kernel_size, self.kernel_size)).astype(
                    np.float32
                ),
                self.device,
            )
        )
        self.b_h = torch.nn.Parameter(torch.zeros((1, hidden_channels, 1, 1), dtype=torch.float32, device=self.device))
        self.W_hy = torch.nn.Parameter(
            to_tensor(rng.standard_normal((output_size, hidden_channels)).astype(np.float32) * 0.1, self.device)
        )
        self.b_y = torch.nn.Parameter(torch.zeros((output_size,), dtype=torch.float32, device=self.device))
        self.optimizer = torch.optim.SGD(self.parameters(), lr=self.eta)

    # reset_optimizer：重置优化器状态（动量/梯度）。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def reset_optimizer(self) -> None:
        self.optimizer = torch.optim.SGD(self.parameters(), lr=self.eta)

    # initialize_weights_with_gain：按指定增益初始化权重矩阵。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def initialize_weights_with_gain(self, g: float, seed: int | None = None) -> None:
        rng = np.random.default_rng(seed)
        fan_in = self.hidden_size * self.kernel_size * self.kernel_size
        std_dev = g / math.sqrt(fan_in)
        new_w_hh = to_tensor(
            rng.standard_normal(self.W_hh.shape).astype(np.float32) * std_dev,
            self.device,
        )
        with torch.no_grad():
            self.W_hh.copy_(new_w_hh)

    # train_batch：训练单个 batch 并返回损失/指标。
    # 关键步骤：前向计算 → 计算损失 → 反向更新/统计。
    # 算法要点：依据 FPTT/E-Prop/BPTT 分支处理梯度更新。
    def train_batch(
        self,
        inputs_batch: np.ndarray | torch.Tensor,
        targets_batch: np.ndarray | torch.Tensor,
        h_prev_batch: np.ndarray | torch.Tensor | None,
    ) -> Tuple[float, torch.Tensor]:
        inputs = to_tensor(inputs_batch, self.device, dtype=torch.float32)
        targets = to_tensor(targets_batch, self.device, dtype=torch.float32)
        batch_size = int(inputs.shape[0])
        time_steps = int(targets.shape[2])
        step_weights = self.step_weights
        if step_weights is not None and not torch.is_tensor(step_weights):
            step_weights = to_tensor(step_weights, self.device, dtype=torch.float32)
        weight_norm = float(step_weights.sum().item()) if step_weights is not None else float(time_steps)
        weight_norm = max(weight_norm, 1.0)

        tbptt_steps = 0 if self.tbptt_steps is None else int(self.tbptt_steps)
        use_tbptt = 0 < tbptt_steps < time_steps
        eps = 1e-12

        if not use_tbptt:
            self.optimizer.zero_grad(set_to_none=True)
            if self.train_encoder:
                f = self.encoder(inputs)
            else:
                with torch.no_grad():
                    f = self.encoder(inputs)
            if h_prev_batch is None:
                h_state = torch.zeros(
                    (batch_size, self.hidden_size, f.shape[2], f.shape[3]),
                    dtype=torch.float32,
                    device=self.device,
                )
            else:
                h_state = to_tensor(h_prev_batch, self.device, dtype=torch.float32)

            total_loss = torch.zeros((), device=self.device)
            for t in range(time_steps):
                step_weight = step_weights[t] if step_weights is not None else 1.0
                h_state = _conv_rnn_step(h_state, f, self.W_hh, self.W_xh, self.b_h, self.padding)
                logits = _readout_logits(h_state, self.W_hy, self.b_y)
                y_true_t = targets[:, :, t]
                if self.loss_mode == "ce":
                    probs = softmax(logits)
                    loss_t = -torch.sum(y_true_t * torch.log(probs + eps), dim=1).mean()
                else:
                    error = logits - y_true_t
                    loss_t = 0.5 * torch.sum(error**2, dim=1).mean()
                if step_weights is not None:
                    loss_t = loss_t * step_weight
                total_loss = total_loss + loss_t

            loss_scale = weight_norm if self.time_normalization else 1.0
            scaled_loss = total_loss / max(loss_scale, 1.0)
            scaled_loss.backward()
            if self.max_grad_norm > 0:
                clip_norm = self.max_grad_norm
                if not self.time_normalization:
                    clip_norm = self.max_grad_norm * math.sqrt(weight_norm)
                torch.nn.utils.clip_grad_norm_(self.parameters(), clip_norm)
            self.optimizer.step()

            avg_loss = float(scaled_loss.item())
            return avg_loss, h_state.detach()

        total_loss_value = 0.0
        if h_prev_batch is None:
            h_state = None
        else:
            h_state = to_tensor(h_prev_batch, self.device, dtype=torch.float32)

        for start in range(0, time_steps, tbptt_steps):
            self.optimizer.zero_grad(set_to_none=True)
            if self.train_encoder:
                f = self.encoder(inputs)
            else:
                with torch.no_grad():
                    f = self.encoder(inputs)
            if h_state is None:
                h_state = torch.zeros(
                    (batch_size, self.hidden_size, f.shape[2], f.shape[3]),
                    dtype=torch.float32,
                    device=self.device,
                )
            chunk_loss = torch.zeros((), device=self.device)
            end = min(time_steps, start + tbptt_steps)
            for t in range(start, end):
                step_weight = step_weights[t] if step_weights is not None else 1.0
                h_state = _conv_rnn_step(h_state, f, self.W_hh, self.W_xh, self.b_h, self.padding)
                logits = _readout_logits(h_state, self.W_hy, self.b_y)
                y_true_t = targets[:, :, t]
                if self.loss_mode == "ce":
                    probs = softmax(logits)
                    loss_t = -torch.sum(y_true_t * torch.log(probs + eps), dim=1).mean()
                else:
                    error = logits - y_true_t
                    loss_t = 0.5 * torch.sum(error**2, dim=1).mean()
                if step_weights is not None:
                    loss_t = loss_t * step_weight
                chunk_loss = chunk_loss + loss_t

            if self.time_normalization:
                if step_weights is not None:
                    chunk_scale = float(step_weights[start:end].sum().item())
                else:
                    chunk_scale = float(end - start)
                chunk_scale = max(chunk_scale, 1.0)
                chunk_loss_scale = chunk_scale
                chunk_clip_scale = chunk_scale
            else:
                chunk_loss_scale = 1.0
                if step_weights is not None:
                    chunk_clip_scale = float(step_weights[start:end].sum().item())
                else:
                    chunk_clip_scale = float(end - start)
                chunk_clip_scale = max(chunk_clip_scale, 1.0)
            scaled_chunk_loss = chunk_loss / max(chunk_loss_scale, 1.0)
            scaled_chunk_loss.backward()
            if self.max_grad_norm > 0:
                clip_norm = self.max_grad_norm
                if not self.time_normalization:
                    clip_norm = self.max_grad_norm * math.sqrt(chunk_clip_scale)
                torch.nn.utils.clip_grad_norm_(self.parameters(), clip_norm)
            self.optimizer.step()

            total_loss_value += float(chunk_loss.detach().item())
            h_state = h_state.detach()

        avg_loss_scale = weight_norm if self.time_normalization else 1.0
        avg_loss = total_loss_value / max(avg_loss_scale, 1.0)
        return avg_loss, h_state.detach()

    # forward_cycle：执行完整时间序列循环的前向与状态更新。
    # 关键步骤：接收输入/状态 → 更新递推 → 输出结果。
    # 算法要点：保持状态递推与非线性一致性。
    def forward_cycle(
        self, inputs_cycle: np.ndarray | torch.Tensor, h_prev_cycle: np.ndarray | torch.Tensor | None
    ) -> Tuple[List[torch.Tensor], torch.Tensor]:
        inputs = to_tensor(inputs_cycle, self.device, dtype=torch.float32)
        with torch.no_grad():
            f = self.encoder(inputs)
            batch_size = int(inputs.shape[0])
            if h_prev_cycle is None:
                h_prev = torch.zeros(
                    (batch_size, self.hidden_size, f.shape[2], f.shape[3]),
                    dtype=torch.float32,
                    device=self.device,
                )
            else:
                h_prev = to_tensor(h_prev_cycle, self.device, dtype=torch.float32)
            outputs: List[torch.Tensor] = []
            for _ in range(self.steps):
                h_prev = _conv_rnn_step(h_prev, f, self.W_hh, self.W_xh, self.b_h, self.padding)
                outputs.append(_readout_logits(h_prev, self.W_hy, self.b_y))
        return outputs, h_prev


# build_chunk_schedule：生成 FPTT 分块训练的 chunk 调度表。
# 关键步骤：计算分块 → 生成索引 → 返回计划。
# 算法要点：分块降低显存并控制梯度路径。
def build_chunk_schedule(time_steps: int, parts: int) -> List[Tuple[int, int]]:
    time_steps = max(1, int(time_steps))
    parts = max(1, int(parts))
    step = max(1, time_steps // parts)
    total_parts = parts + (1 if parts * step < time_steps else 0)

    schedule: List[Tuple[int, int]] = []
    for idx in range(total_parts):
        start = idx * step
        if start >= time_steps:
            break
        end = min(time_steps, start + step)
        schedule.append((start, end))

    if not schedule:
        schedule = [(0, time_steps)]
    elif schedule[-1][1] < time_steps:
        schedule.append((schedule[-1][1], time_steps))
    return schedule


# =============================
# FPTT 支持组件（oracle buffer / 正则器）
# =============================

class ClassOracleBuffer:
    # __init__：初始化模块超参数、子模块与状态缓存。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def __init__(
        self,
        num_classes: int,
        max_parts: int,
        momentum: float = 1.0,
    ) -> None:
        self.num_classes = int(num_classes)
        self.momentum = float(np.clip(momentum, 1e-4, 1.0))
        self._storage = np.full(
            (self.num_classes, max_parts, self.num_classes),
            1.0 / float(self.num_classes),
            dtype=np.float32,
        )

    # ensure：确保缓冲区已初始化并可容纳指定索引。
    # 关键步骤：检查索引 → 读取/更新 → 维护缓冲。
    # 算法要点：保证缓存访问安全与状态一致。
    def ensure(self, required_parts: int) -> None:
        if required_parts <= self._storage.shape[1]:
            return
        extra = required_parts - self._storage.shape[1]
        filler = np.full(
            (self.num_classes, extra, self.num_classes),
            1.0 / float(self.num_classes),
            dtype=np.float32,
        )
        self._storage = np.concatenate([self._storage, filler], axis=1)

    # get：从缓冲区读取指定索引/键的内容。
    # 关键步骤：检查索引 → 读取/更新 → 维护缓冲。
    # 算法要点：保证缓存访问安全与状态一致。
    def get(self, labels: np.ndarray, idx: int) -> np.ndarray:
        labels = labels.astype(np.int64)
        idx = int(min(idx, self._storage.shape[1] - 1))
        oracle = self._storage[labels, idx]
        return oracle.T

    # update：更新缓冲区内容或推进指针。
    # 关键步骤：检查索引 → 读取/更新 → 维护缓冲。
    # 算法要点：保证缓存访问安全与状态一致。
    def update(
        self,
        labels: np.ndarray,
        idx: int,
        probs: np.ndarray,
        preds: np.ndarray,
    ) -> None:
        if probs.ndim != 2:
            raise ValueError("Expected probs with shape (classes, batch).")
        labels = labels.astype(np.int64)
        preds = preds.astype(np.int64)
        idx = int(min(idx, self._storage.shape[1] - 1))

        filled = np.zeros(self.num_classes, dtype=bool)
        for col, (y, y_hat) in enumerate(zip(labels, preds)):
            if y < 0 or y >= self.num_classes:
                continue
            if filled[y]:
                continue
            if y_hat == y:
                continue
            current = probs[:, col]
            if self.momentum >= 0.999:
                new = current
            else:
                old = self._storage[y, idx]
                new = (1.0 - self.momentum) * old + self.momentum * current
            self._storage[y, idx] = new
            filled[y] = True


class OracleBufferStore:
    _buffers: Dict[str, ClassOracleBuffer] = {}

    # get：从缓冲区读取指定索引/键的内容。
    # 关键步骤：检查索引 → 读取/更新 → 维护缓冲。
    # 算法要点：保证缓存访问安全与状态一致。
    @classmethod
    def get(
        cls,
        key: str,
        num_classes: int,
        max_parts: int,
        momentum: float,
    ) -> ClassOracleBuffer:
        buf = cls._buffers.get(key)
        if buf is None or buf.num_classes != num_classes:
            buf = ClassOracleBuffer(num_classes, max_parts, momentum)
            cls._buffers[key] = buf
        else:
            buf.momentum = float(np.clip(momentum, 1e-4, 1.0))
            buf.ensure(max_parts)
        return buf

    # reset：重置模块/缓冲区到初始状态。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    @classmethod
    def reset(cls, key: str | None = None) -> None:
        if key is None:
            cls._buffers.clear()
        else:
            cls._buffers.pop(key, None)


# 严格 FPTT 的正则项/动量状态（sm/lm）维护器。
class FPTTRegularizer:
    # __init__：初始化模块超参数、子模块与状态缓存。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def __init__(
        self,
        named_params: Iterable[Tuple[str, nn.Parameter]],
        alpha: float,
        beta: float,
        rho: float,
        lmbda: float = 1.0,
    ) -> None:
        self.alpha = max(1e-8, float(alpha))
        self.beta = max(0.0, float(beta))
        self.rho = float(rho)
        self.lmbda = float(lmbda)

        self._state: Dict[str, Dict[str, torch.Tensor]] = {}
        for name, param in named_params:
            sm = param.detach().clone()
            lm = torch.zeros_like(param)
            self._state[name] = {"param": param, "sm": sm, "lm": lm}

        self._device = (
            next(iter(self._state.values()))["param"].device
            if self._state
            else torch.device("cpu")
        )

    # reset：重置模块/缓冲区到初始状态。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def reset(self) -> None:
        for state in self._state.values():
            state["sm"].copy_(state["param"].detach())
            state["lm"].zero_()

    # loss：计算正则项/惩罚项的损失值。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def loss(self, lmbda: float | None = None) -> torch.Tensor:
        if not self._state:
            return torch.zeros((), device=self._device)
        scale = float(self.lmbda if lmbda is None else lmbda)
        reg = torch.zeros((), device=self._device)
        for state in self._state.values():
            param = state["param"]
            sm = state["sm"]
            lm = state["lm"]
            reg = reg + (self.rho - 1.0) * torch.sum(param * lm)
            reg = reg + scale * 0.5 * self.alpha * torch.sum((param - sm) ** 2)
        return reg

    # step：执行一步更新（如正则器/动量）。
    # 关键步骤：接收输入/状态 → 更新递推 → 输出结果。
    # 算法要点：保持状态递推与非线性一致性。
    def step(self) -> None:
        if not self._state:
            return
        with torch.no_grad():
            for state in self._state.values():
                param = state["param"]
                sm = state["sm"]
                lm = state["lm"]
                delta = param.detach() - sm
                lm.add_(-self.alpha * delta)
                sm.mul_(1.0 - self.beta)
                sm.add_(self.beta * param.detach() - (self.beta / self.alpha) * lm)


# 严格 FPTT 版本的卷积RNN分类器（分块优化 + oracle buffer）。
class StrictFPTTConvClassifier(nn.Module):
    # __init__：初始化模块超参数、子模块与状态缓存。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def __init__(
        self,
        in_channels: int,
        enc_channels: Tuple[int, int],
        hidden_channels: int,
        output_size: int,
        steps: int = 8,
        eta: float = 1e-3,
        parts: int = 6,
        clip: float = 1.0,
        alpha: float = 0.1,
        beta: float = 0.5,
        rho: float = 0.0,
        lmbda: float = 1.0,
        oracle_momentum: float = 1.0,
        warmup_epochs: int = 20,
        oracle_id: str = "default",
        use_oracle: bool = True,
        train_encoder: bool = False,
        seed: int | None = None,
        device: torch.device | str | None = None,
        kernel_size: int = 3,
    ) -> None:
        super().__init__()
        rng = np.random.default_rng(seed)
        self.device = resolve_device(device)
        self.hidden_size = int(hidden_channels)
        self.input_size = int(in_channels)
        self.output_size = int(output_size)
        self.eta = float(eta)
        self.parts = max(1, int(parts))
        self.grad_clip = max(0.0, float(clip))
        self.train_encoder = bool(train_encoder)
        self.kernel_size = int(kernel_size)
        self.padding = self.kernel_size // 2
        self.steps = max(1, int(steps))

        self.encoder = SimpleCNNEncoder(in_channels, enc_channels, kernel_size=self.kernel_size).to(self.device)
        self.W_xh = nn.Parameter(
            to_tensor(
                rng.standard_normal((hidden_channels, enc_channels[1], 1, 1)).astype(np.float32) * 0.1,
                self.device,
            )
        )
        self.W_hh = nn.Parameter(
            to_tensor(
                rng.standard_normal((hidden_channels, hidden_channels, self.kernel_size, self.kernel_size)).astype(
                    np.float32
                ),
                self.device,
            )
        )
        self.b_h = nn.Parameter(torch.zeros((1, hidden_channels, 1, 1), dtype=torch.float32, device=self.device))
        self.W_hy = nn.Parameter(
            to_tensor(rng.standard_normal((output_size, hidden_channels)).astype(np.float32) * 0.1, self.device)
        )
        self.b_y = nn.Parameter(torch.zeros((output_size,), dtype=torch.float32, device=self.device))

        self.optimizer = optim.SGD(self.parameters(), lr=self.eta)
        self._regularizer = FPTTRegularizer(
            list(self.named_parameters()),
            alpha,
            beta,
            rho,
            lmbda=lmbda,
        )
        self.current_epoch = 0

        self.oracle_momentum = float(np.clip(oracle_momentum, 1e-4, 1.0))
        self.warmup_epochs = max(0, int(warmup_epochs))
        self.oracle_id = oracle_id
        self.use_oracle = bool(use_oracle)

        self._uniform_template = torch.full(
            (1, self.output_size),
            1.0 / float(self.output_size),
            dtype=torch.float32,
            device=self.device,
        )

    # initialize_weights_with_gain：按指定增益初始化权重矩阵。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def initialize_weights_with_gain(self, g: float, seed: int | None = None) -> None:
        rng = np.random.default_rng(seed)
        fan_in = self.hidden_size * self.kernel_size * self.kernel_size
        std_dev = g / math.sqrt(fan_in)
        with torch.no_grad():
            self.W_hh.copy_(to_tensor(rng.standard_normal(self.W_hh.shape) * std_dev, self.device))

    # reset_state_buffers：重置 FPTT 状态与缓存区。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def reset_state_buffers(self) -> None:
        self._regularizer.reset()

    # set_epoch：设置当前 epoch 并同步调度状态。
    # 关键步骤：读取输入 → 处理逻辑 → 返回结果。
    # 算法要点：作为通用工具支撑上层流程。
    def set_epoch(self, epoch: int) -> None:
        self.current_epoch = int(epoch)

    # forward_cycle：执行完整时间序列循环的前向与状态更新。
    # 关键步骤：接收输入/状态 → 更新递推 → 输出结果。
    # 算法要点：保持状态递推与非线性一致性。
    def forward_cycle(
        self, inputs: np.ndarray | torch.Tensor, h_prev: np.ndarray | torch.Tensor | None
    ) -> Tuple[List[torch.Tensor], torch.Tensor]:
        inputs_t = to_tensor(inputs, self.device, dtype=torch.float32)
        with torch.no_grad():
            f = self.encoder(inputs_t)
            batch_size = int(inputs_t.shape[0])
            if h_prev is None:
                h_state = torch.zeros(
                    (batch_size, self.hidden_size, f.shape[2], f.shape[3]),
                    dtype=torch.float32,
                    device=self.device,
                )
            else:
                h_state = to_tensor(h_prev, self.device, dtype=torch.float32)
            outputs: List[torch.Tensor] = []
            for _ in range(self.steps):
                h_state = _conv_rnn_step(h_state, f, self.W_hh, self.W_xh, self.b_h, self.padding)
                outputs.append(_readout_logits(h_state, self.W_hy, self.b_y))
        return outputs, h_state

    # train_batch：训练单个 batch 并返回损失/指标。
    # 关键步骤：前向计算 → 计算损失 → 反向更新/统计。
    # 算法要点：依据 FPTT/E-Prop/BPTT 分支处理梯度更新。
    def train_batch(
        self,
        inputs_batch: np.ndarray | torch.Tensor,
        targets_batch: np.ndarray | torch.Tensor,
        h_prev_batch: np.ndarray | torch.Tensor | None,
    ) -> Tuple[float, torch.Tensor]:
        inputs = to_tensor(inputs_batch, self.device, dtype=torch.float32)
        targets = to_tensor(targets_batch, self.device, dtype=torch.float32)
        batch_size = int(inputs.shape[0])
        time_steps = int(targets.shape[2])
        schedule = build_chunk_schedule(time_steps, self.parts)
        total_chunks = len(schedule)
        if total_chunks == 0:
            return 0.0, h_prev_batch

        with torch.no_grad():
            f_shape = self.encoder(inputs).shape
        if h_prev_batch is None:
            h_state = torch.zeros(
                (batch_size, self.hidden_size, f_shape[2], f_shape[3]),
                dtype=torch.float32,
                device=self.device,
            )
        else:
            h_state = to_tensor(h_prev_batch, self.device, dtype=torch.float32)

        labels = targets[:, :, -1]
        label_indices = torch.argmax(labels, dim=1).detach().cpu().numpy().astype(np.int64)
        oracle_cutoff = max(0, min(self.parts - 1, total_chunks - 1))
        oracle: ClassOracleBuffer | None = None
        warmup = self.current_epoch < self.warmup_epochs
        if self.use_oracle:
            oracle = OracleBufferStore.get(
                self.oracle_id,
                self.output_size,
                self.parts,
                self.oracle_momentum,
            )
            oracle.ensure(total_chunks)

        total_loss = 0.0
        for chunk_idx, (start, end) in enumerate(schedule):
            if self.train_encoder:
                f = self.encoder(inputs)
            else:
                with torch.no_grad():
                    f = self.encoder(inputs)
            for _ in range(start, end):
                h_state = _conv_rnn_step(h_state, f, self.W_hh, self.W_xh, self.b_h, self.padding)
            logits = _readout_logits(h_state, self.W_hy, self.b_y)
            log_probs = torch.log_softmax(logits, dim=1)
            probs = torch.softmax(logits, dim=1)

            alpha = float(chunk_idx + 1) / float(max(1, total_chunks))
            oracle_active = (oracle is not None) and (chunk_idx < oracle_cutoff)
            oracle_weight = (1.0 - alpha) if oracle_active else 0.0
            surrogate = None
            if oracle_active:
                if warmup:
                    surrogate = self._uniform_template.expand_as(labels)
                else:
                    surrogate_np = oracle.get(label_indices, chunk_idx)
                    surrogate = torch.from_numpy(surrogate_np.T).to(self.device)
            mix_target = labels if surrogate is None else alpha * labels + oracle_weight * surrogate

            loss_ce = torch.sum(-mix_target * log_probs, dim=1).mean()
            loss = loss_ce + self._regularizer.loss()
            self.optimizer.zero_grad()
            loss.backward()
            if self.grad_clip > 0:
                clip_grad_norm_(self.parameters(), self.grad_clip)
            self.optimizer.step()
            self._regularizer.step()
            total_loss += float(loss_ce.item())
            h_state = h_state.detach()

            if oracle_active and (not warmup) and (surrogate is not None) and (oracle is not None):
                probs_np = probs.detach().cpu().numpy().T
                preds = np.argmax(probs_np, axis=0)
                oracle.update(label_indices, chunk_idx, probs_np, preds)

        return total_loss / max(1, total_chunks), h_state.detach()


# extract_params：提取模型参数为 numpy 数组以便保存。
# 关键步骤：组织数据 → 转换/序列化 → 写入/返回。
# 算法要点：保证参数可序列化与可复现。
def extract_params(model: Any) -> Dict[str, np.ndarray]:
    params: Dict[str, np.ndarray] = {}
    for name in ("W_hh", "W_xh", "b_h", "W_hy", "b_y"):
        value = getattr(model, name)
        if torch.is_tensor(value):
            params[name] = value.detach().cpu().numpy().copy()
        else:
            params[name] = np.asarray(value, dtype=np.float32).copy()
    if hasattr(model, "encoder"):
        params["enc_conv1_w"] = model.encoder.conv1.weight.detach().cpu().numpy().copy()
        params["enc_conv1_b"] = model.encoder.conv1.bias.detach().cpu().numpy().copy()
        params["enc_conv2_w"] = model.encoder.conv2.weight.detach().cpu().numpy().copy()
        params["enc_conv2_b"] = model.encoder.conv2.bias.detach().cpu().numpy().copy()
    return params


# load_params：将 numpy 参数数组加载回模型。
# 关键步骤：读取缓存/文件 → 预处理/归一化 → 返回数据。
# 算法要点：优先本地缓存，减少重复预处理成本。
def load_params(model: Any, params: Dict[str, np.ndarray]) -> None:
    for name in ("W_hh", "W_xh", "b_h", "W_hy", "b_y"):
        value = params[name]
        current = getattr(model, name)
        if torch.is_tensor(current):
            tensor = to_tensor(value, device=current.device, dtype=current.dtype)
            with torch.no_grad():
                current.copy_(tensor)
        else:
            setattr(model, name, np.asarray(value, dtype=np.float32))
    if hasattr(model, "encoder"):
        with torch.no_grad():
            model.encoder.conv1.weight.copy_(to_tensor(params["enc_conv1_w"], model.encoder.conv1.weight.device))
            model.encoder.conv1.bias.copy_(to_tensor(params["enc_conv1_b"], model.encoder.conv1.bias.device))
            model.encoder.conv2.weight.copy_(to_tensor(params["enc_conv2_w"], model.encoder.conv2.weight.device))
            model.encoder.conv2.bias.copy_(to_tensor(params["enc_conv2_b"], model.encoder.conv2.bias.device))
    if hasattr(model, "reset_optimizer"):
        model.reset_optimizer()
    elif hasattr(model, "optimizer") and hasattr(model.optimizer, "state"):
        model.optimizer.state.clear()


def extract_model_hparams(model: Any) -> Dict[str, Any]:
    hparams: Dict[str, Any] = {}

    for attr in ("eta", "loss_mode", "max_grad_norm", "time_normalization", "tbptt_steps"):
        if hasattr(model, attr):
            value = getattr(model, attr)
            if isinstance(value, (int, float, bool, str)) or value is None:
                hparams[attr] = value
            else:
                try:
                    hparams[attr] = float(value)
                except Exception:
                    hparams[attr] = str(value)

    # E-Prop / RFLO (Conv variant)
    if hasattr(model, "decay_lambda"):
        hparams["decay_lambda"] = float(getattr(model, "decay_lambda"))
    if hasattr(model, "feedback_type"):
        hparams["feedback_type"] = str(getattr(model, "feedback_type"))

    # Strict FPTT (Conv variant)
    for attr in ("parts", "grad_clip", "oracle_momentum", "warmup_epochs", "oracle_id", "label_mode", "use_oracle"):
        if hasattr(model, attr):
            value = getattr(model, attr)
            if isinstance(value, (int, float, bool, str)) or value is None:
                hparams[attr] = value
            else:
                hparams[attr] = str(value)
    if hasattr(model, "_regularizer"):
        reg = getattr(model, "_regularizer")
        for attr in ("alpha", "beta", "rho", "lmbda"):
            if hasattr(reg, attr):
                hparams[f"reg_{attr}"] = float(getattr(reg, attr))

    return hparams

# =============================
# Lyapunov 指标计算
# =============================

# build_lyapunov_driver：构建通用 Lyapunov 指数计算驱动。
# 关键步骤：读取超参 → 组装模块/配置 → 返回对象。
# 算法要点：确保构建结果与训练入口配置一致。
def build_lyapunov_driver(
    model: Any,
    inputs: np.ndarray | torch.Tensor,
    max_samples: int = 32,
) -> torch.Tensor:
    device = resolve_device(getattr(model, "device", None))
    inputs_t = to_tensor(inputs, device=device, dtype=torch.float32)
    if inputs_t.ndim == 3:
        inputs_t = inputs_t.unsqueeze(0)
    sample_count = min(max_samples, inputs_t.shape[0])
    samples = inputs_t[:sample_count]
    if sample_count <= 1:
        return samples
    flat = samples.reshape(sample_count, -1)
    norms = torch.norm(flat, dim=1)
    median_idx = torch.argsort(norms)[sample_count // 2]
    return samples[median_idx : median_idx + 1]


# calculate_lyapunov_exponent_conv：计算卷积 RNN 的最大 Lyapunov 指数。
# 关键步骤：迭代动力学 → 计算 JVP/QR → 累计指数。
# 算法要点：通过 QR/JVP 累计估计最大 Lyapunov 指数。
def calculate_lyapunov_exponent_conv(
    model: Any,
    driver_input: torch.Tensor,
    steps: int | None = None,
    eps: float = 1e-8,
    num_vectors: int | None = None,
    driver_is_encoded: bool | None = None,
) -> float:
    steps = int(steps or getattr(model, "steps", 8))
    steps = max(steps, LYAPUNOV_MIN_STEPS)
    device = resolve_device(getattr(model, "device", None))
    W_hh = to_tensor(getattr(model, "W_hh"), device=device, dtype=torch.float64)
    W_xh = to_tensor(getattr(model, "W_xh"), device=device, dtype=torch.float64)
    b_h = to_tensor(getattr(model, "b_h"), device=device, dtype=torch.float64)
    padding = int(getattr(model, "padding", 1))

    driver = to_tensor(driver_input, device=device, dtype=torch.float32)
    if driver.ndim == 3:
        driver = driver.unsqueeze(0)
    f: torch.Tensor
    if driver_is_encoded is True:
        f = driver
    else:
        encoder = getattr(model, "encoder", None)
        enc_in_channels = getattr(getattr(encoder, "conv1", None), "in_channels", None)
        enc_out_channels = getattr(getattr(encoder, "conv2", None), "out_channels", None)
        needs_encoding = True
        if driver_is_encoded is None:
            if enc_out_channels is not None and driver.shape[1] == enc_out_channels:
                if enc_in_channels is None or driver.shape[1] != enc_in_channels:
                    needs_encoding = False
            if enc_in_channels is not None and driver.shape[1] == enc_in_channels:
                needs_encoding = True
        if needs_encoding:
            with torch.no_grad():
                f = model.encoder(driver)
        else:
            f = driver
    f = f.to(dtype=torch.float64)

    h = torch.zeros((1, W_hh.shape[0], f.shape[2], f.shape[3]), device=device, dtype=torch.float64)
    if num_vectors is None:
        num_vectors = int(getattr(model, "lyap_vectors", 1))
    num_vectors = max(1, min(int(num_vectors), h.numel()))
    vecs = torch.randn((num_vectors,) + h.shape, dtype=torch.float64, device=device)
    vecs_flat = vecs.reshape(num_vectors, -1).T
    Q, _ = torch.linalg.qr(vecs_flat, mode="reduced")
    vecs = Q.T.reshape((num_vectors,) + h.shape)
    log_diag_sum = torch.zeros((num_vectors,), dtype=torch.float64, device=device)

    for _ in range(steps):
        # Benettin-style JVP + QR to track growth of perturbation vectors.
        h = h.detach()
        vecs = vecs.detach()

        jv_list: List[torch.Tensor] = []
        h_next: torch.Tensor | None = None
        with torch.enable_grad():
            h_req = h.detach().requires_grad_(True)

            # step_fn：Lyapunov 计算中的单步动力学函数。
            # 关键步骤：接收输入/状态 → 更新递推 → 输出结果。
            # 算法要点：保持状态递推与非线性一致性。
            def step_fn(h_in: torch.Tensor) -> torch.Tensor:
                return _conv_rnn_step(h_in, f, W_hh, W_xh, b_h, padding)

            for idx in range(num_vectors):
                v = vecs[idx]
                h_out, jv = torch.autograd.functional.jvp(step_fn, (h_req,), (v,), create_graph=False)
                if h_next is None:
                    h_next = h_out
                jv_list.append(jv)

        if h_next is None:
            return float("nan")
        jv_mat = torch.stack([jv.reshape(-1) for jv in jv_list], dim=1)
        Q, R = torch.linalg.qr(jv_mat, mode="reduced")
        diag = torch.diagonal(R)
        if not torch.all(torch.isfinite(diag)):
            return float("nan")
        log_diag_sum += torch.log(torch.abs(diag) + eps)
        vecs = Q.T.reshape((num_vectors,) + h.shape)
        h = h_next.detach()

    lyaps = log_diag_sum / max(steps, 1)
    return float(torch.max(lyaps).item())


# =============================
# 训练 / 评估 / 扫描
# =============================

# train_batches：遍历批次训练一个 epoch 并汇总指标。
# 关键步骤：前向计算 → 计算损失 → 反向更新/统计。
# 算法要点：依据 FPTT/E-Prop/BPTT 分支处理梯度更新。
def train_batches(
    model: Any,
    inputs: np.ndarray | torch.Tensor,
    targets: np.ndarray | torch.Tensor,
    batch_size: int,
    epochs: int,
    seed: int,
    use_surrogates: bool = False,
    step_weights: np.ndarray | torch.Tensor | None = None,
    epoch_offset: int = 0,
) -> None:
    rng = np.random.default_rng(seed)
    device = DEFAULT_DEVICE
    if hasattr(model, "device"):
        device = resolve_device(getattr(model, "device"))
    if hasattr(model, "W_hh") and torch.is_tensor(getattr(model, "W_hh")):
        device = getattr(model, "W_hh").device
    inputs_t = to_tensor(inputs, device=device, dtype=torch.float32)
    targets_t = to_tensor(targets, device=device, dtype=torch.float32)
    time_steps = int(targets_t.shape[2])
    output_size = int(targets_t.shape[1])
    if step_weights is not None and step_weights.shape[0] != time_steps:
        raise ValueError("step_weights length must match the sequence length.")
    if hasattr(model, "step_weights"):
        model.step_weights = (
            to_tensor(step_weights, device=device, dtype=torch.float32)
            if step_weights is not None
            else None
        )

    # FPTT 软标签在每个 epoch 内累计，并在 epoch 结束时更新。
    for epoch in range(int(epochs)):
        if hasattr(model, "set_epoch"):
            model.set_epoch(epoch + epoch_offset)
        if use_surrogates and hasattr(model, "enable_fptt_surrogates"):
            if getattr(model, "fptt_Q_prev", None) is None:
                Q0 = torch.full(
                    (output_size, time_steps),
                    1.0 / output_size,
                    dtype=torch.float32,
                    device=device,
                )
                model.enable_fptt_surrogates(time_steps, output_size, Q0)
            if hasattr(model, "reset_fptt_epoch_accumulators"):
                model.reset_fptt_epoch_accumulators(time_steps, output_size)

        for inputs_batch, targets_batch in iterate_minibatches(inputs_t, targets_t, batch_size, rng):
            model.train_batch(inputs_batch, targets_batch, None)

        if use_surrogates and hasattr(model, "finalize_fptt_epoch"):
            model.finalize_fptt_epoch()


# evaluate_classifier_final_step：评估最后时间步的分类准确率。
# 关键步骤：切换 eval → 前向推理 → 计算指标。
# 算法要点：关闭梯度确保评估稳定，避免污染状态。
def evaluate_classifier_final_step(
    model: Any,
    inputs: np.ndarray | torch.Tensor,
    targets: np.ndarray | torch.Tensor,
    labels: np.ndarray | torch.Tensor,
    batch_size: int,
) -> Tuple[float, float]:
    total_loss = 0.0
    total_samples = 0
    correct = 0
    eps = 1e-12

    device = resolve_device(getattr(model, "device", None))
    inputs_t = to_tensor(inputs, device=device, dtype=torch.float32)
    targets_t = to_tensor(targets, device=device, dtype=torch.float32)
    labels_t = to_tensor(labels, device=device, dtype=torch.long)

    for start_idx in range(0, inputs_t.shape[0], batch_size):
        end_idx = min(start_idx + batch_size, inputs_t.shape[0])
        input_batch = inputs_t[start_idx:end_idx]
        target_batch = targets_t[start_idx:end_idx]
        label_batch = labels_t[start_idx:end_idx]
        current_batch = int(input_batch.shape[0])

        outputs_seq, _ = model.forward_cycle(input_batch, None)
        outputs_stack = torch.stack(outputs_seq, dim=1)
        final_logits = outputs_stack[:, -1, :]
        probs = torch.softmax(final_logits, dim=1)
        y_final = target_batch[:, :, -1]
        ce = -torch.sum(y_final * torch.log(probs + eps), dim=1)
        total_loss += float(ce.mean().item()) * current_batch

        preds = torch.argmax(final_logits, dim=1)
        correct += int((preds == label_batch).sum().item())
        total_samples += current_batch

    avg_loss = total_loss / max(total_samples, 1)
    acc = correct / max(total_samples, 1)
    return avg_loss, acc


# predict_classifier_final_step：返回最后时间步的分类预测结果。
# 关键步骤：读取输入 → 处理逻辑 → 返回结果。
# 算法要点：作为通用工具支撑上层流程。
def predict_classifier_final_step(
    model: Any,
    inputs: np.ndarray | torch.Tensor,
    batch_size: int,
) -> np.ndarray:
    if inputs is None or inputs.shape[0] == 0:
        return np.empty((0,), dtype=np.int64)
    device = resolve_device(getattr(model, "device", None))
    total = int(inputs.shape[0])
    preds_out: List[torch.Tensor] = []
    for start_idx in range(0, total, batch_size):
        end_idx = min(start_idx + batch_size, total)
        if torch.is_tensor(inputs):
            batch = inputs[start_idx:end_idx].to(device=device, dtype=torch.float32)
        else:
            batch = to_tensor(inputs[start_idx:end_idx], device=device, dtype=torch.float32)
        outputs_seq, _ = model.forward_cycle(batch, None)
        outputs_stack = torch.stack(outputs_seq, dim=1)
        final_logits = outputs_stack[:, -1, :]
        preds = torch.argmax(final_logits, dim=1)
        preds_out.append(preds.detach().cpu())
    if not preds_out:
        return np.empty((0,), dtype=np.int64)
    return torch.cat(preds_out, dim=0).numpy()


# scan_gains_classification：扫描增益超参以选择分类最优初始化。
# 关键步骤：遍历增益 → 评估稳定性/损失 → 选最优。
# 算法要点：基于稳定性/验证指标筛选最优增益。
def scan_gains_classification(
    gains: np.ndarray,
    build_model: callable,
    train_inputs: np.ndarray | torch.Tensor,
    train_targets: np.ndarray | torch.Tensor,
    train_labels: np.ndarray | torch.Tensor,
    val_inputs: np.ndarray | torch.Tensor,
    val_targets: np.ndarray | torch.Tensor,
    val_labels: np.ndarray | torch.Tensor,
    batch_size: int,
    scan_epochs: int,
    seed: int,
    lyapunov_driver: torch.Tensor | None = None,
    use_surrogates: bool = False,
    step_weights: np.ndarray | torch.Tensor | None = None,
    plot_dir: Path | None = None,
    plot_tag: str | None = None,
    task_label: str | None = None,
    plot_show: bool = False,
) -> Tuple[float, Dict[str, np.ndarray], Dict[str, float]]:
    best_g = float(gains[0])
    best_metric = float("-inf")
    best_params: Dict[str, np.ndarray] = {}
    best_stats: Dict[str, float] = {}
    scan_g: List[float] = []
    scan_metric: List[float] = []
    scan_loss: List[float] = []
    scan_lyap_pre: List[float] = []
    scan_lyap_post: List[float] = []

    for g in gains:
        model = build_model()
        model.initialize_weights_with_gain(float(g), seed=seed)
        if hasattr(model, "reset_learning_state"):
            model.reset_learning_state()
        init_params = extract_params(model)
        lambda_pre = float("nan")
        if lyapunov_driver is not None:
            lambda_pre = calculate_lyapunov_exponent_conv(model, lyapunov_driver)
        train_batches(
            model,
            train_inputs,
            train_targets,
            batch_size,
            scan_epochs,
            seed + 1,
            use_surrogates=use_surrogates,
            step_weights=step_weights,
        )
        val_loss, val_acc = evaluate_classifier_final_step(
            model,
            val_inputs,
            val_targets,
            val_labels,
            batch_size,
        )
        lambda_post = float("nan")
        if lyapunov_driver is not None:
            lambda_post = calculate_lyapunov_exponent_conv(model, lyapunov_driver)
        scan_g.append(float(g))
        scan_metric.append(float(val_acc))
        scan_loss.append(float(val_loss))
        scan_lyap_pre.append(float(lambda_pre))
        scan_lyap_post.append(float(lambda_post))
        if val_acc > best_metric:
            best_metric = float(val_acc)
            best_g = float(g)
            best_params = init_params
            best_stats = {"val_acc": float(val_acc), "val_loss": float(val_loss)}
            if lyapunov_driver is not None:
                best_stats.update(
                    {
                        "lambda_pre": lambda_pre,
                        "lambda_post": lambda_post,
                    }
                )
        if lyapunov_driver is None:
            print(f"[SCAN] g={g:.3f} | val_acc={val_acc:.4f} | val_loss={val_loss:.4f}")
        else:
            delta = lambda_post - lambda_pre
            print(
                f"[SCAN] g={g:.3f} | val_acc={val_acc:.4f} | val_loss={val_loss:.4f} | "
                f"lyap=(pre:{lambda_pre:.4f}, post:{lambda_post:.4f}, d:{delta:.4f})"
            )

    if plot_dir is not None and plot_tag is not None and scan_g:
        plot_scan_results(
            task_label or "Scan",
            scan_g,
            scan_metric,
            "Val Accuracy",
            scan_lyap_pre,
            scan_lyap_post,
            plot_dir,
            plot_tag,
            plot_show,
            best_g=best_g,
            higher_is_better=True,
            aux_values=scan_loss,
            aux_label="Val Loss",
        )

    return best_g, best_params, best_stats


# =============================
# 任务入口（组装模型 + 扫描 + 训练对比组）
# =============================

# run_classification_task：分类任务入口：准备数据、扫描增益并训练评估。
# 关键步骤：读取输入 → 处理逻辑 → 返回结果。
# 算法要点：作为通用工具支撑上层流程。
def run_classification_task(task_data: Dict[str, Any], args: Any, gains: np.ndarray) -> None:
    set_global_seed(int(getattr(args, "seed", 0)))
    # 步骤1：准备数据、权重与绘图路径。
    train_images_np = task_data["train_images"]
    train_labels_np = task_data["train_labels"]
    test_images_np = task_data["test_images"]
    test_labels_np = task_data["test_labels"]
    input_channels = int(task_data["input_channels"])
    output_size = int(task_data["output_size"])
    time_steps = int(getattr(args, "steps", 8))
    # step_labels=fptt 时启用软标签与替代权重。
    step_label_mode = str(getattr(args, "step_labels", "final")).lower()
    use_surrogates = step_label_mode == "fptt"
    time_weighting = task_data.get("time_weighting", getattr(args, "time_weighting", None))
    step_weights = build_time_weights(time_steps, time_weighting)
    task_label = task_data.get("task_name", getattr(args, "task", "Image Task"))
    plot_enabled = getattr(args, "plot", True)
    plot_dir = None
    plot_tag = None
    if plot_enabled:
        plot_dir, plot_tag = resolve_plot_context(args, task_label)

    train_targets_np = build_repeated_targets(train_labels_np, output_size, time_steps)
    test_targets_np = build_repeated_targets(test_labels_np, output_size, time_steps)

    device = DEFAULT_DEVICE
    train_images = to_tensor(train_images_np, device=device, dtype=torch.float32)
    train_targets = to_tensor(train_targets_np, device=device, dtype=torch.float32)
    train_labels = to_tensor(train_labels_np, device=device, dtype=torch.long)
    test_images = to_tensor(test_images_np, device=device, dtype=torch.float32)
    test_targets = to_tensor(test_targets_np, device=device, dtype=torch.float32)
    test_labels = to_tensor(test_labels_np, device=device, dtype=torch.long)
    step_weights_t = to_tensor(step_weights, device=device, dtype=torch.float32) if step_weights is not None else None

    rng = np.random.default_rng(args.seed)
    tr_inputs, tr_targets, tr_labels, val_inputs, val_targets, val_labels = split_train_val(
        train_images, train_targets, train_labels, 0.1, rng
    )
    train_inputs_fit = tr_inputs
    train_targets_fit = tr_targets
    train_labels_fit = tr_labels
    batches_per_epoch = math.ceil(int(train_inputs_fit.shape[0]) / args.batch_size)

    enc_channels = parse_channel_list(getattr(args, "enc_channels", None), (16, 32))

    run_config: Dict[str, Any] = {
        "task_id": getattr(args, "task", None),
        "task_type": "convrnn_classification",
        "task_name": task_label,
        "seed": int(getattr(args, "seed", 0)),
        "epochs": int(getattr(args, "epochs", 0)),
        "scan_epochs": int(getattr(args, "scan_epochs", 0)),
        "batch_size": int(getattr(args, "batch_size", 0)),
        "hidden": int(getattr(args, "hidden", 0)),
        "lr": float(getattr(args, "lr", 0.0)),
        "steps": int(time_steps),
        "time_weighting": time_weighting,
        "step_labels": step_label_mode,
        "use_surrogates": bool(use_surrogates),
        "gains": [float(x) for x in np.asarray(gains).ravel().tolist()],
        "input_channels": int(input_channels),
        "output_size": int(output_size),
        "train_size": int(train_images.shape[0]),
        "val_size": int(val_inputs.shape[0]),
        "test_size": int(test_images.shape[0]),
        "device": str(device),
        "enc_channels": list(enc_channels),
        "kernel_size": int(getattr(args, "kernel_size", 3)),
        "train_encoder": bool(getattr(args, "train_encoder", False)),
        "args": dict(vars(args)) if hasattr(args, "__dict__") else {},
        "argv": list(sys.argv),
        "eprop_feedback": EPROP_FEEDBACK,
        "eprop_seed": int(EPROP_SEED),
        "fptt_parts": int(FPTT_PARTS),
        "fptt_lambda": float(FPTT_LAMBDA),
        "fptt_oracle_momentum": float(FPTT_ORACLE_MOMENTUM),
        "fptt_warmup_epochs": int(FPTT_WARMUP_EPOCHS),
    }

    # build_local：构建 Local Rule 模型及其优化器配置。
    # 关键步骤：读取超参 → 组装模块/配置 → 返回对象。
    # 算法要点：确保构建结果与训练入口配置一致。
    def build_local() -> TorchLocalRuleConvRNN:
        return TorchLocalRuleConvRNN(
            in_channels=input_channels,
            enc_channels=enc_channels,
            hidden_channels=args.hidden,
            output_size=output_size,
            steps=time_steps,
            eta=args.lr,
            loss_mode="ce",
            max_grad_norm=5.0,
            train_encoder=getattr(args, "train_encoder", False),
            seed=args.seed,
            device=device,
            kernel_size=getattr(args, "kernel_size", 3),
        )

    # 步骤2：增益扫描（Local Rule）并选最优初始化。
    lyapunov_driver = build_lyapunov_driver(build_local(), val_inputs)
    fixed_gain_value = getattr(args, "gain", None)
    if fixed_gain_value is None:
        best_g, init_params, stats = scan_gains_classification(
            gains,
            build_local,
            tr_inputs,
            tr_targets,
            tr_labels,
            val_inputs,
            val_targets,
            val_labels,
            args.batch_size,
            args.scan_epochs,
            args.seed,
            lyapunov_driver=lyapunov_driver,
            use_surrogates=use_surrogates,
            step_weights=step_weights_t,
            plot_dir=plot_dir,
            plot_tag=plot_tag,
            task_label=task_label,
            plot_show=False,
        )
        summary = f"Best g={best_g:.3f} | val_acc={stats['val_acc']:.4f} | val_loss={stats['val_loss']:.4f}"
        if "lambda_pre" in stats and "lambda_post" in stats:
            delta = stats["lambda_post"] - stats["lambda_pre"]
            summary += (
                f" | lyap(pre:{stats['lambda_pre']:.4f}, post:{stats['lambda_post']:.4f}, d:{delta:.4f})"
            )
        print(summary)
    else:
        best_g = float(fixed_gain_value)
        model = build_local()
        model.initialize_weights_with_gain(best_g, seed=args.seed)
        if hasattr(model, "reset_learning_state"):
            model.reset_learning_state()
        init_params = extract_params(model)
        lambda_pre = calculate_lyapunov_exponent_conv(model, lyapunov_driver)
        stats = {
            "val_acc": float("nan"),
            "val_loss": float("nan"),
            "lambda_pre": float(lambda_pre),
            "lambda_post": float("nan"),
        }
        print(f"Fixed g={best_g:.3f} | lyap(pre:{lambda_pre:.4f})")
    run_config["best_g"] = float(best_g)
    run_config["scan_best"] = {k: float(v) for k, v in stats.items()}

    local_name = "Local Rule (FPTT)" if use_surrogates else "Local Rule"
    tbptt_short = max(1, int(getattr(args, "tbptt_short", 1)))
    tbptt_long_default = max(2, min(10, time_steps))
    tbptt_long_value = getattr(args, "tbptt_long", None)
    tbptt_long = tbptt_long_default if tbptt_long_value is None else int(tbptt_long_value)
    train_encoder = bool(getattr(args, "train_encoder", False))
    kernel_size = getattr(args, "kernel_size", 3)
    skip_eprop = bool(getattr(args, "no_eprop", False))

    # 对照组：Local Rule / BPTT / E-Prop / FPTT。
    # 步骤3：构建对照组模型（Local/BPTT/E-Prop/FPTT/TBPTT）。
    models = {
        local_name: build_local(),
        "BPTT": TorchBPTTConvRNN(
            input_channels,
            enc_channels,
            args.hidden,
            output_size,
            steps=time_steps,
            eta=args.lr,
            loss_mode="ce",
            time_normalization=False,
            train_encoder=train_encoder,
            seed=args.seed,
            device=device,
            kernel_size=kernel_size,
        ),
    }
    if not skip_eprop:
        models["E-Prop"] = TorchEPropConvRNN(
            input_channels,
            enc_channels,
            args.hidden,
            output_size,
            steps=time_steps,
            eta=args.lr,
            decay_lambda=0.95,
            feedback=EPROP_FEEDBACK,
            loss_mode="ce",
            train_encoder=train_encoder,
            seed=EPROP_SEED,
            device=device,
            kernel_size=kernel_size,
        )
    models["FPTT"] = StrictFPTTConvClassifier(
        input_channels,
        enc_channels,
        args.hidden,
        output_size,
        steps=time_steps,
        eta=args.lr,
        parts=FPTT_PARTS,
        clip=5.0,
        alpha=0.1,
        beta=0.5,
        rho=0.0,
        lmbda=FPTT_LAMBDA,
        oracle_momentum=FPTT_ORACLE_MOMENTUM,
        warmup_epochs=FPTT_WARMUP_EPOCHS,
        oracle_id=getattr(args, "task", "cnn_task"),
        use_oracle=True,
        train_encoder=train_encoder,
        seed=args.seed,
        device=device,
        kernel_size=kernel_size,
    )
    if 0 < tbptt_short < time_steps:
        models[f"TBPTT-{tbptt_short}"] = TorchBPTTConvRNN(
            input_channels,
            enc_channels,
            args.hidden,
            output_size,
            steps=time_steps,
            eta=args.lr,
            loss_mode="ce",
            tbptt_steps=tbptt_short,
            time_normalization=False,
            train_encoder=train_encoder,
            seed=args.seed,
            device=device,
            kernel_size=kernel_size,
        )
    if tbptt_long != tbptt_short and 0 < tbptt_long < time_steps:
        models[f"TBPTT-{tbptt_long}"] = TorchBPTTConvRNN(
            input_channels,
            enc_channels,
            args.hidden,
            output_size,
            steps=time_steps,
            eta=args.lr,
            loss_mode="ce",
            tbptt_steps=tbptt_long,
            time_normalization=False,
            train_encoder=train_encoder,
            seed=args.seed,
            device=device,
            kernel_size=kernel_size,
        )

    only_eprop = bool(getattr(args, "only_eprop", False))
    if only_eprop:
        if "E-Prop" not in models:
            raise RuntimeError("Requested E-Prop-only run, but E-Prop is disabled. Remove --no-eprop.")
        models = {"E-Prop": models["E-Prop"]}

    # 步骤4：训练/评估各模型并汇总结果。
    results: Dict[str, Dict[str, Any]] = {}
    log_every = max(1, args.epochs // 5)
    local_rule_model: Any | None = None

    for name, model in models.items():
        load_params(model, init_params)
        if hasattr(model, "reset_learning_state"):
            model.reset_learning_state()
        if hasattr(model, "reset_state_buffers"):
            model.reset_state_buffers()
        lambda_pre = calculate_lyapunov_exponent_conv(model, lyapunov_driver)
        history: List[float] = []
        use_surrogates_local = use_surrogates and name == local_name
        step_weights_local = step_weights_t
        complexity = estimate_model_complexity(model)
        update_stats = estimate_training_counts(model, time_steps, batches_per_epoch, args.epochs)
        profile_batch = min(int(args.batch_size), int(train_inputs_fit.shape[0]))
        update_profile = profile_update_cost(
            model,
            train_inputs_fit[:profile_batch],
            train_targets_fit[:profile_batch],
            use_surrogates=use_surrogates_local,
            step_weights=step_weights_local,
        )
        if update_profile:
            load_params(model, init_params)
            if hasattr(model, "reset_learning_state"):
                model.reset_learning_state()
            if hasattr(model, "reset_state_buffers"):
                model.reset_state_buffers()
        train_runtime_sec = 0.0
        eval_runtime_sec = 0.0
        best_epoch = 0
        best_val_metric = -float("inf")
        best_val_loss = float("inf")
        best_params: Dict[str, np.ndarray] | None = None

        if plot_enabled:
            for epoch in range(args.epochs):
                train_start = time.perf_counter()
                train_batches(
                    model,
                    train_inputs_fit,
                    train_targets_fit,
                    args.batch_size,
                    1,
                    args.seed + 10 + epoch,
                    use_surrogates=use_surrogates_local,
                    step_weights=step_weights_local,
                    epoch_offset=epoch,
                )
                train_runtime_sec += time.perf_counter() - train_start
                eval_start = time.perf_counter()
                val_loss, val_acc = evaluate_classifier_final_step(
                    model, val_inputs, val_targets, val_labels, args.batch_size
                )
                eval_runtime_sec += time.perf_counter() - eval_start
                history.append(float(val_acc))
                if val_acc > best_val_metric:
                    best_val_metric = float(val_acc)
                    best_val_loss = float(val_loss)
                    best_epoch = epoch + 1
                    best_params = extract_params(model)
                if (epoch + 1) % log_every == 0 or (epoch + 1) == args.epochs:
                    print(f"[{name}] epoch={epoch+1:02d} | val_acc={val_acc:.4f} | val_loss={val_loss:.4f}")
        else:
            train_start = time.perf_counter()
            train_batches(
                model,
                train_inputs_fit,
                train_targets_fit,
                args.batch_size,
                args.epochs,
                args.seed + 10,
                use_surrogates=use_surrogates_local,
                step_weights=step_weights_local,
                epoch_offset=0,
            )
            train_runtime_sec = time.perf_counter() - train_start
        if best_params is not None:
            load_params(model, best_params)
        eval_start = time.perf_counter()
        final_test_loss, final_test_acc = evaluate_classifier_final_step(
            model, test_images, test_targets, test_labels, args.batch_size
        )
        eval_runtime_sec += time.perf_counter() - eval_start
        if best_epoch == 0:
            best_epoch = int(args.epochs)
            best_val_metric = float("nan")
            best_val_loss = float("nan")

        runtime_sec = train_runtime_sec + eval_runtime_sec
        updates_total = update_stats["updates_total"]
        steps_total = update_stats["steps_total"]
        runtime_per_update_sec = (
            train_runtime_sec / updates_total if updates_total > 0 else float("nan")
        )
        runtime_per_step_sec = train_runtime_sec / steps_total if steps_total > 0 else float("nan")
        lambda_post = calculate_lyapunov_exponent_conv(model, lyapunov_driver)
        delta = lambda_post - lambda_pre
        print(
            f"[{name}] test_acc={final_test_acc:.4f} | test_loss={final_test_loss:.4f} | "
            f"best_val_acc={best_val_metric:.4f} (epoch={best_epoch:02d}) | "
            f"lyap=(pre:{lambda_pre:.4f}, post:{lambda_post:.4f}, d:{delta:.4f})"
        )
        results[name] = {
            "metric": float(final_test_acc),
            "val_metric": float(best_val_metric),
            "val_loss": float(best_val_loss),
            "best_epoch": int(best_epoch),
            "lyap_pre": float(lambda_pre),
            "lyap_post": float(lambda_post),
            "history": history,
            "complexity_params": float(complexity["params"]),
            "complexity_state": float(complexity["state"]),
            "complexity_total": float(complexity["total"]),
            "runtime_sec": float(runtime_sec),
            "train_runtime_sec": float(train_runtime_sec),
            "eval_runtime_sec": float(eval_runtime_sec),
            "runtime_per_update_sec": float(runtime_per_update_sec),
            "runtime_per_step_sec": float(runtime_per_step_sec),
            "batches_per_epoch": int(update_stats["batches_per_epoch"]),
            "time_steps": int(update_stats["time_steps"]),
            "update_factor": int(update_stats["update_factor"]),
            "updates_per_epoch": int(update_stats["updates_per_epoch"]),
            "updates_total": int(update_stats["updates_total"]),
            "steps_total": int(update_stats["steps_total"]),
            "hparams": extract_model_hparams(model),
        }
        if update_profile:
            results[name].update(update_profile)
        if name == local_name:
            local_rule_model = model

    if plot_enabled and plot_dir is not None and plot_tag is not None:
        save_results_summary(
            task_label,
            "Test Accuracy",
            results,
            plot_dir,
            plot_tag,
            run_config=run_config,
        )
        plot_comparison_results(
            task_label,
            results,
            metric_label="Test Accuracy",
            history_label="Val Accuracy",
            higher_is_better=True,
            plot_dir=plot_dir,
            plot_tag=plot_tag,
            show=False,
        )
        plot_cost_profile(
            task_label,
            results,
            plot_dir=plot_dir,
            plot_tag=plot_tag,
            show=False,
        )
        if local_rule_model is not None:
            plot_classification_predictions(
                task_label,
                local_rule_model,
                test_images_np,
                test_labels_np,
                plot_dir=plot_dir,
                plot_tag=plot_tag,
                show=False,
            )
