import os
import pickle
import copy
import numpy as np
import logging
from neuron import h
import torch
from worm_network import Network
import time
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any

from worm_defaults import (
    ALPHA_D as alpha_d,
    ALPHA_W0 as alpha_w0,
    ALPHA_X0 as alpha_x0,
    K_MAX_T as K_max_t,
    K_NBLOCK as K_nblock,
    NGPU as ngpu,
    RANDOM_SEED as random_seed,
    W_GAP_MAX as w_gap_max,
    W_GAP_MIN as w_gap_min,
    W_SYN_MAX as w_syn_max,
    W_SYN_MIN as w_syn_min,
)
from worm_train_config import WormTrainConfig
from worm_training_objective import compute_corr_loss_and_dLtdv
from worm_training_simulation import simulate_and_record_epoch
from worm_training_updates import apply_updates, compute_dw_dx, prepare_dw_dx_for_update
from worm_training_plateau import maybe_emergency_brake, maybe_plateau_restore

try:
    # Prefer heliox_learn's AdamState implementation when available.
    from heliox_learn import AdamState  # type: ignore
except Exception:  # pragma: no cover
    AdamState = None  # type: ignore

# IMPORTANT (compatibility): the historical eworm scripts used a hand-rolled Adam update.
# The `heliox_learn.AdamState` implementation is allowed to differ in subtle ways
# (bias correction, eps handling, etc.) which can change convergence significantly.
# Default to legacy Adam unless explicitly enabled.
USE_HELIOX_LEARN_ADAM = (
    (AdamState is not None)
    and (os.environ.get("EWORM_USE_HELIOX_LEARN_ADAM", "0").strip() == "1")
)

from worm_checkpoint import (
    _ckpt_path,
    _error_path,
    _infer_start_epoch_from_error,
    _load_resume_state,
    _plateau_vmin_snapshot_path,
    _run_best_snapshot_path,
    _save_ckpt,
    _weights_optimal_path,
    _weights_train_path,
    _x_optimal_path,
    _x_train_path,
)

DEFAULT_OUTPUT_PATH = './trial10'
DEFAULT_PREFIX = 'eworm'
DEFAULT_SUFFIX = 'v4_torch'           # 'v4'

# NOTE:
# - 这里是“训练实现库”，推荐入口是 `train.py` 或 `run_from_global_optimal.sh`（不要直接当脚本用）。
# - `OUTPUT_PATH/PREFIX/SUFFIX` are legacy defaults, only used by `test()` (debug helper).
OUTPUT_PATH = DEFAULT_OUTPUT_PATH
PREFIX = DEFAULT_PREFIX
SUFFIX = DEFAULT_SUFFIX

dt = None                        # time step (ms)
v_r = None                       # resting potential (mV)
tstop = None                    # simulation time (ms)
PERCISE = True
ADAM_W = True
ADAM_X = True

_LOGGER = logging.getLogger("worm_demo")


class _EpochProfiler:
    def __init__(self):
        self.times_s: dict[str, float] = {}

    @contextmanager
    def phase(self, name: str):
        t0 = time.perf_counter()
        try:
            yield
        finally:
            self.times_s[name] = self.times_s.get(name, 0.0) + (time.perf_counter() - t0)

    def format_summary(self, total_s: float) -> str:
        parts = [f"total={total_s:.3f}s"]
        for k in sorted(self.times_s.keys()):
            v = self.times_s[k]
            pct = (100.0 * v / total_s) if total_s > 0 else 0.0
            parts.append(f"{k}={v:.3f}s({pct:.1f}%)")
        return ", ".join(parts)

@dataclass
class _TrainState:
    x: np.ndarray
    train_error: list[float]

    # best-so-far snapshot + retreat strategy state
    opt_epoch: int
    opt_w: np.ndarray
    opt_x: np.ndarray
    opt_mean_error: float
    alpha_multiplier: float

    # Adam state (optional)
    adam_m_w: Any = None
    adam_v_w: Any = None
    beta_1_t_w: float = 1.0
    beta_2_t_w: float = 1.0
    adam_m_x: Any = None
    adam_v_x: Any = None
    beta_1_t_x: float = 1.0
    beta_2_t_x: float = 1.0

    # Snapshot of optimizer state at the current best (used by retreat logic).
    opt_adam_params_w: Any = None
    opt_adam_params_x: Any = None


def _init_train_state(net: Network, input_is: np.ndarray, resume_state: dict | None) -> tuple[_TrainState, int]:
    """Initialize the mutable training state (x/weights/optimizer/best-so-far)."""
    start_epoch = int(resume_state.get("start_epoch", 0)) if resume_state is not None else 0

    if resume_state is not None:
        x = np.copy(resume_state["x"])
        if "w" in resume_state and resume_state["w"] is not None:
            net.set_weights(resume_state["w"])
    else:
        x = np.copy(input_is)

    # Adam state defaults.
    adam_m_w = 0.0
    adam_v_w = 0.0
    beta_1_t_w = 1.0
    beta_2_t_w = 1.0
    adam_m_x = 0.0
    adam_v_x = 0.0
    beta_1_t_x = 1.0
    beta_2_t_x = 1.0

    if resume_state is not None:
        if ADAM_W and resume_state.get("adam_m_w", None) is not None:
            adam_m_w = resume_state["adam_m_w"]
            adam_v_w = resume_state["adam_v_w"]
            beta_1_t_w = float(resume_state.get("beta_1_t_w", beta_1_t_w))
            beta_2_t_w = float(resume_state.get("beta_2_t_w", beta_2_t_w))
        if ADAM_X and resume_state.get("adam_m_x", None) is not None:
            adam_m_x = resume_state["adam_m_x"]
            adam_v_x = resume_state["adam_v_x"]
            beta_1_t_x = float(resume_state.get("beta_1_t_x", beta_1_t_x))
            beta_2_t_x = float(resume_state.get("beta_2_t_x", beta_2_t_x))

    if resume_state is not None and resume_state.get("opt_w", None) is not None:
        opt_epoch = int(resume_state.get("opt_epoch", -1))
        opt_w = np.asarray(resume_state["opt_w"])
        opt_x = np.asarray(resume_state["opt_x"])
        opt_mean_error = float(resume_state.get("opt_mean_error", 1e100))
        alpha_multiplier = float(resume_state.get("alpha_multiplier", 1.0))
    else:
        opt_epoch = -1
        opt_w = net.w.numpy()
        opt_x = np.copy(x)
        opt_mean_error = 1e100
        alpha_multiplier = 1.0

    train_error = list(resume_state.get("train_error", [])) if resume_state is not None else []

    opt_adam_params_w = None
    opt_adam_params_x = None
    if ADAM_W:
        opt_adam_params_w = copy.deepcopy((adam_m_w, adam_v_w, beta_1_t_w, beta_2_t_w))
    if ADAM_X:
        opt_adam_params_x = copy.deepcopy((adam_m_x, adam_v_x, beta_1_t_x, beta_2_t_x))

    return (
        _TrainState(
            x=x,
            train_error=train_error,
            opt_epoch=opt_epoch,
            opt_w=opt_w,
            opt_x=opt_x,
            opt_mean_error=opt_mean_error,
            alpha_multiplier=alpha_multiplier,
            adam_m_w=adam_m_w if ADAM_W else None,
            adam_v_w=adam_v_w if ADAM_W else None,
            beta_1_t_w=beta_1_t_w,
            beta_2_t_w=beta_2_t_w,
            adam_m_x=adam_m_x if ADAM_X else None,
            adam_v_x=adam_v_x if ADAM_X else None,
            beta_1_t_x=beta_1_t_x,
            beta_2_t_x=beta_2_t_x,
            opt_adam_params_w=opt_adam_params_w,
            opt_adam_params_x=opt_adam_params_x,
        ),
        start_epoch,
    )


def train_one_epoch(
    net: Network,
    output_names,
    target,
    *,
    output_path: str,
    prefix: str,
    suffix: str,
    dt_ms: float | None = None,
    tstop_ms: float | None = None,
    v_init: float | None = None,
    k_mul: int | None = None,
    epoch: int,
    state: dict,
    config: WormTrainConfig | None = None,
    logger=None,
) -> dict:
    """Run exactly one training epoch and return the updated state dict.

    This is a small wrapper around `train_epoch` that updates `state['start_epoch']`.
    """
    epoch = int(epoch)
    state = dict(state)
    state["start_epoch"] = epoch
    return dict(
        train_epoch(
            net,
            output_names,
            target,
            output_path=output_path,
            prefix=prefix,
            suffix=suffix,
            dt_ms=dt_ms,
            tstop_ms=tstop_ms,
            v_init=v_init,
            k_mul=k_mul,
            epoch=epoch,
            state=state,
            config=config,
            logger=logger,
        )
    )


def train_epoch(
    net: Network,
    output_names,
    target,
    *,
    output_path: str,
    prefix: str,
    suffix: str,
    dt_ms: float | None = None,
    tstop_ms: float | None = None,
    v_init: float | None = None,
    k_mul: int | None = None,
    epoch: int,
    state: dict,
    config: WormTrainConfig | None = None,
    logger=None,
) -> dict:
    """Single-epoch training step (state-in/state-out).

    This mirrors the legacy implementation inside `train()` as closely as possible.
    """
    epoch = int(epoch)
    state = dict(state)

    cfg = config or WormTrainConfig.from_env()

    net.set_outputs(output_names)
    backend = getattr(net, "_heliox_backend", None)
    if backend is None:
        raise RuntimeError(
            "Runtime stepping is HELIOX-only. Attach a HELIOX backend before calling train_epoch."
        )
    profile = bool(cfg.profile)
    use_backend_opt_w = bool(ADAM_W) and bool(cfg.opt_w_backend)
    if use_backend_opt_w and not hasattr(backend, "ensure_weight_adam_optimizer"):
        raise RuntimeError(
            "EWORM_OPT_W_BACKEND=1 requires a WormHelioXRuntime backend with ensure_weight_adam_optimizer()."
        )

    # Resolve simulation constants. Historically these were module-level globals populated in __main__.
    # We allow passing them explicitly to avoid implicit global state.
    if dt_ms is None:
        if dt is None:
            raise RuntimeError(
                "train_epoch requires dt_ms (set dt_ms explicitly)"
            )
        dt_ms = float(dt)
    else:
        dt_ms = float(dt_ms)

    if tstop_ms is None:
        if tstop is None:
            raise RuntimeError(
                "train_epoch requires tstop_ms (set tstop_ms explicitly)"
            )
        tstop_ms = float(tstop)
    else:
        tstop_ms = float(tstop_ms)

    if v_init is None:
        if v_r is None:
            raise RuntimeError(
                "train_epoch requires v_init (set v_init explicitly)"
            )
        v_init = float(v_r)
    else:
        v_init = float(v_init)

    lr_start = 0
    lr_end = int(float(tstop_ms) / float(dt_ms))

    # Pull mutable values from state
    x = np.asarray(state["x"])
    train_error = list(state.get("train_error", []))

    opt_epoch = int(state.get("opt_epoch", -1))
    opt_w = np.asarray(state.get("opt_w", net.w.numpy()))
    opt_x = np.asarray(state.get("opt_x", x))
    opt_mean_error = float(state.get("opt_mean_error", 1e100))
    alpha_multiplier = float(state.get("alpha_multiplier", 1.0))

    # Adam hyperparams (kept consistent with legacy).
    beta_1 = 0.9
    beta_2 = 0.999
    epsilon = 1e-9

    adam_m_w = state.get("adam_m_w", 0.0)
    adam_v_w = state.get("adam_v_w", 0.0)
    beta_1_t_w = float(state.get("beta_1_t_w", 1.0))
    beta_2_t_w = float(state.get("beta_2_t_w", 1.0))
    adam_m_x = state.get("adam_m_x", 0.0)
    adam_v_x = state.get("adam_v_x", 0.0)
    beta_1_t_x = float(state.get("beta_1_t_x", 1.0))
    beta_2_t_x = float(state.get("beta_2_t_x", 1.0))

    # Keep optimizer snapshots for retreat logic.
    # NOTE: these snapshots must be persisted across epochs (in `state`) to match the legacy
    # behavior of "restore to best-so-far + optionally restore best-so-far Adam state".
    def _copy_snapshot(which: str):
        if which == "w":
            return copy.deepcopy((adam_m_w, adam_v_w, beta_1_t_w, beta_2_t_w))
        if which == "x":
            return copy.deepcopy((adam_m_x, adam_v_x, beta_1_t_x, beta_2_t_x))
        raise ValueError("which must be 'w'|'x'")

    def _load_opt_snapshot(which: str):
        if which == "w":
            if state.get("opt_adam_m_w", None) is None or state.get("opt_adam_v_w", None) is None:
                return _copy_snapshot("w")
            return (
                copy.deepcopy(state["opt_adam_m_w"]),
                copy.deepcopy(state["opt_adam_v_w"]),
                float(state.get("opt_beta_1_t_w", 1.0)),
                float(state.get("opt_beta_2_t_w", 1.0)),
            )
        if which == "x":
            if state.get("opt_adam_m_x", None) is None or state.get("opt_adam_v_x", None) is None:
                return _copy_snapshot("x")
            return (
                copy.deepcopy(state["opt_adam_m_x"]),
                copy.deepcopy(state["opt_adam_v_x"]),
                float(state.get("opt_beta_1_t_x", 1.0)),
                float(state.get("opt_beta_2_t_x", 1.0)),
            )
        raise ValueError("which must be 'w'|'x'")

    opt_adam_params_w = _load_opt_snapshot("w") if ADAM_W else None
    opt_adam_params_x = _load_opt_snapshot("x") if ADAM_X else None

    # Plateau helper snapshot (optional): among near-best error states, keep a vmin-improved candidate.
    plateau_vmin_best = float(state.get("plateau_vmin_best", -1e100))
    plateau_vmin_best_epoch = int(state.get("plateau_vmin_best_epoch", -1))
    plateau_vmin_best_error = float(state.get("plateau_vmin_best_error", 1e100))
    plateau_vmin_best_w = state.get("plateau_vmin_best_w", None)
    plateau_vmin_best_x = state.get("plateau_vmin_best_x", None)
    plateau_vmin_best_alpha_multiplier = float(state.get("plateau_vmin_best_alpha_multiplier", 1.0))
    plateau_vmin_best_adam_params_w = None
    plateau_vmin_best_adam_params_x = None
    if ADAM_W and state.get("plateau_vmin_best_adam_m_w", None) is not None and state.get("plateau_vmin_best_adam_v_w", None) is not None:
        plateau_vmin_best_adam_params_w = (
            copy.deepcopy(state["plateau_vmin_best_adam_m_w"]),
            copy.deepcopy(state["plateau_vmin_best_adam_v_w"]),
            float(state.get("plateau_vmin_best_beta_1_t_w", 1.0)),
            float(state.get("plateau_vmin_best_beta_2_t_w", 1.0)),
        )
    if ADAM_X and state.get("plateau_vmin_best_adam_m_x", None) is not None and state.get("plateau_vmin_best_adam_v_x", None) is not None:
        plateau_vmin_best_adam_params_x = (
            copy.deepcopy(state["plateau_vmin_best_adam_m_x"]),
            copy.deepcopy(state["plateau_vmin_best_adam_v_x"]),
            float(state.get("plateau_vmin_best_beta_1_t_x", 1.0)),
            float(state.get("plateau_vmin_best_beta_2_t_x", 1.0)),
        )

    # Run-best snapshot (best mean_error within this run/phase; independent of the global `opt_*`).
    run_best_epoch = int(state.get("run_best_epoch", -1))
    run_best_mean_error = float(state.get("run_best_mean_error", 1e100))
    run_best_vmin = float(state.get("run_best_vmin", 0.0))
    run_best_vmax = float(state.get("run_best_vmax", 0.0))
    run_best_w = state.get("run_best_w", None)
    run_best_x = state.get("run_best_x", None)

    epoch_prof = _EpochProfiler() if (profile and logger) else None
    t_epoch0 = time.perf_counter()
    start_time = time.time()

    # Learning rate schedule (allow env overrides for quick tuning without code edits).
    # Defaults preserve historical behavior: alpha_w0/alpha_x0 with 1/(1+alpha_d*epoch) decay.
    alpha_w_base = float(cfg.alpha_w0)
    alpha_x_base = float(cfg.alpha_x0)
    alpha_w = alpha_w_base / (1 + alpha_d * epoch)
    alpha_x = alpha_x_base / (1 + alpha_d * epoch)
    alpha_w *= float(cfg.alpha_w_scale)
    alpha_x *= float(cfg.alpha_x_scale)

    h.t = 0
    h.tstop = float(tstop_ms)
    h.secondorder = 0

    if epoch_prof:
        with epoch_prof.phase("reset_lr_records"):
            net._reset_lr_records()
    else:
        net._reset_lr_records()

    if epoch_prof:
        with epoch_prof.phase("finitialize"):
            backend.finitialize(v_init)
        with epoch_prof.phase("push_weights"):
            net.set_weights()
    else:
        backend.finitialize(v_init)
        net.set_weights()

    backend_adam_w_state = None
    opt_backend_adam_w_state = None
    plateau_vmin_best_backend_adam_w_state = None
    if use_backend_opt_w:
        backend.ensure_weight_adam_optimizer(beta1=float(beta_1), beta2=float(beta_2), epsilon=float(epsilon))

        def _load_backend_adam_state(prefix: str) -> dict | None:
            step_k = f"{prefix}backend_adam_w_step"
            m_k = f"{prefix}backend_adam_w_m"
            v_k = f"{prefix}backend_adam_w_v"
            if state.get(step_k, None) is None or state.get(m_k, None) is None or state.get(v_k, None) is None:
                return None
            return {
                "step": int(state.get(step_k, 0)),
                "m": np.asarray(state.get(m_k), dtype=np.float64),
                "v": np.asarray(state.get(v_k), dtype=np.float64),
                "beta1": float(state.get(f"{prefix}backend_adam_w_beta1", beta_1)),
                "beta2": float(state.get(f"{prefix}backend_adam_w_beta2", beta_2)),
                "epsilon": float(state.get(f"{prefix}backend_adam_w_epsilon", epsilon)),
            }

        backend_adam_w_state = _load_backend_adam_state(prefix="")
        if backend_adam_w_state is not None:
            backend.set_weight_adam_state(backend_adam_w_state)
        else:
            backend.reset_weight_adam_state()
            backend_adam_w_state = backend.get_weight_adam_state()

        opt_backend_adam_w_state = _load_backend_adam_state(prefix="opt_")
        if opt_backend_adam_w_state is None:
            opt_backend_adam_w_state = copy.deepcopy(backend_adam_w_state)

        plateau_vmin_best_backend_adam_w_state = _load_backend_adam_state(prefix="plateau_vmin_best_")

    k_mul_v = 5 if k_mul is None else int(k_mul)
    if k_mul_v <= 0:
        raise ValueError(f"k_mul must be a positive int (got {k_mul_v})")

    # stimulation + objective
    total_steps = int(float(h.tstop) / float(h.dt)) if float(h.dt) > 0 else 0
    if total_steps <= 0:
        total_steps = 1

    output_vs, it_lr, ditdv_lr, ditdvpre_lr = simulate_and_record_epoch(
        net=net,
        backend=backend,
        output_names=list(output_names),
        x=x,
        epoch=int(epoch),
        cfg=cfg,
        epoch_prof=epoch_prof,
        tstop_ms=float(h.tstop),
        dt_ms=float(h.dt),
        v_init=float(v_init),
        total_steps=int(total_steps),
        k_mul=int(k_mul_v),
        percise=bool(PERCISE),
    )
    # Downstream update math expects dt-grid step count.
    tstep = int(total_steps)

    mean_error, dLtdv = compute_corr_loss_and_dLtdv(
        output_vs,
        target,
        lr_start=int(lr_start),
        lr_end=int(lr_end),
        cfg=cfg,
        epoch_prof=epoch_prof,
    )

    train_error.append(mean_error)
    if logger:
        logger.info(f'epoch: {epoch}, mean error: {mean_error:.5g}')
        if isinstance(dLtdv, torch.Tensor):
            logger.info(f'dLtdv max: {float(torch.max(dLtdv).detach().cpu().item()):.5g}, min: {float(torch.min(dLtdv).detach().cpu().item()):.5g}')
        else:
            logger.info(f'dLtdv max: {np.max(dLtdv):.5g}, min: {np.min(dLtdv):.5g}')

    # Persist best-within-run snapshot (to avoid losing good intermediate epochs when `weights_train_*.npy`
    # gets overwritten each epoch). This is especially useful for multi-stage training (e.g. freeze-x phases)
    # where the global `opt_mean_error` may never improve.
    save_run_best = bool(cfg.save_run_best)
    if save_run_best and float(mean_error) < float(run_best_mean_error):
        run_best_epoch = int(epoch)
        run_best_mean_error = float(mean_error)
        run_best_w = net.w.numpy().copy()
        run_best_x = np.copy(x)
        snap_path = _run_best_snapshot_path(output_path, prefix, suffix)
        np.savez_compressed(
            snap_path,
            epoch=int(run_best_epoch),
            mean_error=float(run_best_mean_error),
            alpha_multiplier=float(alpha_multiplier),
            w=np.asarray(run_best_w),
            x=np.asarray(run_best_x),
        )
        if logger:
            logger.info(f"run-best: update epoch={run_best_epoch} err={run_best_mean_error:.5g}; saved: {snap_path}")

    if mean_error < opt_mean_error:
        opt_epoch = epoch
        opt_w = net.w.numpy().copy()
        opt_x = np.copy(x)
        opt_mean_error = mean_error
        if ADAM_W:
            if use_backend_opt_w:
                opt_backend_adam_w_state = backend.get_weight_adam_state()
            else:
                opt_adam_params_w = copy.deepcopy((adam_m_w, adam_v_w, beta_1_t_w, beta_2_t_w))
        if ADAM_X:
            opt_adam_params_x = copy.deepcopy((adam_m_x, adam_v_x, beta_1_t_x, beta_2_t_x))
        np.save(_weights_optimal_path(output_path, prefix, suffix), opt_w)
        np.save(_x_optimal_path(output_path, prefix, suffix), opt_x)
        if logger:
            logger.info('optimal weights & x saved')

    # Plateau enhancement (optional): keep a vmin-preferred snapshot among "near-best" error states.
    # This DOES NOT change the training loss/gradients; it only changes which checkpoint we restore to.
    use_plateau_vmin = bool(cfg.plateau_use_vmin)
    if use_plateau_vmin:
        err_tol = float(cfg.plateau_vmin_err_tol)
        vmin_eps = float(cfg.plateau_vmin_eps)
        err_gate = float(opt_mean_error) + float(err_tol)
        if float(mean_error) <= err_gate:
            # Note: `plateau_use_vmin` affects restore strategy only (not the loss). We avoid printing voltages.
            vmin_now = float(np.min(output_vs))
            vmax_now = float(np.max(output_vs))
            if (plateau_vmin_best_epoch < 0) or (vmin_now >= float(plateau_vmin_best) + float(vmin_eps)):
                plateau_vmin_best = float(vmin_now)
                plateau_vmin_best_epoch = int(epoch)
                plateau_vmin_best_error = float(mean_error)
                plateau_vmin_best_w = net.w.numpy().copy()
                plateau_vmin_best_x = np.copy(x)
                plateau_vmin_best_alpha_multiplier = float(alpha_multiplier)
                if ADAM_W:
                    if use_backend_opt_w:
                        plateau_vmin_best_backend_adam_w_state = backend.get_weight_adam_state()
                    else:
                        plateau_vmin_best_adam_params_w = copy.deepcopy((adam_m_w, adam_v_w, beta_1_t_w, beta_2_t_w))
                if ADAM_X:
                    plateau_vmin_best_adam_params_x = copy.deepcopy((adam_m_x, adam_v_x, beta_1_t_x, beta_2_t_x))
                if logger:
                    logger.info(
                        f"plateau-vmin: update best at epoch={plateau_vmin_best_epoch} "
                        f"(err={plateau_vmin_best_error:.5g}, gate={err_gate:.5g})"
                    )
                if bool(cfg.plateau_vmin_save_snapshot):
                    snap_path = _plateau_vmin_snapshot_path(output_path, prefix, suffix)
                    snap = {
                        "epoch": int(plateau_vmin_best_epoch),
                        "mean_error": float(plateau_vmin_best_error),
                        "alpha_multiplier": float(plateau_vmin_best_alpha_multiplier),
                        "w": np.asarray(plateau_vmin_best_w),
                        "x": np.asarray(plateau_vmin_best_x),
                    }
                    if ADAM_W and plateau_vmin_best_adam_params_w is not None:
                        snap["adam_m_w"] = np.asarray(plateau_vmin_best_adam_params_w[0])
                        snap["adam_v_w"] = np.asarray(plateau_vmin_best_adam_params_w[1])
                        snap["beta_1_t_w"] = float(plateau_vmin_best_adam_params_w[2])
                        snap["beta_2_t_w"] = float(plateau_vmin_best_adam_params_w[3])
                    if ADAM_X and plateau_vmin_best_adam_params_x is not None:
                        snap["adam_m_x"] = np.asarray(plateau_vmin_best_adam_params_x[0])
                        snap["adam_v_x"] = np.asarray(plateau_vmin_best_adam_params_x[1])
                        snap["beta_1_t_x"] = float(plateau_vmin_best_adam_params_x[2])
                        snap["beta_2_t_x"] = float(plateau_vmin_best_adam_params_x[3])
                    np.savez_compressed(snap_path, **snap)
                    if logger:
                        logger.info(f"plateau-vmin: snapshot saved: {snap_path}")

    restored_state = maybe_emergency_brake(
        cfg=cfg,
        epoch=int(epoch),
        mean_error=float(mean_error),
        opt_mean_error=float(opt_mean_error),
        opt_epoch=int(opt_epoch),
        alpha_multiplier=float(alpha_multiplier),
        net=net,
        backend=backend,
        x=x,
        opt_w=np.asarray(opt_w),
        opt_x=np.asarray(opt_x),
        ADAM_W=bool(ADAM_W),
        ADAM_X=bool(ADAM_X),
        use_backend_opt_w=bool(use_backend_opt_w),
        beta_1=float(beta_1),
        beta_2=float(beta_2),
        epsilon=float(epsilon),
        adam_m_w=adam_m_w,
        adam_v_w=adam_v_w,
        beta_1_t_w=float(beta_1_t_w),
        beta_2_t_w=float(beta_2_t_w),
        adam_m_x=adam_m_x,
        adam_v_x=adam_v_x,
        beta_1_t_x=float(beta_1_t_x),
        beta_2_t_x=float(beta_2_t_x),
        opt_adam_params_w=opt_adam_params_w,
        opt_adam_params_x=opt_adam_params_x,
        opt_backend_adam_w_state=opt_backend_adam_w_state,
        plateau_vmin_best_backend_adam_w_state=plateau_vmin_best_backend_adam_w_state,
        run_best_epoch=int(run_best_epoch),
        run_best_mean_error=float(run_best_mean_error),
        run_best_vmin=float(run_best_vmin),
        run_best_vmax=float(run_best_vmax),
        run_best_w=run_best_w,
        run_best_x=run_best_x,
        plateau_vmin_best=float(plateau_vmin_best),
        plateau_vmin_best_epoch=int(plateau_vmin_best_epoch),
        plateau_vmin_best_error=float(plateau_vmin_best_error),
        plateau_vmin_best_w=plateau_vmin_best_w,
        plateau_vmin_best_x=plateau_vmin_best_x,
        plateau_vmin_best_alpha_multiplier=float(plateau_vmin_best_alpha_multiplier),
        plateau_vmin_best_adam_params_w=plateau_vmin_best_adam_params_w,
        plateau_vmin_best_adam_params_x=plateau_vmin_best_adam_params_x,
        train_error=train_error,
        output_path=str(output_path),
        prefix=str(prefix),
        suffix=str(suffix),
        logger=logger,
    )
    if restored_state is not None:
        return restored_state

    restored_state = maybe_plateau_restore(
        cfg=cfg,
        epoch=int(epoch),
        opt_epoch=int(opt_epoch),
        opt_mean_error=float(opt_mean_error),
        opt_w=np.asarray(opt_w),
        opt_x=np.asarray(opt_x),
        alpha_multiplier=float(alpha_multiplier),
        net=net,
        backend=backend,
        x=x,
        use_plateau_vmin=bool(use_plateau_vmin),
        plateau_vmin_best_epoch=int(plateau_vmin_best_epoch),
        plateau_vmin_best_error=float(plateau_vmin_best_error),
        plateau_vmin_best=float(plateau_vmin_best),
        plateau_vmin_best_w=plateau_vmin_best_w,
        plateau_vmin_best_x=plateau_vmin_best_x,
        plateau_vmin_best_alpha_multiplier=float(plateau_vmin_best_alpha_multiplier),
        plateau_vmin_best_adam_params_w=plateau_vmin_best_adam_params_w,
        plateau_vmin_best_adam_params_x=plateau_vmin_best_adam_params_x,
        plateau_vmin_best_backend_adam_w_state=plateau_vmin_best_backend_adam_w_state,
        ADAM_W=bool(ADAM_W),
        ADAM_X=bool(ADAM_X),
        use_backend_opt_w=bool(use_backend_opt_w),
        beta_1=float(beta_1),
        beta_2=float(beta_2),
        epsilon=float(epsilon),
        adam_m_w=adam_m_w,
        adam_v_w=adam_v_w,
        beta_1_t_w=float(beta_1_t_w),
        beta_2_t_w=float(beta_2_t_w),
        adam_m_x=adam_m_x,
        adam_v_x=adam_v_x,
        beta_1_t_x=float(beta_1_t_x),
        beta_2_t_x=float(beta_2_t_x),
        opt_adam_params_w=opt_adam_params_w,
        opt_adam_params_x=opt_adam_params_x,
        opt_backend_adam_w_state=opt_backend_adam_w_state,
        run_best_epoch=int(run_best_epoch),
        run_best_mean_error=float(run_best_mean_error),
        run_best_vmin=float(run_best_vmin),
        run_best_vmax=float(run_best_vmax),
        run_best_w=run_best_w,
        run_best_x=run_best_x,
        train_error=train_error,
        output_path=str(output_path),
        prefix=str(prefix),
        suffix=str(suffix),
        logger=logger,
    )
    if restored_state is not None:
        return restored_state

    alpha_w *= alpha_multiplier
    alpha_x *= alpha_multiplier

    dw, dx = compute_dw_dx(
        net=net,
        backend=backend,
        dLtdv=dLtdv,
        it_lr=it_lr,
        ditdv_lr=ditdv_lr,
        ditdvpre_lr=ditdvpre_lr,
        x=x,
        total_steps=int(total_steps),
        k_mul=int(k_mul_v),
        percise=bool(PERCISE),
        cfg=cfg,
        epoch_prof=epoch_prof,
        lr_start=int(lr_start),
        lr_end=int(lr_end),
        dt_ms=float(h.dt),
        tstop_ms=float(h.tstop),
        v_init=float(v_init),
    )

    (
        dw,
        dx,
        freeze_w,
        freeze_x,
        adam_m_w,
        adam_v_w,
        beta_1_t_w,
        beta_2_t_w,
        adam_m_x,
        adam_v_x,
        beta_1_t_x,
        beta_2_t_x,
    ) = prepare_dw_dx_for_update(
        dw=dw,
        dx=dx,
        epoch=int(epoch),
        cfg=cfg,
        logger=logger,
        ADAM_W=bool(ADAM_W),
        ADAM_X=bool(ADAM_X),
        use_backend_opt_w=bool(use_backend_opt_w),
        alpha_w=float(alpha_w),
        alpha_x=float(alpha_x),
        k_mul=int(k_mul_v),
        tstep=int(tstep),
        lr_start=int(lr_start),
        lr_end=int(lr_end),
        x=x,
        beta_1=float(beta_1),
        beta_2=float(beta_2),
        epsilon=float(epsilon),
        use_heliox_learn_adam=bool(USE_HELIOX_LEARN_ADAM),
        AdamState=AdamState,
        adam_m_w=adam_m_w,
        adam_v_w=adam_v_w,
        beta_1_t_w=float(beta_1_t_w),
        beta_2_t_w=float(beta_2_t_w),
        adam_m_x=adam_m_x,
        adam_v_x=adam_v_x,
        beta_1_t_x=float(beta_1_t_x),
        beta_2_t_x=float(beta_2_t_x),
        epoch_prof=epoch_prof,
    )

    x, backend_adam_w_state = apply_updates(
        net=net,
        backend=backend,
        dw=dw,
        dx=dx,
        x=x,
        lr_start=int(lr_start),
        lr_end=int(lr_end),
        freeze_w=bool(freeze_w),
        freeze_x=bool(freeze_x),
        use_backend_opt_w=bool(use_backend_opt_w),
        alpha_w=float(alpha_w),
        beta_1=float(beta_1),
        beta_2=float(beta_2),
        epsilon=float(epsilon),
        cfg=cfg,
        epoch_prof=epoch_prof,
        logger=logger,
    )

    tmp_w = net.w.numpy()
    if logger:
        logger.info(f'dw gap max: {np.max(dw[net.pgap]):.5g}, min: {np.min(dw[net.pgap]):.5g}')
        logger.info(f'dw syn max: {np.max(dw[net.psyn]):.5g}, min: {np.min(dw[net.psyn]):.5g}')
        logger.info(f'dx max: {np.max(dx):.5g}, min: {np.min(dx):.5g}')
        logger.info(f'w gap max: {np.max(tmp_w[net.pgap]):.5g}, min: {np.min(tmp_w[net.pgap]):.5g}')
        logger.info(f'w syn max: {np.max(tmp_w[net.psyn]):.5g}, min: {np.min(tmp_w[net.psyn]):.5g}')
        logger.info(f'x max: {np.max(x):.5g}, min: {np.min(x):.5g}')

    if epoch_prof:
        with epoch_prof.phase("save"):
            net.save_weights(path=_weights_train_path(output_path, prefix, suffix))
            np.save(_x_train_path(output_path, prefix, suffix), x)
            np.save(_error_path(output_path, prefix, suffix), np.asarray(train_error, dtype=np.float64))
    else:
        net.save_weights(path=_weights_train_path(output_path, prefix, suffix))
        np.save(_x_train_path(output_path, prefix, suffix), x)
        np.save(_error_path(output_path, prefix, suffix), np.asarray(train_error, dtype=np.float64))
    if logger:
        logger.info('weights & x saved')
        # Keep logs stable and avoid leaking performance details in shared demo output.

    # Always refresh backend Adam state right before checkpointing to avoid stale state being persisted across resume.
    if use_backend_opt_w:
        try:
            backend_adam_w_state = backend.get_weight_adam_state()
        except Exception:
            pass

    ckpt_state = {
        "start_epoch": epoch + 1,
        "x": x,
        "w": net.w.numpy(),
        "train_error": train_error,
        "opt_epoch": opt_epoch,
        "opt_mean_error": opt_mean_error,
        "opt_w": opt_w,
        "opt_x": opt_x,
        "alpha_multiplier": alpha_multiplier,
        "beta_1_t_w": beta_1_t_w if ADAM_W else 1.0,
        "beta_2_t_w": beta_2_t_w if ADAM_W else 1.0,
        "beta_1_t_x": beta_1_t_x if ADAM_X else 1.0,
        "beta_2_t_x": beta_2_t_x if ADAM_X else 1.0,
        "adam_m_w": adam_m_w if ADAM_W else None,
        "adam_v_w": adam_v_w if ADAM_W else None,
        "adam_m_x": adam_m_x if ADAM_X else None,
        "adam_v_x": adam_v_x if ADAM_X else None,
        # Best-so-far optimizer snapshot (for retreat logic).
        "opt_adam_m_w": opt_adam_params_w[0] if (ADAM_W and opt_adam_params_w is not None) else None,
        "opt_adam_v_w": opt_adam_params_w[1] if (ADAM_W and opt_adam_params_w is not None) else None,
        "opt_beta_1_t_w": opt_adam_params_w[2] if (ADAM_W and opt_adam_params_w is not None) else 1.0,
        "opt_beta_2_t_w": opt_adam_params_w[3] if (ADAM_W and opt_adam_params_w is not None) else 1.0,
        "opt_adam_m_x": opt_adam_params_x[0] if (ADAM_X and opt_adam_params_x is not None) else None,
        "opt_adam_v_x": opt_adam_params_x[1] if (ADAM_X and opt_adam_params_x is not None) else None,
        "opt_beta_1_t_x": opt_adam_params_x[2] if (ADAM_X and opt_adam_params_x is not None) else 1.0,
        "opt_beta_2_t_x": opt_adam_params_x[3] if (ADAM_X and opt_adam_params_x is not None) else 1.0,
        # Plateau vmin-preferred snapshot (optional).
        "plateau_vmin_best": plateau_vmin_best,
        "plateau_vmin_best_epoch": plateau_vmin_best_epoch,
        "plateau_vmin_best_error": plateau_vmin_best_error,
        "plateau_vmin_best_w": plateau_vmin_best_w if plateau_vmin_best_w is not None else None,
        "plateau_vmin_best_x": plateau_vmin_best_x if plateau_vmin_best_x is not None else None,
        "plateau_vmin_best_alpha_multiplier": plateau_vmin_best_alpha_multiplier,
        "plateau_vmin_best_adam_m_w": plateau_vmin_best_adam_params_w[0] if (ADAM_W and plateau_vmin_best_adam_params_w is not None) else None,
        "plateau_vmin_best_adam_v_w": plateau_vmin_best_adam_params_w[1] if (ADAM_W and plateau_vmin_best_adam_params_w is not None) else None,
        "plateau_vmin_best_beta_1_t_w": plateau_vmin_best_adam_params_w[2] if (ADAM_W and plateau_vmin_best_adam_params_w is not None) else 1.0,
        "plateau_vmin_best_beta_2_t_w": plateau_vmin_best_adam_params_w[3] if (ADAM_W and plateau_vmin_best_adam_params_w is not None) else 1.0,
        "plateau_vmin_best_adam_m_x": plateau_vmin_best_adam_params_x[0] if (ADAM_X and plateau_vmin_best_adam_params_x is not None) else None,
        "plateau_vmin_best_adam_v_x": plateau_vmin_best_adam_params_x[1] if (ADAM_X and plateau_vmin_best_adam_params_x is not None) else None,
        "plateau_vmin_best_beta_1_t_x": plateau_vmin_best_adam_params_x[2] if (ADAM_X and plateau_vmin_best_adam_params_x is not None) else 1.0,
        "plateau_vmin_best_beta_2_t_x": plateau_vmin_best_adam_params_x[3] if (ADAM_X and plateau_vmin_best_adam_params_x is not None) else 1.0,
        # Run-best snapshot (optional).
        "run_best_epoch": run_best_epoch,
        "run_best_mean_error": run_best_mean_error,
        "run_best_vmin": run_best_vmin,
        "run_best_vmax": run_best_vmax,
        "run_best_w": run_best_w if run_best_w is not None else None,
        "run_best_x": run_best_x if run_best_x is not None else None,
    }
    if use_backend_opt_w and backend_adam_w_state is not None:
        ckpt_state.update(
            {
                "backend_adam_w_step": int(backend_adam_w_state.get("step", 0)),
                "backend_adam_w_m": np.asarray(backend_adam_w_state.get("m", []), dtype=np.float64),
                "backend_adam_w_v": np.asarray(backend_adam_w_state.get("v", []), dtype=np.float64),
                "backend_adam_w_beta1": float(backend_adam_w_state.get("beta1", beta_1)),
                "backend_adam_w_beta2": float(backend_adam_w_state.get("beta2", beta_2)),
                "backend_adam_w_epsilon": float(backend_adam_w_state.get("epsilon", epsilon)),
            }
        )
        if opt_backend_adam_w_state is not None:
            ckpt_state.update(
                {
                    "opt_backend_adam_w_step": int(opt_backend_adam_w_state.get("step", 0)),
                    "opt_backend_adam_w_m": np.asarray(opt_backend_adam_w_state.get("m", []), dtype=np.float64),
                    "opt_backend_adam_w_v": np.asarray(opt_backend_adam_w_state.get("v", []), dtype=np.float64),
                    "opt_backend_adam_w_beta1": float(opt_backend_adam_w_state.get("beta1", beta_1)),
                    "opt_backend_adam_w_beta2": float(opt_backend_adam_w_state.get("beta2", beta_2)),
                    "opt_backend_adam_w_epsilon": float(opt_backend_adam_w_state.get("epsilon", epsilon)),
                }
            )
        if plateau_vmin_best_backend_adam_w_state is not None:
            ckpt_state.update(
                {
                    "plateau_vmin_best_backend_adam_w_step": int(
                        plateau_vmin_best_backend_adam_w_state.get("step", 0)
                    ),
                    "plateau_vmin_best_backend_adam_w_m": np.asarray(
                        plateau_vmin_best_backend_adam_w_state.get("m", []), dtype=np.float64
                    ),
                    "plateau_vmin_best_backend_adam_w_v": np.asarray(
                        plateau_vmin_best_backend_adam_w_state.get("v", []), dtype=np.float64
                    ),
                    "plateau_vmin_best_backend_adam_w_beta1": float(
                        plateau_vmin_best_backend_adam_w_state.get("beta1", beta_1)
                    ),
                    "plateau_vmin_best_backend_adam_w_beta2": float(
                        plateau_vmin_best_backend_adam_w_state.get("beta2", beta_2)
                    ),
                    "plateau_vmin_best_backend_adam_w_epsilon": float(
                        plateau_vmin_best_backend_adam_w_state.get("epsilon", epsilon)
                    ),
                }
            )
    if epoch_prof:
        with epoch_prof.phase("ckpt"):
            _save_ckpt(output_path, prefix, suffix, ckpt_state, logger=logger)
    else:
        _save_ckpt(output_path, prefix, suffix, ckpt_state, logger=logger)

    if epoch_prof and logger:
        total_s = time.perf_counter() - t_epoch0
        logger.info(f"TIMING epoch {epoch}: {epoch_prof.format_summary(total_s)}")
        if cfg.replay and hasattr(backend, "last_sim_capture_timing"):
            bt = getattr(backend, "last_sim_capture_timing", None)
            if isinstance(bt, dict) and bt:
                keys = sorted(bt.keys())
                line = ", ".join([f"{k}={bt[k]}" for k in keys])
                logger.info(f"TIMING heliox_backend.simulate_and_capture_lr_signals: {line}")

    # Optional: print only the total epoch wall time (kept off by default for release logs).
    if logger and bool(cfg.print_epoch_time):
        total_s = time.perf_counter() - t_epoch0
        logger.info(f"epoch time: {total_s:.3f}s")

    return ckpt_state


def test(
    net: Network,
    output_names,
    *,
    output_path: str | None = None,
    prefix: str | None = None,
    suffix: str | None = None,
    dt_ms: float | None = None,
    tstop_ms: float | None = None,
    v_init: float | None = None,
    logger=None,
):
    """Forward-only test helper (HELIOX-only runtime).

    NOTE: This is a lightweight debug helper.
    """
    cfg = WormTrainConfig.from_env()
    logger = logger or _LOGGER

    # Resolve legacy defaults.
    if output_path is None:
        output_path = OUTPUT_PATH
    if prefix is None:
        prefix = PREFIX
    if suffix is None:
        suffix = SUFFIX

    if dt_ms is None:
        if dt is None:
            raise RuntimeError("worm_train.test requires dt_ms (or set module-level dt before calling).")
        dt_ms = float(dt)
    if tstop_ms is None:
        if tstop is None:
            raise RuntimeError("worm_train.test requires tstop_ms (or set module-level tstop before calling).")
        tstop_ms = float(tstop)
    if v_init is None:
        if v_r is None:
            raise RuntimeError("worm_train.test requires v_init (or set module-level v_r before calling).")
        v_init = float(v_r)

    opt_w = np.load(_weights_optimal_path(output_path, prefix, suffix))
    opt_x = np.load(_x_optimal_path(output_path, prefix, suffix))
    net.set_weights(opt_w)
    synlist = {}
    syn_cnt = 0
    for i in net.synlist.keys():
        synlist[i] = {}
        for j in net.synlist[i].keys():
            syninfo_list = net.synlist[i][j]
            synlist[i][j] = []
            for syninfo in syninfo_list:
                try:
                    syninfo.syn.Vth
                    synlist[i][j].append([syninfo.id, syninfo.point, 'syn', syninfo.syn.w * 1e4, syninfo.p])
                except Exception:
                    synlist[i][j].append([syninfo.id, syninfo.point, 'gj', syninfo.syn.w, syninfo.p])
                syn_cnt += 1
    logger.info("synapse cnt: %s", syn_cnt)
    pickle.dump(synlist, open(os.path.join(output_path, "net_synlist.pkl"), "wb"))
    logger.info("saved network")
    net.set_outputs(output_names)
    x = np.copy(opt_x)
    net.input_vs = x
    backend = getattr(net, "_heliox_backend", None)
    if backend is None:
        raise RuntimeError(
            "Runtime stepping is HELIOX-only. Attach a HELIOX backend before calling test()."
        )

    h.t = 0
    h.dt = float(dt_ms)
    h.tstop = float(tstop_ms)
    h.secondorder = 0
    backend.finitialize(float(v_init))
    net.set_weights()

    # stimulation
    tstep = 0
    print_every_steps = max(1, int(round(float(cfg.print_interval_ms) / float(dt_ms))))
    total_steps = int(round(float(h.tstop) / float(h.dt))) if float(h.dt) > 0 else 0
    if total_steps <= 0:
        total_steps = 1
    output_vs = np.zeros((len(output_names), total_steps + 1), dtype=np.float32)
    output_vs[:, 0] = backend.read_output_v()
    while backend.get_t() < float(h.tstop):
        if cfg.print_timestep:
            if (tstep % print_every_steps) == 0:
                t_ms = backend.get_t()
                print(f't: {t_ms}/{h.tstop}', end='\r')
        backend.set_input_amps(x[:, tstep])
        backend.fadvance()
        if tstep + 1 < output_vs.shape[1]:
            output_vs[:, tstep + 1] = backend.read_output_v()
        tstep += 1
    if cfg.print_timestep:
        print("")

    return output_vs
