import numpy as np
# import cupy as cp
import os
import math
import logging
try:
    import opt_einsum as oe  # type: ignore
except ModuleNotFoundError:  # pragma: no cover
    oe = None
from neuron import h
import torch
# from utils.func import load_json
import pickle
import json
from dataclasses import dataclass
from typing import Iterable, List, Tuple, Any
import re
import time


_LOGGER = logging.getLogger("worm_demo")

if oe is None:
    class _OptEinsumTorchFallback:
        @staticmethod
        def contract(expr: str, *operands, backend: str | None = None, **_kwargs):
            if backend not in (None, "torch"):
                raise ModuleNotFoundError(
                    "opt_einsum is not installed, but a non-torch backend was requested "
                    f"(backend={backend!r}). Install opt_einsum or use backend='torch'."
                )
            return torch.einsum(expr, *operands)

    oe = _OptEinsumTorchFallback()  # type: ignore[assignment]


def _verbose_enabled() -> bool:
    return os.environ.get("EWORM_VERBOSE", "0").strip() == "1"


def _nonfinite_log_enabled() -> bool:
    return os.environ.get("EWORM_NONFINITE_LOG", "0").strip() == "1"


def _k_progress_enabled() -> bool:
    return os.environ.get("EWORM_PRINT_K_PROGRESS", "0").strip() == "1"

def _k_progress_every_blocks() -> int:
    try:
        return max(0, int(os.environ.get("EWORM_K_PROGRESS_EVERY_BLOCKS", "0").strip()))
    except Exception:
        return 0


def load_json(file_name):
    with open(file_name, 'r+') as f:
        data_dic = json.load(f)
    return data_dic

_DVDW_NAN_ACTION = os.environ.get("EWORM_DVDW_NAN_ACTION", "zero").strip().lower()
_GRAD_NAN_ACTION = os.environ.get("EWORM_GRAD_NAN_ACTION", "zero").strip().lower()


def _sanitize_nonfinite_(tensor: torch.Tensor, *, name: str, action: str) -> None:
    """In-place sanitization for non-finite values (NaN/Inf)."""
    action = (action or "zero").strip().lower()
    if action in ("zero", "nan_to_num", "clamp"):
        torch.nan_to_num_(tensor, nan=0.0, posinf=0.0, neginf=0.0)
        return
    if action in ("raise", "error", "assert"):
        raise FloatingPointError(f"{name} contains NaN/Inf")
    raise ValueError(f"Unknown non-finite action: {action} (for {name})")


_DVDW_REPLAY_JIT_FN: Any = None


def _default_heliox_python_lib() -> str:
    path = os.environ.get("HELIOX_PYTHON_LIB", "").strip()
    if path:
        return os.path.expanduser(path)
    # If running inside the shipped bundle layout:
    #   <bundle_root>/worm_demo/worm_network.py
    #   <bundle_root>/heliox/python_lib
    try:
        from pathlib import Path

        here = Path(__file__).resolve().parent
        cand = (here.parent / "heliox" / "python_lib").resolve()
        if cand.is_dir():
            return str(cand)
    except Exception:
        pass
    return ""


@dataclass(frozen=True)
class _HandleAccum:
    handle: int
    dest: int
    scale: float


@dataclass(frozen=True)
class _WeightHandle:
    handle: int
    p_index: int


class WormHelioXRuntime:
    """
    Run worm time stepping on HELIOX.

    This was historically implemented as a standalone module (`worm_heliox_backend.py`).
    Since this repo now supports only one runtime backend (HELIOX), we keep the runtime
    implementation colocated with Network and expose it from here to avoid an extra
    "backend wrapper layer" file.
    """

    def __init__(self, net, output_names: Iterable[str], *, dt: float, v_init: float, export_path: str):
        import sys

        heliox_lib = _default_heliox_python_lib()
        if heliox_lib not in sys.path and os.path.isdir(heliox_lib):
            sys.path.insert(0, heliox_lib)

        try:
            from heliox_core import HandleBatch, HelioXManager, VecPlayBatch, get_variable_handle  # type: ignore
        except Exception as e:  # pragma: no cover
            raise ImportError(
                "Failed to import HELIOX python wrapper core layer. "
                "Set HELIOX_PYTHON_LIB to the folder containing heliox_core/ (and heliox_wrapper.py) "
                f"(tried {heliox_lib})."
            ) from e

        self.net = net
        self.dt = float(dt)
        self.v_init = float(v_init)
        self.export_path = export_path

        device = os.environ.get("HELIOX_DEVICE", "gpu").strip().lower()
        permute_type = int(os.environ.get("HELIOX_PERMUTE_TYPE", "3"))

        self.manager = HelioXManager()
        self.manager.set_default_device(device)
        self.manager.set_default_permute_type(permute_type)

        self._pc = None
        self._pc_ncs = []
        # NOTE: HELIOX recorders/monitors are currently unstable in some GPU builds.
        # For now we rely on per-step handle reads for output v(t).
        self._output_recorders = []
        self._input_vecplays = []
        self._output_names = list(output_names)

        # Assign GIDs (required by nrnbbcore_write for CoreNEURON export).
        # This must be done before calling HelioXManager.setup_and_load_model().
        pc = h.ParallelContext()
        self._pc = pc
        gid = 0
        for cell_id in self.net.cells_id_sim:
            cn = self.net.cells_name_dic[cell_id]
            cell = self.net.cells[cn]
            soma = cell.Soma
            pc.set_gid2node(int(gid), int(pc.id()))
            nc = h.NetCon(soma(0.5)._ref_v, None, sec=soma)
            pc.cell(int(gid), nc)
            self._pc_ncs.append(nc)
            gid += 1

        # VecPlay channels must be created before export so they can be initialized after export+load.
        # (The wrapper already implements delayed initialization for VecPlayWrapper.)
        input_clamps: list[tuple[int, object, int]] = []
        for cell_id in self.net.cells_id_sim:
            if cell_id not in self.net.input_ids:
                continue
            cn = self.net.cells_name_dic[cell_id]
            cell = self.net.cells[cn]
            p0 = int(self.net.K_start[cell_id])
            iclamp = self.net.input_synlist[cell_id]
            input_clamps.append((int(cell_id), iclamp, p0 + int(cell.Soma.nseg)))
            self._input_vecplays.append(self.manager.create_vecplay_wrapper(iclamp, "amp"))
        input_clamp_by_id = {cid: iclamp for cid, iclamp, _ in input_clamps}

        export_path = os.path.abspath(os.path.expanduser(export_path))
        os.makedirs(export_path, exist_ok=True)
        self.manager.setup_and_load_model(export_path, dt=self.dt, v_init=self.v_init)

        self._pure_i_accum: list[_HandleAccum] = []
        self._didv_accum: list[_HandleAccum] = []
        self._didvpre_accum: list[_HandleAccum] = []
        self._weight_handles: list[_WeightHandle] = []

        # Soma mechanisms (per-segment, per-mechanism) -> It/didv accumulators.
        for cell_id in self.net.cells_id_sim:
            cn = self.net.cells_name_dic[cell_id]
            cell = self.net.cells[cn]
            p0 = int(self.net.K_start[cell_id])
            for seg_id, seg in enumerate(cell.Soma):
                seg_area_scale = float(seg.area()) * 1e-2
                dest = p0 + int(seg_id)
                for mech_name in self.net.mech_list:
                    mech = getattr(seg, mech_name)
                    h_pure = get_variable_handle(self.manager, mech, "pure_i", 0)
                    h_didv = get_variable_handle(self.manager, mech, "didv", 0)
                    self._pure_i_accum.append(_HandleAccum(int(h_pure), int(dest), float(seg_area_scale)))
                    self._didv_accum.append(_HandleAccum(int(h_didv), int(dest), float(seg_area_scale)))

        # IClamp contribution into It at the "extra slot" (p0 + nseg).
        for cell_id, iclamp, dest in input_clamps:
            h_i = get_variable_handle(self.manager, iclamp, "i", 0)
            self._pure_i_accum.append(_HandleAccum(int(h_i), int(dest), 1.0))

        # Synapse/gap weights and local learning signals.
        for cell_id in self.net.cells_id_sim:
            for point in self.net.synlist[cell_id].keys():
                for syninfo in self.net.synlist[cell_id][point]:
                    syn = syninfo.syn
                    p_index = int(syninfo.p)
                    self._pure_i_accum.append(_HandleAccum(int(get_variable_handle(self.manager, syn, "pure_i", 0)), p_index, 1.0))
                    self._didv_accum.append(_HandleAccum(int(get_variable_handle(self.manager, syn, "didv", 0)), p_index, 1.0))
                    self._didvpre_accum.append(
                        _HandleAccum(int(get_variable_handle(self.manager, syn, "didvpre", 0)), p_index, 1.0)
                    )
                    self._weight_handles.append(_WeightHandle(int(get_variable_handle(self.manager, syn, "w", 0)), p_index))

        # Input amp handles (for per-step mode); preserve stable ordering by net.input_ids.
        self._input_amp_handles = []
        for cell_id in self.net.input_ids:
            cid = int(cell_id)
            iclamp = input_clamp_by_id.get(cid, None)
            if iclamp is None:
                raise RuntimeError(
                    f"Missing IClamp for input cell_id={cid}. "
                    "Check EWORM_DEBUG_DISABLE_INPUTS / model construction."
                )
            self._input_amp_handles.append(int(get_variable_handle(self.manager, iclamp, "amp", 0)))
        self._input_amp_batch = HandleBatch.from_iterable(self.manager, self._input_amp_handles)
        self._input_vecplay_batch = VecPlayBatch(self._input_vecplays)

        self._output_v_handles: list[int] = []
        self._output_names = list(output_names)
        for cn in output_names:
            cell_id = self.net.cells_id_dic[cn]
            seg = self.net.get_3dp_segment(cell_id, 0)
            self._output_v_handles.append(int(get_variable_handle(self.manager, seg, "v", 0)))
        self._output_v_batch = HandleBatch.from_iterable(self.manager, self._output_v_handles)

        self._pure_i_handles = [x.handle for x in self._pure_i_accum]
        self._pure_i_dest = np.asarray([x.dest for x in self._pure_i_accum], dtype=np.int32)
        self._pure_i_scale = np.asarray([x.scale for x in self._pure_i_accum], dtype=np.float32)

        self._didv_handles = [x.handle for x in self._didv_accum]
        self._didv_dest = np.asarray([x.dest for x in self._didv_accum], dtype=np.int32)
        self._didv_scale = np.asarray([x.scale for x in self._didv_accum], dtype=np.float32)

        self._didvpre_handles = [x.handle for x in self._didvpre_accum]
        self._didvpre_dest = np.asarray([x.dest for x in self._didvpre_accum], dtype=np.int32)
        self._didvpre_scale = np.asarray([x.scale for x in self._didvpre_accum], dtype=np.float32)

        self._weight_p_indices = np.asarray([x.p_index for x in self._weight_handles], dtype=np.int32)
        self._weight_handles_only = [x.handle for x in self._weight_handles]

        # Learning index metadata (used by backend replay dw/dx).
        # `Network.set_outputs()` is normally called inside the training loop, but backend init
        # also needs `poutput` for replay gradients, so ensure it exists here.
        if not hasattr(self.net, "poutput"):
            self.net.set_outputs(list(output_names))
        self._poutput = np.ascontiguousarray(np.asarray(self.net.poutput, dtype=np.int32))
        self._pinput = np.ascontiguousarray(np.asarray(self.net.pinput, dtype=np.int32))
        pre = np.asarray(getattr(self.net, "preplist", []), dtype=np.int32)
        post = np.asarray(getattr(self.net, "postplist", []), dtype=np.int32)
        if pre.shape != post.shape:
            raise RuntimeError(f"preplist/postplist shape mismatch: {pre.shape} vs {post.shape}")
        pre_of_col = np.full((int(self.net.N),), -1, dtype=np.int32)
        if pre.size:
            pre_of_col[post] = pre
        self._pre_of_col = np.ascontiguousarray(pre_of_col)

        self._dense_blocks_uploaded = False
        self._weight_adam_optimizer_id: int | None = None

    def _require_backend_api_(self, attr: str, *, feature: str):
        cli = self.manager.client
        if not hasattr(cli, attr):
            raise RuntimeError(f"HELIOX build missing required API for {feature}; please rebuild HelioX")
        return cli

    def _ensure_dense_blocks_uploaded(self) -> None:
        if self._dense_blocks_uploaded:
            return
        if not hasattr(self.manager.client, "set_dense_blocks_f32"):
            raise RuntimeError("HELIOX build missing required API for dense blocks; please rebuild HelioX")
        if not hasattr(self.net, "K_blocks_cpu"):
            raise RuntimeError(
                "Network has no K_blocks_cpu. Ensure K is loaded (do not set EWORM_SKIP_K=1) and rebuild if needed."
            )
        blocks_cpu = []
        for b in self.net.K_blocks_cpu:
            blocks_cpu.append(np.ascontiguousarray(np.asarray(b, dtype=np.float32)))
        self.manager.client.set_dense_blocks_f32(blocks_cpu)
        self._dense_blocks_uploaded = True

    @property
    def weight_p_indices(self) -> np.ndarray:
        return self._weight_p_indices

    def finitialize(self, v_init: float | None = None) -> None:
        self.manager.finitialize(self.v_init if v_init is None else float(v_init))

    def fadvance(self) -> None:
        self.manager.fadvance()

    def get_t(self) -> float:
        return float(self.manager.get_t())

    def set_input_amps(self, amps: np.ndarray) -> None:
        amps = np.asarray(amps, dtype=np.float64).reshape((-1,))
        if amps.shape[0] != len(self._input_amp_handles):
            raise ValueError(f"expected {len(self._input_amp_handles)} amps, got {amps.shape[0]}")
        self._input_amp_batch.write(amps.tolist())

    def read_output_v(self) -> np.ndarray:
        return self._output_v_batch.read_f32()

    def _fill_from_accum(self, out: np.ndarray, handles: list[int], dest: np.ndarray, scale: np.ndarray) -> None:
        if hasattr(self.manager, "get_variables_by_handles_f32"):
            vals = self.manager.get_variables_by_handles_f32(handles)
        else:
            vals = np.asarray(self.manager.get_variables_by_handles(handles), dtype=np.float32)
        out.fill(0.0)
        np.add.at(out, dest, vals * scale)

    def fill_it(self, it_cpu: np.ndarray) -> None:
        self._fill_from_accum(it_cpu, self._pure_i_handles, self._pure_i_dest, self._pure_i_scale)

    def fill_ditdv(self, ditdv_cpu: np.ndarray, ditdvpre_cpu: np.ndarray) -> None:
        self._fill_from_accum(ditdv_cpu, self._didv_handles, self._didv_dest, self._didv_scale)
        self._fill_from_accum(ditdvpre_cpu, self._didvpre_handles, self._didvpre_dest, self._didvpre_scale)

    def push_weights(self, w_values: np.ndarray) -> None:
        w_values = np.asarray(w_values, dtype=np.float64).reshape((-1,))
        if w_values.shape[0] != self._weight_p_indices.shape[0]:
            raise ValueError(f"expected {self._weight_p_indices.shape[0]} weights, got {w_values.shape[0]}")
        self.manager.set_variables_by_handles(self._weight_handles_only, w_values.tolist())

    def ensure_weight_adam_optimizer(self, *, beta1: float, beta2: float, epsilon: float) -> int:
        if self._weight_adam_optimizer_id is not None:
            return int(self._weight_adam_optimizer_id)
        cli = self._require_backend_api_("optimizer_add_external_grads", feature="backend optimizer (external grads)")
        self._require_backend_api_("optimizer_set_external_grads_f32", feature="backend optimizer (external grads)")
        self._require_backend_api_("optimizer_step_with_inv_record_steps", feature="backend optimizer (step)")
        opt_id = int(cli.create_optimizer("adam"))
        cli.configure_optimizer(int(opt_id), beta1=float(beta1), beta2=float(beta2), epsilon=float(epsilon))
        cli.optimizer_add_external_grads(int(opt_id), self._weight_handles_only, 1.0)
        self._weight_adam_optimizer_id = int(opt_id)
        return int(opt_id)

    def get_weight_adam_state(self) -> dict:
        if self._weight_adam_optimizer_id is None:
            raise RuntimeError("weight Adam optimizer not initialized")
        cli = self._require_backend_api_("optimizer_get_adam_state", feature="backend optimizer (get state)")
        step, m, v, beta1, beta2, eps = cli.optimizer_get_adam_state(int(self._weight_adam_optimizer_id))
        return {
            "step": int(step),
            "m": np.asarray(m, dtype=np.float64),
            "v": np.asarray(v, dtype=np.float64),
            "beta1": float(beta1),
            "beta2": float(beta2),
            "epsilon": float(eps),
        }

    def set_weight_adam_state(self, state: dict) -> None:
        if self._weight_adam_optimizer_id is None:
            raise RuntimeError("weight Adam optimizer not initialized")
        cli = self._require_backend_api_("optimizer_set_adam_state", feature="backend optimizer (set state)")
        step = int(state.get("step", 0))
        m = np.asarray(state.get("m", []), dtype=np.float64).reshape((-1,))
        v = np.asarray(state.get("v", []), dtype=np.float64).reshape((-1,))
        beta1 = float(state.get("beta1", 0.9))
        beta2 = float(state.get("beta2", 0.999))
        eps = float(state.get("epsilon", 1e-8))
        cli.optimizer_set_adam_state(int(self._weight_adam_optimizer_id), int(step), m.tolist(), v.tolist(), beta1, beta2, eps)

    def reset_weight_adam_state(self) -> None:
        if self._weight_adam_optimizer_id is None:
            return
        cli = self._require_backend_api_("optimizer_reset_state", feature="backend optimizer (reset state)")
        cli.optimizer_reset_state(int(self._weight_adam_optimizer_id))

    def adam_step_weights_from_dw(self, dw_full: np.ndarray, *, learning_rate: float) -> None:
        if self._weight_adam_optimizer_id is None:
            raise RuntimeError("weight Adam optimizer not initialized; call ensure_weight_adam_optimizer() first")
        dw_full = np.asarray(dw_full, dtype=np.float32).reshape((-1,))
        if dw_full.shape[0] != int(self.net.N):
            raise ValueError(f"expected dw_full shape ({int(self.net.N)},), got {dw_full.shape}")
        dw_subset = np.ascontiguousarray(dw_full[self._weight_p_indices], dtype=np.float32)
        cli = self.manager.client
        cli.optimizer_set_external_grads_f32(int(self._weight_adam_optimizer_id), dw_subset)
        cli.optimizer_step_with_inv_record_steps(int(self._weight_adam_optimizer_id), float(learning_rate), 1.0)

    def pull_weights(self) -> np.ndarray:
        vals = self.manager.get_variables_by_handles(self._weight_handles_only)
        return np.asarray(vals, dtype=np.float64)

    def play_inputs(self, x: np.ndarray, *, dt: float) -> None:
        self._input_vecplay_batch.play_matrix(x, dt_ms=float(dt))

    def read_output_recorders(self) -> np.ndarray:
        if not self._output_recorders:
            raise RuntimeError("No output recorders initialized")
        return np.stack([np.asarray(r.data, dtype=np.float32) for r in self._output_recorders], axis=0)

    def simulate_output_vs(
        self,
        x: np.ndarray,
        *,
        tstop_ms: float,
        dt_ms: float,
        use_vecplay: bool,
        v_init: float,
        assume_weights_already_set: bool = False,
        assume_inputs_already_played: bool = False,
    ) -> np.ndarray:
        x = np.asarray(x, dtype=np.float64)
        if x.ndim != 2:
            raise ValueError(f"expected x shape (N_input, T), got {x.shape}")
        total_steps = int(float(tstop_ms) / float(dt_ms))
        if total_steps <= 0:
            total_steps = 1

        if use_vecplay and hasattr(self.manager, "client") and hasattr(self.manager.client, "simulate_output_vs_into"):
            # NOTE: the backend API calls finitialize(v_init) internally. We only need to ensure
            # weights/vecplay are updated before we enter the backend call.
            if not assume_weights_already_set:
                self.net.set_weights()
            if not assume_inputs_already_played:
                self.play_inputs(x, dt=float(dt_ms))

            output_vs_tn = np.empty((total_steps + 1, len(self._output_v_handles)), dtype=np.float32)
            rc = int(
                self.manager.client.simulate_output_vs_into(
                    output_vs_tn,
                    list(self._output_v_handles),
                    float(tstop_ms),
                    float(v_init),
                )
            )
            if rc != 0:
                raise RuntimeError(f"simulate_output_vs_into failed (rc={rc})")
            return np.ascontiguousarray(output_vs_tn.T, dtype=np.float32)

        # Fallback: explicit step loop (writes input amps each step when vecplay disabled).
        output_vs = np.zeros((len(self._output_names), total_steps + 1), dtype=np.float32)
        self.finitialize(v_init)
        if not assume_weights_already_set:
            self.net.set_weights()
        if use_vecplay:
            if not assume_inputs_already_played:
                self.play_inputs(x, dt=float(dt_ms))
        output_vs[:, 0] = self.read_output_v()
        tstep = 0
        while self.get_t() < float(tstop_ms):
            if not use_vecplay:
                self.set_input_amps(x[:, tstep])
            self.fadvance()
            if tstep + 1 < output_vs.shape[1]:
                output_vs[:, tstep + 1] = self.read_output_v()
            tstep += 1
        return output_vs

    def simulate_and_capture_lr_signals(
        self,
        x: np.ndarray,
        *,
        tstop_ms: float,
        dt_ms: float,
        k_mul: int,
        percise: bool,
        use_vecplay: bool,
        v_init: float,
    ):
        x = np.asarray(x, dtype=np.float64)
        if x.ndim != 2:
            raise ValueError(f"expected x shape (N_input, T), got {x.shape}")
        total_steps = int(float(tstop_ms) / float(dt_ms))
        if total_steps <= 0:
            total_steps = 1
        ksteps_total = int(total_steps // int(k_mul))

        output_vs = np.zeros((len(self._output_names), total_steps + 1), dtype=np.float32)
        it_lr = np.zeros((int(self.net.N), ksteps_total + 1), dtype=np.float32)
        ditdv_lr = np.zeros((int(self.net.N), ksteps_total + 1), dtype=np.float32) if percise else None
        ditdvpre_lr = np.zeros((int(self.net.N), ksteps_total + 1), dtype=np.float32) if percise else None

        self.finitialize(v_init)
        self.net.set_weights()

        if use_vecplay:
            self.play_inputs(x, dt=float(dt_ms))

        tstep = 0
        output_vs[:, 0] = self.read_output_v()

        self.last_sim_capture_timing = []
        pre_time = time.time()
        while self.get_t() < float(tstop_ms):
            if not use_vecplay:
                self.set_input_amps(x[:, tstep])
            self.fadvance()
            if tstep + 1 < output_vs.shape[1]:
                output_vs[:, tstep + 1] = self.read_output_v()
            tstep += 1
            if tstep % int(k_mul) == 0:
                kstep = int(tstep // int(k_mul))
                it_col = it_lr[:, kstep]
                self.fill_it(it_col)
                if percise:
                    assert ditdv_lr is not None and ditdvpre_lr is not None
                    self.fill_ditdv(ditdv_lr[:, kstep], ditdvpre_lr[:, kstep])
                now = time.time()
                self.last_sim_capture_timing.append(f"t={self.get_t():.1f}ms used={now-pre_time:.3f}s")
                pre_time = now

        return output_vs, it_lr, ditdv_lr, ditdvpre_lr

    def simulate_and_capture_lr_signals_cached(
        self,
        x: np.ndarray,
        *,
        tstop_ms: float,
        dt_ms: float,
        k_mul: int,
        percise: bool,
        use_vecplay: bool,
        v_init: float,
        assume_weights_already_set: bool = False,
        assume_inputs_already_played: bool = False,
    ) -> np.ndarray:
        """
        One-simulation-pass helper:
        - runs simulation and captures output_vs to CPU
        - caches lr signals on GPU for later replay (no CPU matrices)
        """
        if not use_vecplay:
            raise ValueError("cached capture requires use_vecplay=True (inputs must be installed via VecPlay)")

        # Ensure dense blocks are uploaded before caching signals, because (by design) uploading K
        # clears replay/capture workspaces on the backend. If we upload after capture, we would
        # invalidate the cached signals before replay.
        self._ensure_dense_blocks_uploaded()

        x = np.asarray(x, dtype=np.float64)
        if x.ndim != 2:
            raise ValueError(f"expected x shape (N_input, T), got {x.shape}")
        total_steps = int(float(tstop_ms) / float(dt_ms))
        if total_steps <= 0:
            total_steps = 1

        if not assume_weights_already_set:
            self.net.set_weights()
        if not assume_inputs_already_played:
            self.play_inputs(x, dt=float(dt_ms))

        # Use the helper API to avoid depending on low-level backend entrypoint names in training code.
        from heliox_learn import CaptureSpec, MappedSignal, capture_signals_cached, empty_signal  # type: ignore

        pure_i = MappedSignal(
            name="pure_i",
            handles=tuple(int(h) for h in self._pure_i_handles),
            dest=np.asarray(self._pure_i_dest, dtype=np.int32),
            scale=np.asarray(self._pure_i_scale, dtype=np.float32),
        )
        didv = (
            MappedSignal(
                name="didv",
                handles=tuple(int(h) for h in self._didv_handles),
                dest=np.asarray(self._didv_dest, dtype=np.int32),
                scale=np.asarray(self._didv_scale, dtype=np.float32),
            )
            if percise
            else empty_signal("didv")
        )
        didvpre = (
            MappedSignal(
                name="didvpre",
                handles=tuple(int(h) for h in self._didvpre_handles),
                dest=np.asarray(self._didvpre_dest, dtype=np.int32),
                scale=np.asarray(self._didvpre_scale, dtype=np.float32),
            )
            if percise
            else empty_signal("didvpre")
        )

        spec = CaptureSpec(
            output_v_handles=tuple(int(h) for h in self._output_v_handles),
            signals={"pure_i": pure_i, "didv": didv, "didvpre": didvpre},
            tstop_ms=float(tstop_ms),
            k_mul=int(k_mul),
            percise=bool(percise),
            v_init=float(v_init),
        )
        pack = capture_signals_cached(self.manager, spec, total_steps=int(total_steps))
        return np.ascontiguousarray(pack.output_vs_tn.T, dtype=np.float32)

    def replay_grads_from_cached_signals(
        self,
        dLtdv_lr_to: np.ndarray,
        *,
        percise: bool,
        dt_ms: float,
        grad_scale: float = 1.0,
        eps: float = 1e-6,
    ):
        self._ensure_dense_blocks_uploaded()

        # Backend expects dLtdv in time-major layout: (T_lr, N_output), contiguous.
        dLtdv_lr_to = np.asarray(dLtdv_lr_to, dtype=np.float32)
        if dLtdv_lr_to.ndim != 2:
            raise ValueError(f"expected dLtdv_lr_to 2D, got {dLtdv_lr_to.shape}")
        if dLtdv_lr_to.shape[0] == int(self.net.N_output):
            dLtdv_lr_ot = np.ascontiguousarray(dLtdv_lr_to.T, dtype=np.float32)
        else:
            dLtdv_lr_ot = np.ascontiguousarray(dLtdv_lr_to, dtype=np.float32)

        grad_l2norm_threshold = float(os.environ.get("EWORM_REPLAY_GRAD_L2NORM_THRESHOLD", "1e6"))
        clip_strategy = int(os.environ.get("EWORM_REPLAY_CLIP_STRATEGY", "1"))
        clip_check_every = int(os.environ.get("EWORM_REPLAY_CLIP_CHECK_EVERY", "1"))

        from heliox_learn import replay_grads_from_cached_signals  # type: ignore

        pinput = None
        if int(self.net.N_input) > 0:
            pinput = np.asarray(self._pinput, dtype=np.int32)
            if pinput.size == 0:
                pinput = None

        return replay_grads_from_cached_signals(
            self.manager,
            dLtdv_lr_ot=dLtdv_lr_ot,
            poutput=np.asarray(self._poutput, dtype=np.int32),
            pre_of_col=np.asarray(self._pre_of_col, dtype=np.int32),
            dt_ms=float(dt_ms),
            percise=bool(percise),
            pinput=pinput,
            grad_scale=float(grad_scale),
            eps=float(eps),
            grad_l2norm_threshold=float(grad_l2norm_threshold),
            clip_strategy=int(clip_strategy),
            clip_check_every=int(clip_check_every),
        )

    def simulate_and_replay_grads_streaming(
        self,
        x: np.ndarray,
        dLtdv_lr_to: np.ndarray,
        *,
        tstop_ms: float,
        dt_ms: float,
        k_mul: int,
        percise: bool,
        use_vecplay: bool,
        v_init: float,
        grad_scale: float = 1.0,
        eps: float = 1e-6,
        assume_weights_already_set: bool = False,
        assume_inputs_already_played: bool = False,
    ):
        if not use_vecplay:
            raise ValueError("streaming replay requires use_vecplay=True (inputs must be installed via VecPlay)")
        self._ensure_dense_blocks_uploaded()

        x = np.asarray(x, dtype=np.float64)
        if x.ndim != 2:
            raise ValueError(f"expected x shape (N_input, T), got {x.shape}")

        # Backend expects dLtdv in time-major layout: (T_lr, N_output), contiguous.
        dLtdv_lr_to = np.asarray(dLtdv_lr_to, dtype=np.float32)
        if dLtdv_lr_to.ndim != 2:
            raise ValueError(f"expected dLtdv_lr_to 2D, got {dLtdv_lr_to.shape}")
        if dLtdv_lr_to.shape[0] == int(self.net.N_output):
            dLtdv_lr_ot = np.ascontiguousarray(dLtdv_lr_to.T, dtype=np.float32)
        else:
            dLtdv_lr_ot = np.ascontiguousarray(dLtdv_lr_to, dtype=np.float32)

        ksteps_total = int(dLtdv_lr_ot.shape[0])
        if ksteps_total <= 0:
            raise ValueError(f"dLtdv_lr is empty: {dLtdv_lr_ot.shape}")

        dw_out = np.empty((int(self.net.N),), dtype=np.float32)
        dx_out = np.empty((int(self.net.N_input), int(ksteps_total)), dtype=np.float32)

        grad_l2norm_threshold = float(os.environ.get("EWORM_REPLAY_GRAD_L2NORM_THRESHOLD", "1e6"))
        clip_strategy = int(os.environ.get("EWORM_REPLAY_CLIP_STRATEGY", "1"))
        clip_check_every = int(os.environ.get("EWORM_REPLAY_CLIP_CHECK_EVERY", "1"))

        cli = self._require_backend_api_("simulate_and_replay_dw_dx_streaming_into", feature="streaming replay")

        # NOTE: the backend API calls finitialize(v_init) internally. We only need to ensure
        # weights/vecplay are updated before we enter the backend call.
        if not assume_weights_already_set:
            self.net.set_weights()
        if not assume_inputs_already_played:
            self.play_inputs(x, dt=float(dt_ms))

        if percise:
            didv_handles = list(self._didv_handles)
            didv_dest = self._didv_dest
            didv_scale = self._didv_scale
            didvpre_handles = list(self._didvpre_handles)
            didvpre_dest = self._didvpre_dest
            didvpre_scale = self._didvpre_scale
        else:
            didv_handles = []
            didv_dest = np.empty((0,), dtype=np.int32)
            didv_scale = np.empty((0,), dtype=np.float32)
            didvpre_handles = []
            didvpre_dest = np.empty((0,), dtype=np.int32)
            didvpre_scale = np.empty((0,), dtype=np.float32)

        rc = int(
            cli.simulate_and_replay_dw_dx_streaming_into(
                dLtdv_lr_ot,
                self._poutput,
                self._pinput,
                self._pre_of_col,
                dw_out,
                dx_out,
                list(self._pure_i_handles),
                self._pure_i_dest,
                self._pure_i_scale,
                didv_handles,
                didv_dest,
                didv_scale,
                didvpre_handles,
                didvpre_dest,
                didvpre_scale,
                float(tstop_ms),
                int(k_mul),
                bool(percise),
                float(v_init),
                float(dt_ms),
                float(grad_scale),
                float(eps),
                float(grad_l2norm_threshold),
                int(clip_strategy),
                int(clip_check_every),
            )
        )
        if rc != 0:
            raise RuntimeError(f"streaming replay failed (rc={rc})")

        from heliox_learn import ReplayGrads  # type: ignore

        return ReplayGrads(dw_out_n=np.asarray(dw_out, dtype=np.float32), dx_lr_it=np.asarray(dx_out, dtype=np.float32))

    def replay_grads_from_signals(
        self,
        it_lr: np.ndarray,
        ditdv_lr: np.ndarray | None,
        ditdvpre_lr: np.ndarray | None,
        dLtdv_lr_to: np.ndarray,
        *,
        percise: bool,
        dt_ms: float,
        grad_scale: float = 1.0,
        eps: float = 1e-6,
    ):
        self._ensure_dense_blocks_uploaded()

        it_lr = np.ascontiguousarray(np.asarray(it_lr, dtype=np.float32))
        if it_lr.ndim != 2:
            raise ValueError(f"expected it_lr shape (N, T), got {it_lr.shape}")

        # Backend expects dLtdv in time-major layout: (T_lr, N_output), contiguous.
        dLtdv_lr_to = np.asarray(dLtdv_lr_to, dtype=np.float32)
        if dLtdv_lr_to.ndim != 2:
            raise ValueError(f"expected dLtdv_lr_to 2D, got {dLtdv_lr_to.shape}")
        if dLtdv_lr_to.shape[0] == int(self.net.N_output):
            dLtdv_lr_ot = np.ascontiguousarray(dLtdv_lr_to.T, dtype=np.float32)
        else:
            dLtdv_lr_ot = np.ascontiguousarray(dLtdv_lr_to, dtype=np.float32)

        if percise:
            if ditdv_lr is None or ditdvpre_lr is None:
                raise ValueError("percise=True requires ditdv_lr and ditdvpre_lr")
            ditdv_lr_nt = np.ascontiguousarray(np.asarray(ditdv_lr, dtype=np.float32))
            ditdvpre_lr_nt = np.ascontiguousarray(np.asarray(ditdvpre_lr, dtype=np.float32))
        else:
            # API requires arrays; use it_lr as a dummy placeholder when percise=False.
            ditdv_lr_nt = it_lr
            ditdvpre_lr_nt = it_lr

        ksteps_total = int(it_lr.shape[1] - 1)
        if ksteps_total <= 0:
            raise ValueError(f"it_lr has too few steps: {it_lr.shape}")

        dw_out = np.empty((int(self.net.N),), dtype=np.float32)
        dx_out = np.empty((int(self.net.N_input), int(ksteps_total)), dtype=np.float32)

        grad_l2norm_threshold = float(os.environ.get("EWORM_REPLAY_GRAD_L2NORM_THRESHOLD", "1e6"))
        clip_strategy = int(os.environ.get("EWORM_REPLAY_CLIP_STRATEGY", "1"))
        clip_check_every = int(os.environ.get("EWORM_REPLAY_CLIP_CHECK_EVERY", "1"))

        cli = self._require_backend_api_("replay_compute_dw_dx_from_signals_into", feature="replay from CPU signals")

        cli.replay_compute_dw_dx_from_signals_into(
            it_lr,
            ditdv_lr_nt,
            ditdvpre_lr_nt,
            dLtdv_lr_ot,
            self._poutput,
            self._pinput,
            self._pre_of_col,
            dw_out,
            dx_out,
            float(dt_ms),
            bool(percise),
            float(grad_scale),
            float(eps),
            float(grad_l2norm_threshold),
            int(clip_strategy),
            int(clip_check_every),
        )

        from heliox_learn import ReplayGrads  # type: ignore

        return ReplayGrads(dw_out_n=np.asarray(dw_out, dtype=np.float32), dx_lr_it=np.asarray(dx_out, dtype=np.float32))


def _get_dvdw_replay_jit_fn():
    """Lazy-compile the TorchScript replay dv/dw loop to avoid Python-loop overhead."""
    global _DVDW_REPLAY_JIT_FN
    if _DVDW_REPLAY_JIT_FN is not None:
        return _DVDW_REPLAY_JIT_FN

    def _dvdw_replay_loop_impl(
        it_lr: torch.Tensor,
        ditdv_lr: torch.Tensor,
        ditdvpre_lr: torch.Tensor,
        K_blocks: List[torch.Tensor],
        block_starts: List[int],
        block_ns: List[int],
        pre_idx: torch.Tensor,
        post_idx: torch.Tensor,
        poutput: torch.Tensor,
        pinput: torch.Tensor,
        replay_it_buf: torch.Tensor,
        replay_ditdv_buf: torch.Tensor,
        replay_ditdvpre_buf: torch.Tensor,
        replay_dVtdw_buf: torch.Tensor,
        replay_dVpretdw_buf: torch.Tensor,
        dvtdw: torch.Tensor,
        dvpretdw: torch.Tensor,
        dVoutputtdw: torch.Tensor,
        dVinputtdw: torch.Tensor,
        ksteps_total: int,
        K_len: int,
        dt: float,
        grad_scale: float,
        grad_l2norm_threshold: float,
        lr_valid_len: int,
        lr_buf_len: int,
    ) -> Tuple[int, int, int, int, float]:
        it_pos = 0
        ditdv_pos = 0
        dV_pos = 1  # leading zero slice at t=0

        for tstep in range(1, int(ksteps_total) + 1):
            t_window = int(K_len) if int(tstep) >= int(K_len) else int(tstep)

            # --- Append It/ditdv signals (current tick) ---
            it_t = it_lr[:, tstep]
            replay_it_buf[:, it_pos] = it_t
            replay_it_buf[:, it_pos + K_len] = it_t
            it_pos = (it_pos + 1) % K_len

            ditdv_t = ditdv_lr[:, tstep]
            ditdvpre_t = ditdvpre_lr[:, tstep]
            replay_ditdv_buf[:, ditdv_pos] = ditdv_t
            replay_ditdv_buf[:, ditdv_pos + K_len] = ditdv_t
            replay_ditdvpre_buf[:, ditdv_pos] = ditdvpre_t
            replay_ditdvpre_buf[:, ditdv_pos + K_len] = ditdvpre_t
            ditdv_pos = (ditdv_pos + 1) % K_len

            # Window ends (exclusive) in the double buffers.
            it_end = it_pos + K_len
            ditdv_end = ditdv_pos + K_len
            dV_end = dV_pos + K_len

            it_win = replay_it_buf[:, it_end - t_window : it_end]
            ditdv_win = replay_ditdv_buf[:, ditdv_end - t_window : ditdv_end]
            ditdvpre_win = replay_ditdvpre_buf[:, ditdv_end - t_window : ditdv_end]
            dVtdw_win = replay_dVtdw_buf[:, :, dV_end - t_window : dV_end]
            dVpretdw_win = replay_dVpretdw_buf[:, :, dV_end - t_window : dV_end]

            # --- Compute dvtdw on GPU (reuse buffer) ---
            dvtdw.zero_()

            # Base (block-diagonal) term.
            for b in range(len(K_blocks)):
                start = int(block_starts[b])
                bn = int(block_ns[b])
                end = start + bn
                It_b = it_win[start:end, :]
                K_b = K_blocks[b][:, :, -t_window:]
                dv_block = torch.einsum("ijt,it->ij", K_b, It_b).to(torch.float32) * float(dt) * float(grad_scale)
                dvtdw[start:end, start:end] = dv_block

            # Precise correction term.
            for b in range(len(K_blocks)):
                start = int(block_starts[b])
                bn = int(block_ns[b])
                end = start + bn
                K_b = K_blocks[b][:, :, -t_window:]
                dItdv_b = ditdv_win[start:end, :]
                dItdvpre_b = ditdvpre_win[start:end, :]
                dVtdw_b = dVtdw_win[:, start:end, :]
                dVpretdw_b = dVpretdw_win[:, start:end, :]
                dItdw_b = dVtdw_b * dItdv_b.unsqueeze(0) + dVpretdw_b * dItdvpre_b.unsqueeze(0)
                dv_corr = torch.einsum("ikt,jkt->ij", dItdw_b, K_b).to(torch.float32) * float(dt)
                dvtdw[:, start:end] += dv_corr

            # Match original grad clip behavior (rare in practice).
            dvtdw_l2 = torch.linalg.norm(dvtdw, ord="fro")
            dvtdw_l2_f = float(dvtdw_l2.item())
            if not bool(torch.isfinite(dvtdw_l2).item()):
                torch.nan_to_num_(dvtdw, nan=0.0, posinf=0.0, neginf=0.0)
                dvtdw_l2_f = float(torch.linalg.norm(dvtdw, ord="fro").item())
            if dvtdw_l2_f > float(grad_l2norm_threshold):
                scaler = float(grad_l2norm_threshold) / dvtdw_l2_f
                dvtdw.mul_(scaler)
                dVoutputtdw.mul_(scaler)  # keep legacy behavior (input buffer not scaled here)
                replay_dVtdw_buf.mul_(scaler)
                replay_dVpretdw_buf.mul_(scaler)
                grad_scale = float(grad_scale) * scaler

            # dvpretdw (GPU): columns remap.
            dvpretdw.zero_()
            dvpretdw.index_copy_(1, post_idx, torch.index_select(dvtdw, 1, pre_idx))

            # --- Append dvtdw/dvpretdw into the history buffers (dV_pos points to next write slot) ---
            replay_dVtdw_buf[:, :, dV_pos] = dvtdw
            replay_dVtdw_buf[:, :, dV_pos + K_len] = dvtdw
            replay_dVpretdw_buf[:, :, dV_pos] = dvpretdw
            replay_dVpretdw_buf[:, :, dV_pos + K_len] = dvpretdw
            dV_pos = (dV_pos + 1) % K_len

            # Record output/input sensitivities for this LR tick (GPU).
            if int(lr_valid_len) < int(lr_buf_len):
                dVoutputtdw[:, :, lr_valid_len] = torch.index_select(dvtdw, 1, poutput)
                dVinputtdw[:, :, lr_valid_len] = torch.index_select(dvtdw, 1, pinput)
                lr_valid_len += 1

        return it_pos, ditdv_pos, dV_pos, lr_valid_len, float(grad_scale)

    _DVDW_REPLAY_JIT_FN = torch.jit.script(_dvdw_replay_loop_impl)
    return _DVDW_REPLAY_JIT_FN

class SynInfo(object):
    def __init__(self, id, point, syn, p=-1):
        self.id = id
        self.point = point
        self.syn = syn
        self.p = p


class Network(object):
    def __init__(self, net_config, lr_config, random_seed):
        self.random_seed = random_seed
        np.random.seed(self.random_seed)
        self._heliox_backend = None

        self.load_network_info(net_config)
        self.create_cells()
        self.set_cell_segments()
        self.define_biophysics()

        # learning configs
        self.v_r = lr_config['v_r']             # resting potential
        self.ngpu = lr_config['ngpu']
        self.K_max_t = lr_config['K_max_t']     # transfer impedance maximum time length
        self.K_filename = lr_config['K_filename']
        self.K_nblock = lr_config['K_nblock']
        self.K_nblock = np.clip(self.K_nblock, a_min=1, a_max=self.ncell)
        self.K_mul = lr_config['K_mul']
        self.w_gap_max = lr_config['w_gap_max']
        self.w_gap_min = lr_config['w_gap_min']
        self.w_syn_max = lr_config['w_syn_max']
        self.w_syn_min = lr_config['w_syn_min']
        # Perf helpers
        self.syninfos_flat = None
        self._lr_buf_len = None
        self._lr_valid_len = None


    def load_network_info(self, net_config):
        cell_info = net_config["cell_info"]
        self.cells_id_sim = cell_info["cells_id_sim"]
        self.cells_name_dic = cell_info["cells_name_dic"]   # find name, key-id, value-name
        self.cells_id_dic = {v : k for k, v in self.cells_name_dic.items()} # find id, key-name, value-id
        self.length_per_seg = net_config["cnt_info"]["length_per_seg"]
        if _verbose_enabled():
            _LOGGER.info("length_per_seg=%s", self.length_per_seg)

        dir_info = net_config["dir_info"]
        self.model_dir = dir_info["model_dir"]
        self.cell_param_dir = dir_info["cell_param_dir"]


    def create_cells(self):
        self.ncell = len(self.cells_id_sim)
        self.cells = {}
        for id in self.cells_id_sim:
            cn = self.cells_name_dic[id]
            h.load_file(self.model_dir + cn + ".hoc")
            self.cells[cn] = getattr(h, cn)()


    def define_biophysics(self):
        default_mech_list = ['slo1_unc2_lr', 
                          'egl2_lr', 
                          'shl1_lr', 
                          'kqt3_lr',  
                          'unc2_lr', 
                          'kvs1_lr', 
                          'slo1_egl19_lr', 
                          'slo2_unc2_lr', 
                          'irk_lr', 
                          'egl36_lr', 
                          'egl19_lr', 
                          'cca1_lr', 
                          'shk1_lr', 
                          'slo2_egl19_lr', 
                          'nca_lr', 
                          'kcnl_lr',
                          'cainternm_lr',
                          ]

        # Debug toggles for backend equivalence checking.
        # These are intentionally env-driven so normal training behavior remains unchanged.
        passive_only = os.environ.get("EWORM_DEBUG_PASSIVE_ONLY", "0").strip() == "1"
        mech_list_override = os.environ.get("EWORM_DEBUG_MECH_LIST", "").strip()
        if mech_list_override:
            # Comma-separated list; whitespace allowed.
            self.mech_list = [x.strip() for x in mech_list_override.split(",") if x.strip()]
        elif passive_only:
            self.mech_list = []
        else:
            self.mech_list = default_mech_list

        for id in self.cells_id_sim:
            cn = self.cells_name_dic[id]
            cell = self.cells[cn]
            cell_param = load_json(self.cell_param_dir + cn + ".json")
            
            # set biophysical mechanism
            sec = cell.Soma
            sec.Ra = cell_param["soma"]["Ra"]  # (Ohm*cm)
            sec.cm = cell_param["soma"]["cm"]  # (uF/cm2)
            
            sec.insert('pas')
            for seg in sec:
                seg.pas.g = cell_param["soma"]["gpas"]  # Passive conductance in S/cm2
                seg.pas.e = cell_param["soma"]["epas"]  # Leak reversal potential mV

            for sec in cell.all: # check neuron_name.hoc
                if "Soma" in sec.name():
                    continue
                sec.Ra = cell_param["neurite"]["Ra"]  # (Ohm*cm)
                sec.cm = cell_param["soma"]["cm"]  # (uF/cm2)
                
                sec.insert('pas')
                for seg in sec:
                    seg.pas.g = cell_param["neurite"]["gpas"]  # Passive conductance in S/cm2 
                    seg.pas.e = cell_param["neurite"]["epas"]  # Leak reversal potential mV

            for m in self.mech_list:
                cell.Soma.insert(m)
            for seg in cell.Soma:
                if hasattr(seg, "shl1_lr"):
                    seg.shl1_lr.gbshl1 = cell_param["soma"]["gbshl1"]  # (nS/um2)
                if hasattr(seg, "shk1_lr"):
                    seg.shk1_lr.gbshk1 = cell_param["soma"]["gbshk1"]
                if hasattr(seg, "kvs1_lr"):
                    seg.kvs1_lr.gbkvs1 = cell_param["soma"]["gbkvs1"]
                if hasattr(seg, "egl2_lr"):
                    seg.egl2_lr.gbegl2 = cell_param["soma"]["gbegl2"]
                if hasattr(seg, "egl36_lr"):
                    seg.egl36_lr.gbegl36 = cell_param["soma"]["gbegl36"]
                if hasattr(seg, "kqt3_lr"):
                    seg.kqt3_lr.gbkqt3 = cell_param["soma"]["gbkqt3"]
                if hasattr(seg, "egl19_lr"):
                    seg.egl19_lr.gbegl19 = cell_param["soma"]["gbegl19"]
                if hasattr(seg, "unc2_lr"):
                    seg.unc2_lr.gbunc2 = cell_param["soma"]["gbunc2"]
                if hasattr(seg, "cca1_lr"):
                    seg.cca1_lr.gbcca1 = cell_param["soma"]["gbcca1"]
                if hasattr(seg, "slo1_egl19_lr"):
                    seg.slo1_egl19_lr.gbslo1 = cell_param["soma"]["gbslo1_egl19"]
                if hasattr(seg, "slo1_unc2_lr"):
                    seg.slo1_unc2_lr.gbslo1 = cell_param["soma"]["gbslo1_unc2"]
                if hasattr(seg, "slo2_egl19_lr"):
                    seg.slo2_egl19_lr.gbslo2 = cell_param["soma"]["gbslo2_egl19"]
                if hasattr(seg, "slo2_unc2_lr"):
                    seg.slo2_unc2_lr.gbslo2 = cell_param["soma"]["gbslo2_unc2"]
                if hasattr(seg, "kcnl_lr"):
                    seg.kcnl_lr.gbkcnl = cell_param["soma"]["gbkcnl"]
                if hasattr(seg, "nca_lr"):
                    seg.nca_lr.gbnca = cell_param["soma"]["gbnca"]
                if hasattr(seg, "irk_lr"):
                    seg.irk_lr.gbirk = cell_param["soma"]["gbirk"]


    def set_cell_segments(self):
        """section, segment separation, prepare self.segments"""
        for _, cell in self.cells.items():
            for section in cell.all:
                soma_flag = 'Soma' in section.name()
                section.nseg = 1 if soma_flag else int(np.ceil(section.L / self.length_per_seg))


    def get_3dp_segment(self, nrn_id, seg_id):
        cn = self.cells_name_dic[nrn_id]
        cell = self.cells[cn]
        seg_cnt = 0
        for section in cell.all:
            if seg_id < seg_cnt + section.nseg:
                loading_bar = np.linspace(0, 1, section.nseg + 1)
                loading_cnt = seg_id - seg_cnt
                return section((loading_bar[loading_cnt] + loading_bar[loading_cnt + 1]) / 2)
            seg_cnt += section.nseg


    def read_cells_neurite_connection(self, infile, input_names):
        import pickle
        connection_info = pickle.load(open(infile, 'rb'))
        # connection_info: a dictionary
        # key: id_post_cell
        # value: [[id_pre_cell, id_pre_point, id_post_point, type, weight],[...]]
        #        type = 0/1/2  0-gap junction, 1-excitatory synapse, 2-inhibitory synapse
        # for id in self.cells_id_sim:
        #     print(id)
        #     if id in connection_info.keys():
        #         print(connection_info[id])
        #     print('')
        
        self.relpointlist = {}      # id -> [point, ...] (relative segments, non-repetitive)
        self.synlist = {}           # id_post_cell -> id_post_point -> [SynInfo(id_pre_cell, id_pre_point, syn, p), ...]
        self.synapse_list = []
        for id in self.cells_id_sim:
            cn = self.cells_name_dic[id]
            cell = self.cells[cn]
            self.relpointlist[id] = list(range(cell.Soma.nseg))
            self.synlist[id] = {}
        disable_all_connections = os.environ.get("EWORM_DEBUG_DISABLE_CONNECTIONS", "0").strip() == "1"
        disable_gap = os.environ.get("EWORM_DEBUG_DISABLE_GAP", "0").strip() == "1"
        disable_syn = os.environ.get("EWORM_DEBUG_DISABLE_SYN", "0").strip() == "1"

        for id_post_cell in self.cells_id_sim:  # for each post neuron
            # cn_post = self.cells_name_dic[id_post_cell]
            if id_post_cell not in connection_info.keys():
                continue
            connection = connection_info[id_post_cell]
            for [id_pre_cell, id_pre_point, id_post_point, syntype, w] in connection:
                if id_pre_cell in self.cells_id_sim:  # for each pre neuron
                    cn_pre = self.cells_name_dic[id_pre_cell]
                    if id_post_point not in self.synlist[id_post_cell].keys():
                        self.synlist[id_post_cell][id_post_point] = []
                    if id_post_point not in self.relpointlist[id_post_cell]:
                        self.relpointlist[id_post_cell].append(id_post_point)
                    if id_pre_point not in self.relpointlist[id_pre_cell]:
                        self.relpointlist[id_pre_cell].append(id_pre_point)
                    pre_seg = self.get_3dp_segment(id_pre_cell, id_pre_point)
                    post_seg = self.get_3dp_segment(id_post_cell, id_post_point)
                    if disable_all_connections:
                        continue
                    if syntype == 0: # construct gap junction
                        if disable_gap:
                            continue
                        gapjunction = h.gapjunction_lr(post_seg)
                        gapjunction.w = w
                        gapjunction._ref_vpre = pre_seg._ref_v
                        self.synapse_list.append(gapjunction)
                        self.synlist[id_post_cell][id_post_point].append(SynInfo(id_pre_cell, id_pre_point, gapjunction))
                    elif syntype == 1 or syntype == 2:
                        if disable_syn:
                            continue
                        # construct synapse
                        synapse = h.neuron_to_neuron_syn_lr(post_seg)
                        synapse.w = w * 1e-4 if syntype == 1 else -w * 1e-4
                        synapse._ref_vpre = pre_seg._ref_v
                        self.synapse_list.append(synapse)
                        self.synlist[id_post_cell][id_post_point].append(SynInfo(id_pre_cell, id_pre_point, synapse))
                    else:
                        raise ValueError("Synapse type unrecognized, expect 0/1/2, get", syntype)

        # Flatten per-cell synapse lists for faster iteration in update_dvdw().
        # This avoids nested dict loops each step and does not change semantics.
        self.syninfos_flat = {cid: [] for cid in self.cells_id_sim}
        for cid in self.cells_id_sim:
            for point in self.synlist[cid].keys():
                self.syninfos_flat[cid].extend(self.synlist[cid][point])
        
        disable_inputs = os.environ.get("EWORM_DEBUG_DISABLE_INPUTS", "0").strip() == "1"
        self.N_input = len(input_names)
        self.input_ids = [self.cells_id_dic[cn] for cn in input_names]
        self.input_synlist = {}
        if not disable_inputs:
            for id in self.input_ids:
                seg = self.get_3dp_segment(id, 0)   # Soma 0
                syn = h.IClamp(seg)
                syn.amp = 0.
                syn.delay = 0.
                syn.dur = 1e9
                self.input_synlist[id] = syn

        self.N = 0              # total number of relative segments
        self.K_start = {}       # id -> pstart in self.K
        self.K_n = {}           # id -> number of relative segments
        self.pinput = []        # list of input p in self.K
        self.point2p = {}       # id_cell -> id_point -> [p, ...] in self.K
        self.pwmaskall = []     # list of w fixed p (active channels & pure pre-synaptic) in self.K
        self.pwmask = {}        # id -> [p, ...], w fixed ps for each id
        self.allpointlist = {}  # id -> [point, ...] (can be repetitive, including multiple synapses onto single segment senario)
        for id in self.cells_id_sim:
            cn = self.cells_name_dic[id]
            self.K_start[id] = self.N
            self.relpointlist[id].sort()
            self.allpointlist[id] = []

            # w fixed points, including active channel segments, input segments & pure pre-synaptic segments
            wmask_points = list(range(self.cells[cn].Soma.nseg))    # somatic active channel segments
            if id in self.input_ids:
                self.pinput.append(self.N + len(wmask_points))
                wmask_points.append(0)                              # input onto soma seg 0
            for point in self.relpointlist[id]:
                if point not in self.synlist[id].keys() and point not in wmask_points:  # pure pre-synaptic segments
                   wmask_points.append(point)
            pwmask = list(range(self.N, self.N + len(wmask_points)))
            self.pwmaskall += pwmask
            self.pwmask[id] = pwmask
            self.allpointlist[id] += wmask_points
            
            n_id = len(wmask_points)
            # gap junction & post-synaptic segments
            for point in self.synlist[id].keys():
                n_id += len(self.synlist[id][point])
                self.allpointlist[id] += [point] * len(self.synlist[id][point])
            
            self.K_n[id] = n_id
            self.N += n_id

            self.point2p[id] = {}
            p = self.K_start[id]
            for point in wmask_points:
                if point not in self.point2p[id].keys():
                    self.point2p[id][point] = []
                self.point2p[id][point].append(p)
                p += 1
            for point in self.synlist[id].keys():
                if point not in self.point2p[id].keys():
                    self.point2p[id][point] = []
                for syninfo in self.synlist[id][point]:
                    self.point2p[id][point].append(p)
                    syninfo.p = p
                    p += 1

        # self.w = cp.zeros((self.N,), dtype=cp.float32)
        self.w = torch.zeros((self.N,), dtype=torch.float32)
        self.w[self.pwmaskall] = 1.
        self.pgap = []
        self.psyn = []
        for id in self.cells_id_sim:
            for point in self.synlist[id].keys():
                for syninfo in self.synlist[id][point]:
                    pw = syninfo.p
                    syn = syninfo.syn
                    self.w[pw] = syn.w
                    if 'gapjunction' in syn.hname():
                        self.pgap.append(pw)
                    elif 'syn' in syn.hname():
                        self.psyn.append(pw)

        # print('w gap max: %.5g, min: %.5g'%(np.max(self.w[self.pgap]), np.min(self.w[self.pgap])))
        # print('w syn max: %.5g, min: %.5g'%(np.max(self.w[self.psyn]), np.min(self.w[self.psyn])))
        if len(self.pgap) > 0:
            if _verbose_enabled():
                _LOGGER.info(
                    "w_gap: max=%.5g, min=%.5g",
                    float(np.max(self.w[self.pgap].numpy())),
                    float(np.min(self.w[self.pgap].numpy())),
                )
        else:
            if _verbose_enabled():
                _LOGGER.info("w_gap: none")
        if len(self.psyn) > 0:
            if _verbose_enabled():
                _LOGGER.info(
                    "w_syn: max=%.5g, min=%.5g",
                    float(np.max(self.w[self.psyn].numpy())),
                    float(np.min(self.w[self.psyn].numpy())),
                )
        else:
            if _verbose_enabled():
                _LOGGER.info("w_syn: none")
        
        if _verbose_enabled():
            _LOGGER.info("N(weights)=%s", int(self.N))
        # print(self.K_n.items())

        self.postplist = []
        self.preplist = []
        for id_post in self.cells_id_sim:
            for point_post in self.synlist[id_post].keys():
                for syninfo in self.synlist[id_post][point_post]:
                    id_pre = syninfo.id
                    point_pre = syninfo.point
                    p_post = syninfo.p
                    self.postplist.append(p_post)
                    self.preplist.append(self.point2p[id_pre][point_pre][0])    # first occurance

        # Debug/forward-only mode: allow skipping K construction/loading since it is only needed for dvdw learning.
        # This speeds up NEURON vs HELIOX forward equivalence checks.
        if os.environ.get("EWORM_SKIP_K", "0").strip() == "1":
            return
        self._cal_K()


    def _cal_K(self):
        self.K_len = int(self.K_max_t / (h.dt * self.K_mul))
        if _verbose_enabled():
            _LOGGER.info(
                "K config: K_max_t=%.3f(ms) dt=%.3f(ms) K_mul=%d => K_len=%d (LR ticks)",
                float(self.K_max_t),
                float(h.dt),
                int(self.K_mul),
                int(self.K_len),
            )
            _LOGGER.info("K cache target: %s", self.K_filename)

        def split_integer(m, n):
            assert n > 0
            quotient = m // n
            remainder = m % n
            if remainder > 0:
                return [quotient] * (n - remainder) + [quotient + 1] * remainder
            if remainder < 0:
                return [quotient - 1] * -remainder + [quotient] * (n + remainder)
            return  [quotient] * n
        
        id_split = split_integer(self.ncell, self.K_nblock)     # split ncell evenly into K_nblock
        self.K_block_n = []         # number of relative segments in each block
        self.K_block_n_start = []   # start of each split for block
        self.id2block = {}          # cell id -> block id
        self.block2ids = {}         # block id -> [cell id, ...]
        cell_id = 0
        block_start = 0
        for i in range(self.K_nblock):
            self.K_block_n_start.append(block_start)
            self.block2ids[i] = []
            block_n = 0
            for _ in range(id_split[i]):
                block_n += self.K_n[self.cells_id_sim[cell_id]]
                self.id2block[self.cells_id_sim[cell_id]] = i
                self.block2ids[i].append(self.cells_id_sim[cell_id])
                cell_id += 1
            self.K_block_n.append(block_n)
            block_start += block_n
        if _verbose_enabled():
            _LOGGER.info("K id_split=%s", id_split)
        # print(self.id2block)
        # print(self.block2ids)
        if _verbose_enabled():
            _LOGGER.info("K_block_n=%s", self.K_block_n)
            _LOGGER.info("K_block_n_start=%s", self.K_block_n_start)

        self.K_gpu_n = split_integer(self.N, self.ngpu)   # split N evenly into ngpu
        if _verbose_enabled():
            _LOGGER.info("K_gpu_n=%s", self.K_gpu_n)
        gpu_start = 0
        self.K_gpu_n_start = []     # start of each split for gpu
        for i in range(self.ngpu):
            self.K_gpu_n_start.append(gpu_start)
            gpu_start += self.K_gpu_n[i]
        if _verbose_enabled():
            _LOGGER.info("K_gpu_n_start=%s", self.K_gpu_n_start)
        
        try:
            if _verbose_enabled():
                _LOGGER.info("read K from %s", self.K_filename)
            with open(self.K_filename, 'rb') as f:
                # tmp_K = np.load(f, allow_pickle=True)['K']
                tmp_K = pickle.load(f)
                if _verbose_enabled():
                    _LOGGER.info("load K done")
            assert len(tmp_K) == self.K_nblock, 'Unexpected number of blocks read from K'
            tmp_shape = 0
            for i in range(self.K_nblock):
                ki = tmp_K[i]
                assert ki.shape[0] == ki.shape[1] and ki.shape[2] == self.K_len, f'Unexpected shape of block {i} read from K'
                tmp_shape += ki.shape[0]
            assert tmp_shape == self.N, 'Unexpected total number of relative segments read from K'
        except Exception as e:
            if _verbose_enabled():
                if isinstance(e, FileNotFoundError):
                    _LOGGER.info("%s not found, generating K (K_mul=%s)", self.K_filename, self.K_mul)
                else:
                    _LOGGER.info("K load failed (%s), regenerating K (K_mul=%s)", e, self.K_mul)

            tmp_K = None

            # Fast path: synthesize K for the requested K_mul by resampling an existing K file.
            #
            # Rationale:
            # - Computing K from scratch is extremely expensive (runs many NEURON simulations).
            # - The demo typically ships with a single precomputed K (historically K_mul=5).
            # - For other K_mul values (1/2/10/...), we can derive a compatible K by resampling
            #   the time axis and cache the result to disk.
            k_dir = os.path.dirname(self.K_filename) or "."
            alt_path = None
            alt_mul = None
            try:
                target_t = int(round(float(self.K_max_t)))
                prefer_candidates = [
                    os.path.join(k_dir, f"K_eworm_v4_x5_t{target_t}.npz"),
                    os.path.join(k_dir, "K_eworm_v4_x5.npz"),
                ]
                if int(self.K_mul) != 5:
                    for prefer in prefer_candidates:
                        if os.path.exists(prefer):
                            alt_path = prefer
                            alt_mul = 5
                            break

                if alt_path is None:
                    for name in os.listdir(k_dir):
                        m = re.match(r"^K_eworm_v4_x(\d+)(?:_t(\d+))?\.npz$", name)
                        if not m:
                            continue
                        m_mul = int(m.group(1))
                        if m_mul == int(self.K_mul):
                            continue
                        alt_path = os.path.join(k_dir, name)
                        alt_mul = m_mul
                        break
            except Exception:
                alt_path = None
                alt_mul = None

            if alt_path is not None and alt_mul is not None:
                try:
                    if _verbose_enabled():
                        _LOGGER.info("synthesizing K from %s (K_mul=%s)", alt_path, alt_mul)
                    with open(alt_path, "rb") as f:
                        src_K = pickle.load(f)
                    assert len(src_K) == self.K_nblock, "Unexpected number of blocks in source K"
                    old_len = int(src_K[0].shape[2])
                    new_len = int(self.K_len)

                    dt_ms = float(h.dt)
                    old_step = dt_ms * float(alt_mul)
                    new_step = dt_ms * float(self.K_mul)
                    if old_step <= 0 or new_step <= 0:
                        raise RuntimeError(f"invalid dt/step: dt={dt_ms} old_step={old_step} new_step={new_step}")

                    # Only synthesize when the source K covers the target horizon.
                    #
                    # - If target <= source: safe (we are effectively truncating the time window).
                    # - If target > source: unsafe (would require padding/extrapolation), so we refuse and
                    #   fall back to NEURON K computation.
                    target_horizon = float(self.K_max_t)
                    old_horizon = float(old_len) * float(old_step)
                    tol = float(max(old_step, new_step) * 2.0)
                    if (target_horizon - old_horizon) > tol:
                        raise RuntimeError(
                            f"source K horizon too short: src≈{old_horizon:.3f}ms, "
                            f"target={target_horizon:.3f}ms (tol={tol:.3f}ms)"
                        )

                    pos = (np.arange(new_len, dtype=np.float32) * np.float32(new_step)) / np.float32(old_step)
                    pos = np.clip(pos, 0.0, float(old_len - 1))
                    i0 = np.floor(pos).astype(np.int64)
                    i1 = np.minimum(i0 + 1, old_len - 1).astype(np.int64)
                    w = (pos - i0.astype(np.float32)).astype(np.float32)
                    w0 = (1.0 - w).astype(np.float32)[None, None, :]
                    w1 = w.astype(np.float32)[None, None, :]

                    tmp_K = []
                    t0 = time.time()
                    for bi in range(self.K_nblock):
                        blk = np.asarray(src_K[bi])
                        if blk.shape[2] != old_len:
                            raise RuntimeError(f"block {bi} len mismatch: {blk.shape[2]} != {old_len}")
                        blk_f = blk.astype(np.float32, copy=False)
                        v0 = np.take(blk_f, i0, axis=2)
                        v1 = np.take(blk_f, i1, axis=2)
                        out = v0 * w0 + v1 * w1
                        tmp_K.append(out.astype(blk.dtype, copy=False))
                        if _k_progress_enabled() and _k_progress_every_blocks() > 0:
                            every = _k_progress_every_blocks()
                            if ((bi + 1) % every) == 0 or (bi + 1) == self.K_nblock:
                                elapsed = time.time() - t0
                                frac = float(bi + 1) / float(self.K_nblock)
                                eta = (elapsed * (1.0 / frac - 1.0)) if frac > 0 else 0.0
                                _LOGGER.info(
                                    "K synth progress: block %d/%d (%.1f%%), elapsed=%.1fs, eta=%.1fs",
                                    bi + 1,
                                    self.K_nblock,
                                    frac * 100.0,
                                    elapsed,
                                    eta,
                                )

                    if _verbose_enabled():
                        _LOGGER.info("save synthesized K to %s", self.K_filename)
                    os.makedirs(k_dir, exist_ok=True)
                    with open(self.K_filename, "wb") as f:
                        pickle.dump(tmp_K, f, protocol=4)
                    if _verbose_enabled():
                        _LOGGER.info("synthesized K saved")
                except Exception as e:
                    if _verbose_enabled():
                        _LOGGER.info("synthesize K failed (%s); falling back to NEURON K computation", e)
                    tmp_K = None

            if tmp_K is None:
                if _verbose_enabled():
                    _LOGGER.info("Computing K via NEURON (slow path). This may take a long time.")
                k_start_t = time.time()
                last_log_t = k_start_t
                def get_3dp_segment_tmp(cell, seg_id):
                    seg_cnt = 0
                    for section in cell.all:
                        if seg_id < seg_cnt + section.nseg:
                            loading_bar = np.linspace(0, 1, section.nseg + 1)
                            loading_cnt = seg_id - seg_cnt
                            return section((loading_bar[loading_cnt] + loading_bar[loading_cnt + 1]) / 2)
                        seg_cnt += section.nseg

                tmp_K = []
                for block_n in self.K_block_n:
                    tmp_K.append(np.zeros((block_n, block_n, self.K_len), dtype=np.float16))
                total_cells = len(self.cells_id_sim)
                for cell_i, id in enumerate(self.cells_id_sim):  # for each post neuron
                    if _k_progress_enabled():
                        print(f'cell {id} ({cell_i+1}/{total_cells})', end='\r')
                        now = time.time()
                        if (now - last_log_t) > 5.0:
                            elapsed = now - k_start_t
                            frac = float(cell_i + 1) / float(max(1, total_cells))
                            eta = (elapsed * (1.0 / frac - 1.0)) if frac > 0 else 0.0
                            _LOGGER.info(
                                "K compute progress: %d/%d (%.1f%%), elapsed=%.1fs, eta=%.1fs",
                                cell_i + 1,
                                total_cells,
                                frac * 100.0,
                                elapsed,
                                eta,
                            )
                            last_log_t = now
                    tmpcn = self.cells_name_dic[id]
                    tmpcell = getattr(h, tmpcn)()
                    cell_param = load_json(self.cell_param_dir + tmpcn + ".json")
                    for section in tmpcell.all:
                        soma_flag = 'Soma' in section.name()
                        section.nseg = 1 if soma_flag else int(np.ceil(section.L / self.length_per_seg))

                    # set biophysical mechanism
                    sec = tmpcell.Soma
                    sec.Ra = cell_param["soma"]["Ra"]  # (Ohm*cm)
                    sec.cm = cell_param["soma"]["cm"]  # (uF/cm2)
                    sec.insert('pas')
                    for seg in sec:
                        seg.pas.g = cell_param["soma"]["gpas"]  # Passive conductance in S/cm2
                        seg.pas.e = cell_param["soma"]["epas"]  # Leak reversal potential mV

                    for sec in tmpcell.all:  # check neuron_name.hoc
                        if "Soma" in sec.name():
                            continue
                        sec.Ra = cell_param["neurite"]["Ra"]  # (Ohm*cm)
                        sec.cm = cell_param["neurite"]["cm"]  # (uF/cm2)
                        sec.insert('pas')
                        for seg in sec:
                            seg.pas.g = cell_param["neurite"]["gpas"]  # Passive conductance in S/cm2
                            seg.pas.e = cell_param["neurite"]["epas"]  # Leak reversal potential mV

                    tmpseglist = []
                    tmpvlist = []
                    for point in self.allpointlist[id]:
                        seg = get_3dp_segment_tmp(tmpcell, point)
                        tmpseglist.append(seg)
                        tmpvlist.append(h.Vector().record(seg._ref_v))

                    block_id = self.id2block[id]
                    block_ids = self.block2ids[block_id]
                    offset_in_block = 0
                    for id_same_block in block_ids:
                        if int(id_same_block) < int(id):
                            offset_in_block += self.K_n[id_same_block]
                    old_dt = h.dt
                    h.dt = old_dt * self.K_mul
                    for i in range(self.K_n[id]):
                        tmpclamp = h.IClamp(tmpseglist[i])
                        tmpclamp.delay = 500.
                        tmpclamp.dur = h.dt
                        tmpclamp.amp = 1. / h.dt
                        h.finitialize(cell_param["soma"]["epas"])
                        h.continuerun(self.K_max_t + tmpclamp.delay)
                        for j in range(self.K_n[id]):
                            v = np.array(tmpvlist[j], dtype=np.float16)
                            i_start = int(tmpclamp.delay / h.dt)    # assert divisible
                            v_rev = v[i_start]
                            tmp_K[block_id][offset_in_block + i, offset_in_block + j] = v[i_start + 1: i_start + 1 + self.K_len][::-1] - v_rev  # already reversed
                        tmpclamp.amp = 0.
                    h.dt = old_dt

                if _verbose_enabled():
                    _LOGGER.info("save K to %s", self.K_filename)
                with open(self.K_filename, 'wb') as f:
                    pickle.dump(tmp_K, f, protocol=4)
                if _k_progress_enabled():
                    print("")
                if _verbose_enabled():
                    _LOGGER.info("save K done")

        # Keep a CPU copy around for backends that want to upload K blocks themselves (e.g. HELIOX replay).
        # `tmp_K` is a list of numpy arrays shaped (bn, bn, K_len).
        self.K_blocks_cpu = tmp_K

        for i in range(self.ngpu):
            # with cp.cuda.Device(i):
            #     setattr(self, 'K' + str(i), [cp.asarray(ki, dtype=cp.float16) for ki in tmp_K])  # copy of self.Ki on gpu i
            k_dtype_name = os.environ.get("EWORM_K_DTYPE", "float32").strip().lower()
            k_dtype = torch.float32
            if k_dtype_name in ("float16", "fp16", "half"):
                k_dtype = torch.float16
            setattr(
                self,
                "K" + str(i),
                [torch.asarray(ki, dtype=k_dtype, device=torch.device(f"cuda:{i}")) for ki in tmp_K],
            )  # copy of self.Ki on gpu i

    
    def set_outputs(self, output_names):
        self.N_output = len(output_names)
        self.output_ids = [self.cells_id_dic[cn] for cn in output_names]
        if _verbose_enabled():
            _LOGGER.info("outputs=%s output_ids=%s", list(output_names), self.output_ids)
        self.poutput = []
        for id in self.output_ids:
            self.poutput.append(self.point2p[id][0][0])     # first occurance

    def attach_heliox_backend(self, backend) -> None:
        self._heliox_backend = backend
    

    def _reset_lr_records(self):
        # self.It = cp.array([], dtype=cp.float32).reshape((self.N, 0)) # shape (N, min(tstep, K_len))
        self.It = torch.tensor([], dtype=torch.float32, device=torch.device('cuda:0')).reshape((self.N, 0)) # shape (N, min(tstep, K_len))
        for i in range(self.ngpu):
            # with cp.cuda.Device(i):
            #     setattr(self, 'dItdv' + str(i), cp.array([], dtype=cp.float32).reshape((self.N, 0)))    # shape (N, min(tstep, K_len)), copy of self.dItdv on gpu i
            #     setattr(self, 'dItdvpre' + str(i), cp.array([], dtype=cp.float32).reshape((self.N, 0))) # shape (N, min(tstep, K_len)), copy of self.dItdvpre on gpu i
            #     setattr(self, 'dVtdw' + str(i), cp.zeros((self.K_gpu_n[i], self.N, 1), dtype=cp.float32))   # shape (K_gpu_n[i], N, min(tstep, K_len)), [i, j] = dvjdwi
            #     setattr(self, 'dVpretdw' + str(i), cp.zeros((self.K_gpu_n[i], self.N, 1), dtype=cp.float32))   # shape (K_gpu_n[i], N, min(tstep, K_len)), [i, j] = dvprejdwi
            setattr(self, 'dItdv' + str(i), torch.tensor([], dtype=torch.float32, device=torch.device(f'cuda:{i}')).reshape((self.N, 0)))    # shape (N, min(tstep, K_len)), copy of self.dItdv on gpu i
            setattr(self, 'dItdvpre' + str(i), torch.tensor([], dtype=torch.float32, device=torch.device(f'cuda:{i}')).reshape((self.N, 0))) # shape (N, min(tstep, K_len)), copy of self.dItdvpre on gpu i
            setattr(self, 'dVtdw' + str(i), torch.zeros((self.K_gpu_n[i], self.N, 1), dtype=torch.float32, device=torch.device(f'cuda:{i}')))   # shape (K_gpu_n[i], N, min(tstep, K_len)), [i, j] = dvjdwi
            setattr(self, 'dVpretdw' + str(i), torch.zeros((self.K_gpu_n[i], self.N, 1), dtype=torch.float32, device=torch.device(f'cuda:{i}')))   # shape (K_gpu_n[i], N, min(tstep, K_len)), [i, j] = dvprejdwi

        # Pre-allocate the full-length recording buffers for this run to avoid O(T^2) dstack growth.
        # Number of calls to update_dvdw per run ~= int(tstop/dt)//K_mul.
        # Keep a leading zero column to match the original behavior (starts as zeros(...,1) then dstack).
        total_steps = int(float(h.tstop) / float(h.dt)) if float(h.dt) > 0 else 0
        if total_steps <= 0:
            total_steps = 1
        ksteps_total = int(total_steps // self.K_mul)
        self._lr_buf_len = int(ksteps_total + 1)
        self._lr_valid_len = 1
        self.dVoutputtdw = torch.empty((self.N, self.N_output, self._lr_buf_len), dtype=torch.float32) # CPU
        self.dVinputtdw = torch.empty((self.N, self.N_input, self._lr_buf_len), dtype=torch.float32)   # CPU
        self.dVoutputtdw[:, :, 0].zero_()
        self.dVinputtdw[:, :, 0].zero_()

        self.grad_l2norm_thresold = 1.e6
        self.grad_scale = 1.

        self.streams = [torch.cuda.Stream(torch.device(f'cuda:{i}')) for i in range(self.ngpu)]

    def update_dvdw_from_signals_batch(
        self,
        it_lr_cpu: np.ndarray,
        ditdv_lr_cpu: np.ndarray | None = None,
        ditdvpre_lr_cpu: np.ndarray | None = None,
        *,
        percise: bool = True,
    ) -> None:
        """
        Replay mode: update dv/dw for the whole run from pre-recorded signal matrices.

        This replaces the outer Python loop in `worm_training_impl.py` and avoids per-step
        numpy->CUDA transfers and CPU-side dvtdw assembly.

        Parameters:
        - it_lr_cpu: shape (N, K_steps+1)
        - ditdv_lr_cpu/ditdvpre_lr_cpu: shape (N, K_steps+1), required when percise=True
        """
        if it_lr_cpu is None:
            raise ValueError("it_lr_cpu must be provided")
        it_lr_cpu = np.asarray(it_lr_cpu, dtype=np.float32)
        if it_lr_cpu.ndim != 2 or it_lr_cpu.shape[0] != int(self.N):
            raise ValueError(f"expected it_lr_cpu shape (N, T), got {it_lr_cpu.shape}")
        if int(it_lr_cpu.shape[1]) < 2:
            return

        if percise:
            if ditdv_lr_cpu is None or ditdvpre_lr_cpu is None:
                raise ValueError("ditdv_lr_cpu and ditdvpre_lr_cpu required when percise=True")
            ditdv_lr_cpu = np.asarray(ditdv_lr_cpu, dtype=np.float32)
            ditdvpre_lr_cpu = np.asarray(ditdvpre_lr_cpu, dtype=np.float32)
            if ditdv_lr_cpu.shape != it_lr_cpu.shape or ditdvpre_lr_cpu.shape != it_lr_cpu.shape:
                raise ValueError("ditdv_lr_cpu/ditdvpre_lr_cpu must have the same shape as it_lr_cpu")

        # Feature flag: allow quick rollback to the slow per-step API.
        use_fast = os.environ.get("EWORM_FAST_DVDW", "1").strip() != "0"
        if not use_fast:
            ksteps_total = int(it_lr_cpu.shape[1] - 1)
            for t_lr in range(1, ksteps_total + 1):
                self.update_dvdw_from_signals(
                    t_lr,
                    it_lr_cpu[:, t_lr],
                    ditdv_lr_cpu[:, t_lr] if percise else None,
                    ditdvpre_lr_cpu[:, t_lr] if percise else None,
                    percise=percise,
                )
            return

        if int(self.ngpu) != 1:
            # Keep multi-GPU behavior unchanged for now to avoid correctness risks.
            ksteps_total = int(it_lr_cpu.shape[1] - 1)
            for t_lr in range(1, ksteps_total + 1):
                self.update_dvdw_from_signals(
                    t_lr,
                    it_lr_cpu[:, t_lr],
                    ditdv_lr_cpu[:, t_lr] if percise else None,
                    ditdvpre_lr_cpu[:, t_lr] if percise else None,
                    percise=percise,
                )
            return

        if self._lr_valid_len is None or self._lr_buf_len is None:
            raise RuntimeError("lr buffers not initialized; call _reset_lr_records() first")

        device = torch.device("cuda:0")
        ksteps_total = int(it_lr_cpu.shape[1] - 1)

        replay_compute_dtype_name = os.environ.get("EWORM_REPLAY_COMPUTE_DTYPE", "float32").strip().lower()
        replay_compute_dtype = torch.float32
        if replay_compute_dtype_name in ("float16", "fp16", "half"):
            replay_compute_dtype = torch.float16

        # Move the whole signal matrices to GPU once.
        it_lr = torch.from_numpy(it_lr_cpu).to(device)
        if percise:
            ditdv_lr = torch.from_numpy(ditdv_lr_cpu).to(device)
            ditdvpre_lr = torch.from_numpy(ditdvpre_lr_cpu).to(device)
        else:
            ditdv_lr = None
            ditdvpre_lr = None

        # Optional: run replay math in fp16 for performance.
        # NOTE: dvtdw itself stays fp32; we only change the history/signal buffers and K dtype.
        if replay_compute_dtype == torch.float16:
            it_lr = it_lr.to(replay_compute_dtype)
            if percise:
                ditdv_lr = ditdv_lr.to(replay_compute_dtype)
                ditdvpre_lr = ditdvpre_lr.to(replay_compute_dtype)

        K_len = int(self.K_len)
        if K_len <= 0:
            raise RuntimeError(f"invalid K_len={K_len}")

        # Allocate replay ring buffers (double-length) on GPU for contiguous window slicing.
        # NOTE: dVtdw/dVpretdw are large; we reuse buffers across epochs.
        def _need_alloc(attr: str, shape, *, dtype: torch.dtype | None = None) -> bool:
            if not hasattr(self, attr):
                return True
            t = getattr(self, attr)
            if not (hasattr(t, "shape") and tuple(t.shape) == tuple(shape) and t.device == device):
                return True
            if dtype is not None and getattr(t, "dtype", None) != dtype:
                return True
            return False

        it_buf_shape = (int(self.N), 2 * K_len)
        if _need_alloc("_replay_it_buf", it_buf_shape, dtype=replay_compute_dtype):
            self._replay_it_buf = torch.zeros(it_buf_shape, dtype=replay_compute_dtype, device=device)
            self._replay_it_pos = 0

        if percise:
            if _need_alloc("_replay_ditdv_buf", it_buf_shape, dtype=replay_compute_dtype):
                self._replay_ditdv_buf = torch.zeros(it_buf_shape, dtype=replay_compute_dtype, device=device)
                self._replay_ditdvpre_buf = torch.zeros(it_buf_shape, dtype=replay_compute_dtype, device=device)
                self._replay_ditdv_pos = 0

        dvbuf_dtype_name = os.environ.get("EWORM_REPLAY_DVBUF_DTYPE", "float32").strip().lower()
        dvbuf_dtype = torch.float32
        if dvbuf_dtype_name in ("float16", "fp16", "half"):
            dvbuf_dtype = torch.float16

        dV_buf_shape = (int(self.N), int(self.N), 2 * K_len)
        if _need_alloc("_replay_dVtdw_buf", dV_buf_shape, dtype=dvbuf_dtype):
            self._replay_dVtdw_buf = torch.zeros(dV_buf_shape, dtype=dvbuf_dtype, device=device)
            self._replay_dVpretdw_buf = torch.zeros(dV_buf_shape, dtype=dvbuf_dtype, device=device)
            # dVtdw history has an implicit leading zero slice (t=0) in the original code.
            self._replay_dV_pos = 1

        if _need_alloc("_replay_dvtdw", (int(self.N), int(self.N)), dtype=torch.float32):
            self._replay_dvtdw = torch.zeros((int(self.N), int(self.N)), dtype=torch.float32, device=device)
            self._replay_dvpretdw = torch.zeros_like(self._replay_dvtdw)

        # Cache index tensors for dvpretdw column remap.
        if not hasattr(self, "_replay_pre_idx") or not hasattr(self, "_replay_post_idx"):
            self._replay_pre_idx = torch.tensor(self.preplist, dtype=torch.long, device=device)
            self._replay_post_idx = torch.tensor(self.postplist, dtype=torch.long, device=device)

        # Place dVoutputtdw/dVinputtdw on GPU in fast mode to avoid per-step D2H writes.
        # (get_dw_dx handles both CPU and GPU buffers.)
        if (
            hasattr(self, "dVoutputtdw")
            and hasattr(self.dVoutputtdw, "device")
            and self.dVoutputtdw.device.type == "cpu"
        ):
            # Only migrate if the buffers are large enough for this run.
            self.dVoutputtdw = self.dVoutputtdw.to(device, non_blocking=True)
            self.dVinputtdw = self.dVinputtdw.to(device, non_blocking=True)

        # Reset valid length (keep the leading zero column at t=0).
        self._lr_valid_len = 1

        # Reset replay buffers per run (do not carry state across epochs).
        # Positions are offset by one: It/ditdv includes current tick, while dVtdw/dVpretdw
        # history used in the conv term includes up to previous tick (starts with a leading 0).
        self._replay_it_buf.zero_()
        self._replay_it_pos = 0
        it_pos = 0
        if percise:
            self._replay_ditdv_buf.zero_()
            self._replay_ditdvpre_buf.zero_()
            self._replay_ditdv_pos = 0
            ditdv_pos = 0
        else:
            ditdv_pos = 0
        self._replay_dVtdw_buf.zero_()
        self._replay_dVpretdw_buf.zero_()
        self._replay_dV_pos = 1
        dV_pos = 1

        # Precompute block ranges for slicing.
        block_starts = list(self.K_block_n_start)
        block_ns = list(self.K_block_n)
        K_blocks = getattr(self, "K0")

        poutput_gpu = torch.tensor(self.poutput, dtype=torch.long, device=device)
        pinput_gpu = torch.tensor(self.pinput, dtype=torch.long, device=device)

        use_replay_jit = percise and (os.environ.get("EWORM_DVDW_REPLAY_JIT", "1").strip() != "0")
        if use_replay_jit:
            try:
                jit_fn = _get_dvdw_replay_jit_fn()
                it_pos, ditdv_pos, dV_pos, lr_valid_len, grad_scale = jit_fn(
                    it_lr,
                    ditdv_lr,
                    ditdvpre_lr,
                    K_blocks,
                    [int(x) for x in block_starts],
                    [int(x) for x in block_ns],
                    self._replay_pre_idx,
                    self._replay_post_idx,
                    poutput_gpu,
                    pinput_gpu,
                    self._replay_it_buf,
                    self._replay_ditdv_buf if percise else self._replay_it_buf,
                    self._replay_ditdvpre_buf if percise else self._replay_it_buf,
                    self._replay_dVtdw_buf,
                    self._replay_dVpretdw_buf,
                    self._replay_dvtdw,
                    self._replay_dvpretdw,
                    self.dVoutputtdw,
                    self.dVinputtdw,
                    int(ksteps_total),
                    int(K_len),
                    float(h.dt),
                    float(self.grad_scale),
                    float(self.grad_l2norm_thresold),
                    int(self._lr_valid_len),
                    int(self._lr_buf_len),
                )
                self._replay_it_pos = int(it_pos)
                self._replay_ditdv_pos = int(ditdv_pos)
                self._replay_dV_pos = int(dV_pos)
                self._lr_valid_len = int(lr_valid_len)
                self.grad_scale = float(grad_scale)
                return
            except Exception as e:
                # Fall back to the Python loop (correctness-first).
                _LOGGER.warning("[EWORM] replay JIT disabled; falling back to Python loop: %s", e)

        for tstep in range(1, ksteps_total + 1):
            t_window = int(min(tstep, K_len))

            # --- Append It/ditdv signals (current tick) ---
            it_t = it_lr[:, tstep]
            self._replay_it_buf[:, it_pos] = it_t
            self._replay_it_buf[:, it_pos + K_len] = it_t
            it_pos = (it_pos + 1) % K_len

            if percise:
                ditdv_t = ditdv_lr[:, tstep]
                ditdvpre_t = ditdvpre_lr[:, tstep]
                self._replay_ditdv_buf[:, ditdv_pos] = ditdv_t
                self._replay_ditdv_buf[:, ditdv_pos + K_len] = ditdv_t
                self._replay_ditdvpre_buf[:, ditdv_pos] = ditdvpre_t
                self._replay_ditdvpre_buf[:, ditdv_pos + K_len] = ditdvpre_t
                ditdv_pos = (ditdv_pos + 1) % K_len

            # Window ends (exclusive) in the double buffers.
            it_end = it_pos + K_len
            dV_end = dV_pos + K_len
            it_win = self._replay_it_buf[:, it_end - t_window: it_end]
            if percise:
                ditdv_end = ditdv_pos + K_len
                ditdv_win = self._replay_ditdv_buf[:, ditdv_end - t_window: ditdv_end]
                ditdvpre_win = self._replay_ditdvpre_buf[:, ditdv_end - t_window: ditdv_end]
                dVtdw_win = self._replay_dVtdw_buf[:, :, dV_end - t_window: dV_end]
                dVpretdw_win = self._replay_dVpretdw_buf[:, :, dV_end - t_window: dV_end]

            # --- Compute dvtdw on GPU (reuse buffer) ---
            dvtdw = self._replay_dvtdw
            dvtdw.zero_()

            dt = float(h.dt)
            grad_scale = float(self.grad_scale)

            # Base (block-diagonal) term.
            for b, (start, bn) in enumerate(zip(block_starts, block_ns)):
                end = start + bn
                It_b = it_win[start:end, :]
                K_b = K_blocks[b][:, :, -t_window:]
                dv_block = torch.einsum("ijt,it->ij", K_b, It_b) * dt * grad_scale
                dvtdw[start:end, start:end] = dv_block

            # Precise correction term.
            if percise:
                for b, (start, bn) in enumerate(zip(block_starts, block_ns)):
                    end = start + bn
                    K_b = K_blocks[b][:, :, -t_window:]
                    dItdv_b = ditdv_win[start:end, :]
                    dItdvpre_b = ditdvpre_win[start:end, :]
                    dVtdw_b = dVtdw_win[:, start:end, :]
                    dVpretdw_b = dVpretdw_win[:, start:end, :]
                    dItdw_b = dVtdw_b * dItdv_b[torch.newaxis, :, :] + dVpretdw_b * dItdvpre_b[torch.newaxis, :, :]
                    dv_corr = torch.einsum("ikt,jkt->ij", dItdw_b, K_b) * dt
                    dvtdw[:, start:end] += dv_corr

            # Match original grad clip behavior (rare in practice).
            dvtdw_l2norm = float(torch.linalg.norm(dvtdw, ord="fro").item())
            if not math.isfinite(dvtdw_l2norm):
                _sanitize_nonfinite_(dvtdw, name="dvtdw (fast replay)", action=_DVDW_NAN_ACTION)
                dvtdw_l2norm = float(torch.linalg.norm(dvtdw, ord="fro").item())
            if dvtdw_l2norm > self.grad_l2norm_thresold:
                scaler = float(self.grad_l2norm_thresold) / dvtdw_l2norm
                dvtdw.mul_(scaler)
                self.dVoutputtdw.mul_(scaler)
                self.grad_scale *= scaler
                self._replay_dVtdw_buf.mul_(scaler)
                self._replay_dVpretdw_buf.mul_(scaler)

            # dvpretdw (GPU)
            dvpretdw = self._replay_dvpretdw
            dvpretdw.zero_()
            dvpretdw[:, self._replay_post_idx] = dvtdw[:, self._replay_pre_idx]

            # --- Append dvtdw/dvpretdw into the history buffers (dV_pos points to next write slot) ---
            self._replay_dVtdw_buf[:, :, dV_pos] = dvtdw
            self._replay_dVtdw_buf[:, :, dV_pos + K_len] = dvtdw
            self._replay_dVpretdw_buf[:, :, dV_pos] = dvpretdw
            self._replay_dVpretdw_buf[:, :, dV_pos + K_len] = dvpretdw
            dV_pos = (dV_pos + 1) % K_len

            # Record output/input sensitivities for this LR tick (GPU).
            if self._lr_valid_len < self._lr_buf_len:
                self.dVoutputtdw[:, :, self._lr_valid_len] = dvtdw[:, poutput_gpu]
                self.dVinputtdw[:, :, self._lr_valid_len] = dvtdw[:, pinput_gpu]
                self._lr_valid_len += 1

        # Persist last positions (mainly for debugging/inspection).
        self._replay_it_pos = it_pos
        if percise:
            self._replay_ditdv_pos = ditdv_pos
        self._replay_dV_pos = dV_pos

    def update_dvdw(self, tstep, percise=True):
        # called after each timestep advance
        if self._heliox_backend is not None:
            if not hasattr(self, "_it_cpu") or getattr(self._it_cpu, "shape", None) != (self.N,):
                self._it_cpu = np.zeros((self.N,), dtype=np.float32)
                self._ditdv_cpu = np.zeros((self.N,), dtype=np.float32)
                self._ditdvpre_cpu = np.zeros((self.N,), dtype=np.float32)
            self._heliox_backend.fill_it(self._it_cpu)
            it = torch.from_numpy(self._it_cpu).to(torch.device("cuda:0"))
        else:
            itlist = []
            for id in self.cells_id_sim:
                cn = self.cells_name_dic[id]
                it_local = [0.] * len(self.pwmask[id])
                for id_seg, seg in enumerate(self.cells[cn].Soma):
                    seg_area = seg.area()
                    for ch_name in self.mech_list:
                        it_local[id_seg] += getattr(seg, ch_name).pure_i * seg_area * 1e-2
                if id in self.input_ids:
                    syn = self.input_synlist[id]
                    if 'syn' in syn.hname():
                        it_local[self.cells[cn].Soma.nseg] += syn.pure_i
                    else:
                        it_local[self.cells[cn].Soma.nseg] += syn.i
                for syninfo in self.syninfos_flat[id]:
                    it_local.append(syninfo.syn.pure_i)
                itlist.append(torch.asarray(it_local, dtype=torch.float32))
            it = torch.concatenate(itlist).to(torch.device('cuda:0')) # at t-1
        self.It = torch.hstack([self.It, it[:, torch.newaxis]])[:, -self.K_len:] # shape (N, min(tstep, K_len))

        if percise:
            if self._heliox_backend is not None:
                self._heliox_backend.fill_ditdv(self._ditdv_cpu, self._ditdvpre_cpu)
                self.ditdv = torch.from_numpy(self._ditdv_cpu)
                self.ditdvpre = torch.from_numpy(self._ditdvpre_cpu)
            else:
                ditdvlist = []
                for id in self.cells_id_sim:
                    cn = self.cells_name_dic[id]
                    ditdv_local = [0.] * len(self.pwmask[id])
                    for id_seg, seg in enumerate(self.cells[cn].Soma):
                        seg_area = seg.area()
                        for ch_name in self.mech_list:
                            ditdv_local[id_seg] += getattr(seg, ch_name).didv * seg_area * 1e-2
                    if id in self.input_ids:
                        syn = self.input_synlist[id]
                        if 'syn' in syn.hname():
                            ditdv_local[self.cells[cn].Soma.nseg] += syn.didv
                    for syninfo in self.syninfos_flat[id]:
                        ditdv_local.append(syninfo.syn.didv)
                    ditdvlist.append(torch.asarray(ditdv_local, dtype=torch.float32))
                self.ditdv = torch.concatenate(ditdvlist)
            for i in range(self.ngpu):
                with torch.cuda.stream(self.streams[i]):
                    setattr(self, f'dItdv{i}', torch.hstack([getattr(self, f'dItdv{i}'), self.ditdv.to(torch.device(f'cuda:{i}'), non_blocking=True)[:, torch.newaxis]])[:, -self.K_len:])

            if self._heliox_backend is None:
                ditdvprelist = []
                for id in self.cells_id_sim:
                    ditdvpre_local = [0.] * len(self.pwmask[id])
                    for syninfo in self.syninfos_flat[id]:
                        ditdvpre_local.append(syninfo.syn.didvpre)
                    ditdvprelist.append(torch.asarray(ditdvpre_local, dtype=torch.float32))
                self.ditdvpre = torch.concatenate(ditdvprelist)
            for i in range(self.ngpu):
                with torch.cuda.stream(self.streams[i]):
                    setattr(self, f'dItdvpre{i}', torch.hstack([getattr(self, f'dItdvpre{i}'), self.ditdvpre.to(torch.device(f'cuda:{i}'), non_blocking=True)[:, torch.newaxis]])[:, -self.K_len:])
        
        # It_split = cp.vsplit(self.It, self.K_block_n_start[1:])
        It_split = torch.vsplit(self.It, self.K_block_n_start[1:])
        # dvtdw_split = [cp.einsum('ijl,il->ij', ki[:, :, -tstep:], Iti, optimize='greedy') * h.dt * self.grad_scale for ki, Iti in zip(self.K0, It_split)]
        # dvtdw_split = [oe.contract('ijl,il->ij', ki[:, :, -tstep:], Iti, backend='cupy') * h.dt * self.grad_scale for ki, Iti in zip(self.K0, It_split)]
        dvtdw_split = [oe.contract('ijt,it->ij', ki[:, :, -tstep:], Iti, backend='torch') * h.dt * self.grad_scale for ki, Iti in zip(self.K0, It_split)]
        # self.dvtdw = cp.zeros((self.N, self.N))
        self.dvtdw = torch.zeros((self.N, self.N), dtype=torch.float32)
        for i in range(self.K_nblock):
            start = self.K_block_n_start[i]
            end = start + self.K_block_n[i]
            self.dvtdw[start: end, start: end] = dvtdw_split[i].to(torch.device('cpu'), non_blocking=True)
        # dvtdw is consumed on CPU below; synchronize only device 0.
        torch.cuda.synchronize(torch.device('cuda:0'))

        if percise:
            for i in range(self.ngpu):
                # with cp.cuda.Device(i):
                #     exec(f"dItdw_conv_0tot_1_{i} = self.dItdv{i}[cp.newaxis, :, :] * self.dVtdw{i} + self.dItdvpre{i}[cp.newaxis, :, :] * self.dVpretdw{i}\n"
                #         f"dItdw_conv_0tot_1_{i} = cp.hsplit(dItdw_conv_0tot_1_{i}, self.K_block_n_start[1:])\n"
                #         #  f"self.dvtdw_{i} = [cp.einsum('ikl,jkl->ij', ai, ki[:, :, -tstep:], optimize='greedy') * dt for ai, ki in zip(dItdw_conv_0tot_1_{i}, self.K{i})]\n"
                #         f"self.dvtdw_{i} = [oe.contract('ikl,jkl->ij', ai, ki[:, :, -tstep:], backend='cupy') * dt for ai, ki in zip(dItdw_conv_0tot_1_{i}, self.K{i})]\n"
                #         f"self.dvtdw_{i} = cp.hstack(self.dvtdw_{i})\n"
                #         f"del dItdw_conv_0tot_1_{i}",
                #         {'cp': cp, 'oe': oe, 'self': self, 'dt': h.dt, 'tstep': tstep})
                with torch.cuda.stream(self.streams[i]):
                    setattr(self, f'dItdw_conv_0tot_1_{i}', getattr(self, f'dItdv{i}')[torch.newaxis, :, :] * getattr(self, f'dVtdw{i}') + getattr(self, f'dItdvpre{i}')[torch.newaxis, :, :] * getattr(self, f'dVpretdw{i}'))
                    setattr(self, f'dItdw_conv_0tot_1_{i}', torch.hsplit(getattr(self, f'dItdw_conv_0tot_1_{i}'), self.K_block_n_start[1:]))
                    setattr(self, f'dvtdw_{i}', [oe.contract('ikt,jkt->ij', ai, ki[:, :, -tstep:], backend='torch') * h.dt for ai, ki in zip(getattr(self, f'dItdw_conv_0tot_1_{i}'), getattr(self, f'K{i}'))])
                    setattr(self, f'dvtdw_{i}', torch.hstack(getattr(self, f'dvtdw_{i}')))
                    setattr(self, f'dItdw_conv_0tot_1_{i}', None)
            dvtdw_cpu_parts = []
            dvtdw_events = []
            for i in range(self.ngpu):
                with torch.cuda.stream(self.streams[i]):
                    dvtdw_cpu_parts.append(getattr(self, f'dvtdw_{i}').to(torch.device('cpu'), non_blocking=True))
                evt = torch.cuda.Event()
                evt.record(self.streams[i])
                dvtdw_events.append(evt)
            for evt in dvtdw_events:
                evt.synchronize()
            self.dvtdw += torch.vstack(dvtdw_cpu_parts)
        
        # if cp.any(cp.isnan(self.dvtdw)):
        #     print('\n')
        #     print(f"nan detected, t: {h.t}")
        #     print(f"it max: {cp.max(it):.5g}, min: {cp.min(it):.5g}")
        #     if percise:
        #         print(f"ditdv max: {cp.max(self.ditdv):.5g}, min: {cp.min(self.ditdv):.5g}")
        #         print(f"ditdvpre max: {cp.max(self.ditdvpre):.5g}, min: {cp.min(self.ditdvpre):.5g}")
        #     print(f"dvtdw max: {cp.max(self.dvtdw):.5g}, min: {cp.min(self.dvtdw):.5g}")
        #     assert 0
        dvtdw_l2norm = float(torch.linalg.norm(self.dvtdw, ord='fro').item())
        if not math.isfinite(dvtdw_l2norm):
            _sanitize_nonfinite_(self.dvtdw, name="dvtdw", action=_DVDW_NAN_ACTION)
            dvtdw_l2norm = float(torch.linalg.norm(self.dvtdw, ord='fro').item())
        
        # dvtdw_l2norm = float(cp.linalg.norm(self.dvtdw, ord=2))
        # if dvtdw_l2norm > self.grad_l2norm_thresold:
        #     scaler = self.grad_l2norm_thresold / dvtdw_l2norm
        #     self.dvtdw *= scaler
        #     self.dVoutputtdw *= scaler
        #     self.grad_scale *= scaler
        #     for i in range(self.ngpu):
        #         with cp.cuda.Device(i):
        #             exec(f"self.dVtdw{i} *= scaler\n"
        #                 f"self.dVpretdw{i} *= scaler", 
        #                 {'cp': cp, 'self': self, 'scaler': scaler})
        # dvtdw_l2norm = float(torch.linalg.norm(self.dvtdw, ord=2))
        if dvtdw_l2norm > self.grad_l2norm_thresold:
            scaler = self.grad_l2norm_thresold / dvtdw_l2norm
            self.dvtdw *= scaler
            self.dVoutputtdw *= scaler
            self.grad_scale *= scaler
            for i in range(self.ngpu):
                with torch.cuda.stream(self.streams[i]):
                    setattr(self, f'dVtdw{i}', scaler * getattr(self, f'dVtdw{i}'))
                    setattr(self, f'dVpretdw{i}', scaler * getattr(self, f'dVpretdw{i}'))
        
        
        # self.dvpretdw = cp.zeros_like(self.dvtdw)
        self.dvpretdw = torch.zeros_like(self.dvtdw)
        self.dvpretdw[:, self.postplist] = self.dvtdw[:, self.preplist]
        
        if percise:
            # self.dvtdw_split = cp.vsplit(self.dvtdw, self.K_gpu_n_start[1:])
            # self.dvpretdw_split = cp.vsplit(self.dvpretdw, self.K_gpu_n_start[1:])
            # for i in range(self.ngpu):
            #     with cp.cuda.Device(i):
            #         exec(f"self.dVtdw{i} = cp.dstack((self.dVtdw{i}, cp.asarray(self.dvtdw_split[{i}], dtype=cp.float32)[:, :, cp.newaxis]))[:, :, -self.K_len:]\n"
            #             f"self.dVpretdw{i} = cp.dstack((self.dVpretdw{i}, cp.asarray(self.dvpretdw_split[{i}], dtype=cp.float32)[:, :, cp.newaxis]))[:, :, -self.K_len:]", 
            #             {'cp': cp, 'self': self})
            self.dvtdw_split = torch.vsplit(self.dvtdw, self.K_gpu_n_start[1:])
            self.dvpretdw_split = torch.vsplit(self.dvpretdw, self.K_gpu_n_start[1:])
            for i in range(self.ngpu):
                with torch.cuda.stream(self.streams[i]):
                    setattr(self, f'dVtdw{i}', torch.dstack([getattr(self, f'dVtdw{i}'), self.dvtdw_split[i].to(torch.device(f'cuda:{i}'), non_blocking=True)[:, :, torch.newaxis]])[:, :, -self.K_len:])
                    setattr(self, f'dVpretdw{i}', torch.dstack([getattr(self, f'dVpretdw{i}'), self.dvpretdw_split[i].to(torch.device(f'cuda:{i}'), non_blocking=True)[:, :, torch.newaxis]])[:, :, -self.K_len:])

        # Record dV/dw for output/input channels without growing tensors each time.
        if self._lr_valid_len is None or self._lr_buf_len is None:
            raise RuntimeError("lr buffers not initialized; call _reset_lr_records() first")
        if self._lr_valid_len < self._lr_buf_len:
            self.dVoutputtdw[:, :, self._lr_valid_len] = self.dvtdw[:, self.poutput]
            self.dVinputtdw[:, :, self._lr_valid_len] = self.dvtdw[:, self.pinput]
        self._lr_valid_len += 1

    def update_dvdw_from_signals(
        self,
        tstep: int,
        it_cpu: np.ndarray,
        ditdv_cpu: np.ndarray | None = None,
        ditdvpre_cpu: np.ndarray | None = None,
        *,
        percise: bool = True,
    ):
        """
        Replay mode: update dv/dw using pre-recorded signals instead of reading NEURON/HELIOX state.

        Parameters:
        - it_cpu: shape (N,)
        - ditdv_cpu/ditdvpre_cpu: shape (N,), required when percise=True
        """
        tstep = int(tstep)
        if it_cpu is None:
            raise ValueError("it_cpu must be provided")
        it_cpu = np.asarray(it_cpu, dtype=np.float32).reshape((-1,))
        if it_cpu.shape[0] != int(self.N):
            raise ValueError(f"expected it_cpu shape ({self.N},), got {it_cpu.shape}")

        it = torch.from_numpy(it_cpu).to(torch.device("cuda:0"))
        self.It = torch.hstack([self.It, it[:, torch.newaxis]])[:, -self.K_len:]

        if percise:
            if ditdv_cpu is None or ditdvpre_cpu is None:
                raise ValueError("ditdv_cpu and ditdvpre_cpu required when percise=True")
            ditdv_cpu = np.asarray(ditdv_cpu, dtype=np.float32).reshape((-1,))
            ditdvpre_cpu = np.asarray(ditdvpre_cpu, dtype=np.float32).reshape((-1,))
            if ditdv_cpu.shape[0] != int(self.N) or ditdvpre_cpu.shape[0] != int(self.N):
                raise ValueError("ditdv_cpu/ditdvpre_cpu must have shape (N,)")
            self.ditdv = torch.from_numpy(ditdv_cpu)
            self.ditdvpre = torch.from_numpy(ditdvpre_cpu)
            for i in range(self.ngpu):
                with torch.cuda.stream(self.streams[i]):
                    setattr(
                        self,
                        f'dItdv{i}',
                        torch.hstack(
                            [
                                getattr(self, f'dItdv{i}'),
                                self.ditdv.to(torch.device(f'cuda:{i}'), non_blocking=True)[:, torch.newaxis],
                            ]
                        )[:, -self.K_len:],
                    )
                    setattr(
                        self,
                        f'dItdvpre{i}',
                        torch.hstack(
                            [
                                getattr(self, f'dItdvpre{i}'),
                                self.ditdvpre.to(torch.device(f'cuda:{i}'), non_blocking=True)[:, torch.newaxis],
                            ]
                        )[:, -self.K_len:],
                    )
        It_split = torch.vsplit(self.It, self.K_block_n_start[1:])
        dvtdw_split = [
            oe.contract('ijt,it->ij', ki[:, :, -tstep:], Iti, backend='torch') * h.dt * self.grad_scale
            for ki, Iti in zip(self.K0, It_split)
        ]
        self.dvtdw = torch.zeros((self.N, self.N), dtype=torch.float32)
        for i in range(self.K_nblock):
            start = self.K_block_n_start[i]
            end = start + self.K_block_n[i]
            self.dvtdw[start: end, start: end] = dvtdw_split[i].to(torch.device('cpu'), non_blocking=True)
        torch.cuda.synchronize(torch.device('cuda:0'))

        if percise:
            for i in range(self.ngpu):
                with torch.cuda.stream(self.streams[i]):
                    setattr(
                        self,
                        f'dItdw_conv_0tot_1_{i}',
                        getattr(self, f'dItdv{i}')[torch.newaxis, :, :] * getattr(self, f'dVtdw{i}')
                        + getattr(self, f'dItdvpre{i}')[torch.newaxis, :, :] * getattr(self, f'dVpretdw{i}'),
                    )
                    setattr(self, f'dItdw_conv_0tot_1_{i}', torch.hsplit(getattr(self, f'dItdw_conv_0tot_1_{i}'), self.K_block_n_start[1:]))
                    setattr(self, f'dvtdw_{i}', [oe.contract('ikt,jkt->ij', ai, ki[:, :, -tstep:], backend='torch') * h.dt for ai, ki in zip(getattr(self, f'dItdw_conv_0tot_1_{i}'), getattr(self, f'K{i}'))])
                    setattr(self, f'dvtdw_{i}', torch.hstack(getattr(self, f'dvtdw_{i}')))
                    setattr(self, f'dItdw_conv_0tot_1_{i}', None)
            dvtdw_cpu_parts = []
            dvtdw_events = []
            for i in range(self.ngpu):
                with torch.cuda.stream(self.streams[i]):
                    dvtdw_cpu_parts.append(getattr(self, f'dvtdw_{i}').to(torch.device('cpu'), non_blocking=True))
                evt = torch.cuda.Event()
                evt.record(self.streams[i])
                dvtdw_events.append(evt)
            for evt in dvtdw_events:
                evt.synchronize()
            self.dvtdw += torch.vstack(dvtdw_cpu_parts)

        if not torch.isfinite(self.dvtdw).all():
            _sanitize_nonfinite_(self.dvtdw, name="dvtdw (replay mode)", action=_DVDW_NAN_ACTION)

        self.dvpretdw = torch.zeros_like(self.dvtdw)
        self.dvpretdw[:, self.postplist] = self.dvtdw[:, self.preplist]

        # dvtdw is computed on CPU; keep the rest consistent with original update_dvdw.
        self.dvtdw_split = torch.vsplit(self.dvtdw, self.K_gpu_n_start[1:])
        self.dvpretdw_split = torch.vsplit(self.dvpretdw, self.K_gpu_n_start[1:])
        for i in range(self.ngpu):
            with torch.cuda.stream(self.streams[i]):
                setattr(
                    self,
                    f'dVtdw{i}',
                    torch.dstack(
                        [
                            getattr(self, f'dVtdw{i}'),
                            self.dvtdw_split[i].to(torch.device(f'cuda:{i}'), non_blocking=True)[:, :, torch.newaxis],
                        ]
                    )[:, :, -self.K_len:],
                )
                setattr(
                    self,
                    f'dVpretdw{i}',
                    torch.dstack(
                        [
                            getattr(self, f'dVpretdw{i}'),
                            self.dvpretdw_split[i].to(torch.device(f'cuda:{i}'), non_blocking=True)[:, :, torch.newaxis],
                        ]
                    )[:, :, -self.K_len:],
                )

        # Record output/input sensitivities for this LR tick (CPU).
        if self._lr_valid_len < self._lr_buf_len:
            self.dVoutputtdw[:, :, self._lr_valid_len] = self.dvtdw[:, self.poutput]
            self.dVinputtdw[:, :, self._lr_valid_len] = self.dvtdw[:, self.pinput]
            self._lr_valid_len += 1
        

    def get_dw_dx(self, dLtdv, lr_start, lr_end):
        # called after each run
        assert dLtdv.shape == (self.N_output, lr_end - lr_start)
        # dVoutputtdw/dVinputtdw can be on CPU (legacy) or GPU (fast replay).
        rec_device = self.dVoutputtdw.device if hasattr(self.dVoutputtdw, "device") else torch.device("cpu")
        if isinstance(dLtdv, torch.Tensor):
            # Allow upstream to compute dL/dv on GPU (e.g. corr loss in torch).
            dLtdv = dLtdv[:, :: self.K_mul].to(rec_device)
            dLtdv = torch.nan_to_num(dLtdv, nan=0.0, posinf=0.0, neginf=0.0)
        else:
            dLtdv = np.nan_to_num(
                np.asarray(dLtdv[:, :: self.K_mul], dtype=np.float32),
                nan=0.0,
                posinf=0.0,
                neginf=0.0,
            )
            dLtdv = torch.from_numpy(dLtdv).to(rec_device)
        # Only use the portion that was filled during update_dvdw().
        min_t_len = min(int(dLtdv.shape[-1]), int(self._lr_valid_len))
        # dw = np.sum(dLtdv[np.newaxis, :, :] * self.dVoutputtdw[:, :, lr_start: lr_end], axis=(1, 2)) / (lr_end - lr_start)
        dw = torch.sum(dLtdv[np.newaxis, :, :][:, :, :min_t_len] * self.dVoutputtdw[:, :, lr_start: lr_end][:, :, :min_t_len], dim=(1, 2)) / min_t_len
        # dVoutputtdVinputt = oe.contract("ikl,ijl->jkl", self.dVoutputtdw[:, :, lr_start: lr_end], 1. / (self.dVinputtdw[:, :, lr_start: lr_end] + 1e-6))    # shape (N_input, N_output, lr_end - lr_start)
        dVoutputtdVinputt = oe.contract("ikt,ijt->jkt", self.dVoutputtdw[:, :, lr_start: lr_end][:, :, :min_t_len], 1. / (self.dVinputtdw[:, :, lr_start: lr_end][:, :, :min_t_len] + 1e-6), backend='torch')    # shape (N_input, N_output, lr_end - lr_start)
        
        if torch.any(torch.isnan(dVoutputtdVinputt)):
            if _nonfinite_log_enabled():
                _LOGGER.warning("NaN in dVoutputtdVinputt")
        dVoutputtdVinputt[torch.isnan(dVoutputtdVinputt)] = 0.
        if torch.any(torch.isinf(dVoutputtdVinputt)):
            if _nonfinite_log_enabled():
                _LOGGER.warning("INF in dVoutputtdVinputt")
        dVoutputtdVinputt[torch.isinf(dVoutputtdVinputt)] = 0.

        if torch.any(torch.isnan(self.dVinputtdw)):
            if _nonfinite_log_enabled():
                _LOGGER.warning("NaN in dVinputtdw")
        self.dVinputtdw[torch.isnan(self.dVinputtdw)] = 0.
        if torch.any(torch.isinf(self.dVinputtdw)):
            if _nonfinite_log_enabled():
                _LOGGER.warning("INF in dVinputtdw")
        self.dVinputtdw[torch.isinf(self.dVinputtdw)] = 0.

        # dx = np.sum(dLtdv[np.newaxis, :, :] * self.dVinputtdw[:, :, lr_start: lr_end], axis=0)
        # dx = oe.contract("ol,iol,nil->il", dLtdv, dVoutputtdVinputt, self.dVinputtdw[:, :, lr_start: lr_end])
        dx = oe.contract("ot,iot,nit->it", dLtdv[:, :min_t_len], dVoutputtdVinputt[:, :, :min_t_len], self.dVinputtdw[:, :, lr_start: lr_end][:, :, :min_t_len], backend='torch')
        if torch.any(torch.isnan(dx)):
            if _nonfinite_log_enabled():
                _LOGGER.warning("NaN in dx")
        _sanitize_nonfinite_(dx, name="dx", action=_GRAD_NAN_ACTION)

        # mask
        dw[self.pwmaskall] = 0.
        _sanitize_nonfinite_(dw, name="dw", action=_GRAD_NAN_ACTION)

        # return dw, dx
        if rec_device.type == "cpu":
            return dw.numpy(), dx.numpy()
        return dw.detach().cpu().numpy(), dx.detach().cpu().numpy()


    def update_weights(self, dw):
        assert dw.shape == (self.N,)
        # dw = cp.asarray(dw, dtype=cp.float32)
        dw = torch.asarray(dw, dtype=torch.float32)

        # mask
        dw[self.pwmaskall] = 0.

        self.w += dw

        # clip w
        w_gap = self.w[self.pgap]
        w_syn = self.w[self.psyn]
        # w_gap = cp.clip(w_gap, a_min=self.w_gap_min, a_max=self.w_gap_max)
        # w_syn = cp.clip(w_syn, a_min=self.w_syn_min, a_max=self.w_syn_max)
        w_gap = torch.clip(w_gap, min=self.w_gap_min, max=self.w_gap_max)
        w_syn = torch.clip(w_syn, min=self.w_syn_min, max=self.w_syn_max)
        self.w[self.pgap] = w_gap
        self.w[self.psyn] = w_syn

        self.set_weights()
    

    def set_weights(self, w=None):
        # update weights to gap & syns
        if w is None:
            w = self.w
        else:
            assert w.shape == (self.N,)
        # w = cp.asarray(w, dtype=cp.float32)
        w = torch.asarray(w, dtype=torch.float32)
        w_gap = w[self.pgap]
        w_syn = w[self.psyn]
        # w_gap = cp.clip(w_gap, a_min=self.w_gap_min, a_max=self.w_gap_max)
        # w_syn = cp.clip(w_syn, a_min=self.w_syn_min, a_max=self.w_syn_max)
        w_gap = torch.clip(w_gap, min=self.w_gap_min, max=self.w_gap_max)
        w_syn = torch.clip(w_syn, min=self.w_syn_min, max=self.w_syn_max)
        self.w[self.pgap] = w_gap
        self.w[self.psyn] = w_syn
        for id in self.cells_id_sim:
            for point in self.synlist[id].keys():
                for syninfo in self.synlist[id][point]:
                    syn = syninfo.syn
                    pw = syninfo.p
                    syn.w = self.w[pw]
        if self._heliox_backend is not None:
            idx = self._heliox_backend.weight_p_indices
            self._heliox_backend.push_weights(self.w[idx].detach().cpu().numpy())
    

    def set_gap_weights(self, w):
        # update weights to gaps
        assert w.shape == (self.N,)
        # w = cp.asarray(w, dtype=cp.float32)
        w = torch.asarray(w, dtype=torch.float32)
        w_gap = w[self.pgap]
        # w_gap = cp.clip(w_gap, a_min=self.w_gap_min, a_max=self.w_gap_max)
        w_gap = torch.clip(w_gap, min=self.w_gap_min, max=self.w_gap_max)
        self.w[self.pgap] = w_gap
        for id in self.cells_id_sim:
            for point in self.synlist[id].keys():
                for syninfo in self.synlist[id][point]:
                    syn = syninfo.syn
                    if 'gapjunction' in syn.hname():
                        pw = syninfo.p
                        syn.w = self.w[pw]
        if self._heliox_backend is not None:
            idx = self._heliox_backend.weight_p_indices
            self._heliox_backend.push_weights(self.w[idx].detach().cpu().numpy())
    

    def set_syn_weights(self, w=None):
        # update weights to syns
        assert w.shape == (self.N,)
        # w = cp.asarray(w, dtype=cp.float32)
        w = torch.asarray(w, dtype=torch.float32)
        w_syn = w[self.psyn]
        # w_syn = cp.clip(w_syn, a_min=self.w_syn_min, a_max=self.w_syn_max)
        w_syn = torch.clip(w_syn, min=self.w_syn_min, max=self.w_syn_max)
        self.w[self.psyn] = w_syn
        for id in self.cells_id_sim:
            for point in self.synlist[id].keys():
                for syninfo in self.synlist[id][point]:
                    syn = syninfo.syn
                    if 'syn' in syn.hname():
                        pw = syninfo.p
                        syn.w = self.w[pw]
        if self._heliox_backend is not None:
            idx = self._heliox_backend.weight_p_indices
            self._heliox_backend.push_weights(self.w[idx].detach().cpu().numpy())
    

    def save_weights(self, path):
        # np.save(path, cp.asnumpy(self.w))
        np.save(path, self.w.numpy())
            
