#!/usr/bin/env python3
"""
Run a passive HPC (PassiveHPC template) simulation and save a soma voltage trace.

This script is adapted from a multi-cell PassiveHPC workflow,
but rewritten to:
  - avoid hard-coded absolute paths
  - support HelioX (via bbcore export + heliox loader)
  - write the same .npz format expected by plot_compare_to_neuron.py

Outputs (by default under output/):
  - <case>__neuron.npz
  - <case>__heliox_cpu.npz
  - <case>__heliox_gpu.npz
"""

from __future__ import annotations

import argparse
import json
import os
import subprocess
import sys
import tempfile
from pathlib import Path
from typing import Any, Literal

import numpy as np
import shutil

Mode = Literal["neuron", "coreneuron-gpu", "heliox-cpu", "heliox-gpu", "all"]


def _ensure_neuron_importable() -> None:
    """
    Try to make `import neuron` work by adding a likely install path.

    On this machine, NEURON is installed under $HOME/nrn/install/lib/python.
    For portability, allow overriding via NRN_PYTHON_LIB.
    """
    try:
        import neuron  # noqa: F401

        return
    except Exception:
        pass

    candidates = [
        os.environ.get("NRN_PYTHON_LIB", "").strip(),
        os.path.join(os.path.expanduser("~"), "nrn", "install", "lib", "python"),
    ]
    for p in candidates:
        if p and os.path.isdir(p) and p not in sys.path:
            sys.path.insert(0, p)
            break


def _ensure_heliox_on_path() -> None:
    """
    HelioX bindings are not assumed installed system-wide; use HELIOX_PYTHON_LIB.
    """
    if "HELIOX_PYTHON_LIB" in os.environ:
        p = os.environ.get("HELIOX_PYTHON_LIB", "").strip()
        if p and os.path.isdir(p) and p not in sys.path:
            sys.path.insert(0, p)


def _atomic_savez(path: Path, **kwargs: Any) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with tempfile.NamedTemporaryFile(prefix=path.stem + ".tmp.", suffix=path.suffix, dir=str(path.parent), delete=False) as f:
        tmp_path = Path(f.name)
    try:
        np.savez(tmp_path, **kwargs)
        os.replace(tmp_path, path)
    finally:
        try:
            tmp_path.unlink(missing_ok=True)
        except Exception:
            pass


def _find_nrngui_or_nrnivmodl() -> str | None:
    candidates = [
        os.environ.get("NRNIVMODL", "").strip(),
        os.path.join(os.path.expanduser("~"), "nrn", "install", "bin", "nrnivmodl"),
        "nrnivmodl",
    ]
    for c in candidates:
        if not c:
            continue
        if os.path.isabs(c) and os.path.exists(c):
            return c
        if not os.path.isabs(c):
            from shutil import which

            w = which(c)
            if w:
                return w
    return None


def _ensure_hpc_mechanisms(hpc_root: Path, *, verbose: bool) -> None:
    """
    Ensure VecStim etc are available.

    For older NEURON builds, VecStim may be provided by `vecevent.mod`. This
    repository ships `hpc_assets/mod/vecevent.mod` and can build it into
    `hpc_assets/x86_64`.
    """
    from neuron import h, load_mechanisms

    # Newer NEURON builds may already ship VecStim. If so, don't try to load
    # a local vecevent.mod again (it would redefine VecStim and error).
    try:
        if hasattr(h, "VecStim"):
            return
    except Exception:
        pass

    # If already built, load them
    x86_64 = hpc_root / "x86_64"
    if x86_64.exists():
        if verbose:
            print(f"[hpc_assets] loading mechanisms from: {hpc_root}", flush=True)
        try:
            load_mechanisms(str(hpc_root))
        except RuntimeError as e:
            msg = str(e)
            if "VecStim" in msg and "already exists" in msg:
                return
            raise
        return

    nrnivmodl = _find_nrngui_or_nrnivmodl()
    if not nrnivmodl:
        raise RuntimeError("Could not find nrnivmodl; set NRNIVMODL or ensure it is on PATH.")
    if verbose:
        print(f"[hpc_assets] building mechanisms via: {nrnivmodl} mod", flush=True)
    if not (hpc_root / "mod").exists():
        raise FileNotFoundError(f"Expected mod/ under: {hpc_root}")
    subprocess.check_call([nrnivmodl, "mod"], cwd=str(hpc_root))
    if verbose:
        print(f"[hpc_assets] loading mechanisms from: {hpc_root}", flush=True)
    try:
        load_mechanisms(str(hpc_root))
    except RuntimeError as e:
        msg = str(e)
        if "VecStim" in msg and "already exists" in msg:
            return
        raise


def _disable_usetable() -> None:
    from neuron import h

    # match the old scripts
    try:
        h.usetable_hh = 0
    except Exception:
        pass


def _setup_hpc_cell(template_name: str, morph_asc: str):
    from neuron import h

    cell = getattr(h, template_name)()
    nl = h.Import3d_Neurolucida3()
    nl.quiet = 1
    nl.input(morph_asc)
    imprt = h.Import3d_GUI(nl, 0)
    imprt.instantiate(cell)

    cell.indexSections(imprt)
    cell.geom_nsec()
    cell.geom_nseg()
    cell.delete_axon()
    cell.insertChannel()
    cell.init_rc()
    cell.biophys()
    return cell


def _attach_vecstim_inputs(cell, *, nstim: int, freq_hz: float, tstart: float, tend: float, dt: float, seed: int):
    """
    Attach random VecStim-driven Exp2Syn inputs to random dendritic segments.
    Kept close to the legacy workflow behavior.
    """
    from neuron import h

    rng = np.random.default_rng(int(seed))
    nbin = int((tend - tstart) / dt)
    spike_mat = rng.random((nstim, nbin)) < (float(freq_hz) / 1000.0 * float(dt))

    all_spike_times: list[np.ndarray] = []
    for i in range(nstim):
        spike_index = np.where(spike_mat[i, :] == 1)[0]
        spike_time = spike_index.astype(float) * float(dt) + float(tstart)
        all_spike_times.append(spike_time)

    all_seg = []
    for sec in cell.dend:
        for seg in sec:
            all_seg.append(seg)

    if not all_seg:
        raise RuntimeError("No dendritic segments found on PassiveHPC cell.")

    stim_list = []
    idx = rng.choice(len(all_seg), size=nstim, replace=True)
    for i in range(nstim):
        spikes = all_spike_times[i]
        vs = h.VecStim()
        v = h.Vector(spikes.shape[0])
        v.from_python(spikes)
        vs.play(v)

        syn = h.Exp2Syn(all_seg[int(idx[i])])
        nc = h.NetCon(vs, syn)
        nc.threshold = 0.1
        nc.weight[0] = 0.3
        stim_list.append((v, vs, nc, syn))

    return stim_list


def _build_network(
    *,
    passivehpc_root: Path,
    ncell: int,
    seed: int,
    dt: float,
    tstop: float,
    v_init: float,
    vecstim_n: int,
    vecstim_freq_hz: float,
    vecstim_tstart: float,
    vecstim_tend: float,
    iclamp_amp: float,
    iclamp_delay: float,
    iclamp_dur: float,
):
    from neuron import h

    h.load_file("nrngui.hoc")
    # NOTE: `sthcell.hoc` is shipped alongside the PassiveHPC assets and loaded
    # from CWD (we chdir to `passivehpc_root` before calling this function).
    h.load_file("sthcell.hoc")
    h.load_file("import3d.hoc")
    h.load_file("PassiveHPC.hoc")

    pc = h.ParallelContext()
    cell_list = []
    nc_list = []

    for gid in range(int(ncell)):
        cell = _setup_hpc_cell("PassiveHPC", "2013_03_06_cell11_1125_H41_06.asc")
        nc = h.NetCon(cell.soma[0](0.5)._ref_v, None, sec=cell.soma[0])
        nc.threshold = 0.1
        cell_list.append(cell)
        nc_list.append(nc)
        pc.set_gid2node(gid, 0)
        pc.cell(gid, nc)

    # External input on cell 0
    _stim_keepalive = []
    _stim_keepalive.extend(
        _attach_vecstim_inputs(
            cell_list[0],
            nstim=int(vecstim_n),
            freq_hz=float(vecstim_freq_hz),
            tstart=float(vecstim_tstart),
            tend=float(vecstim_tend),
            dt=float(dt),
            seed=int(seed),
        )
    )

    iclamp = h.IClamp(cell_list[0].soma[0](0.5))
    iclamp.dur = float(iclamp_dur)
    iclamp.amp = float(iclamp_amp)
    iclamp.delay = float(iclamp_delay)
    _stim_keepalive.append(iclamp)

    h.cvode.cache_efficient(1)
    pc.setup_transfer()
    pc.set_maxstep(10)
    h.dt = float(dt)
    h.tstop = float(tstop)
    h.stdinit()
    h.finitialize(float(v_init))

    vrec = h.Vector().record(cell_list[0].soma[0](0.5)._ref_v)
    return pc, cell_list[0], vrec, _stim_keepalive


def _export_bbcore(pc, export_dir: Path, *, dt: float, v_init: float) -> None:
    from neuron import h

    if export_dir.exists():
        shutil.rmtree(export_dir)
    export_dir.mkdir(parents=True, exist_ok=True)
    h.dt = float(dt)
    h.finitialize(float(v_init))
    pc.nrnbbcore_write(str(export_dir))


def _run_neuron(pc, *, tstop: float, v_init: float) -> None:
    from neuron import h

    h.finitialize(float(v_init))
    pc.psolve(float(tstop))


def _run_coreneuron_gpu(pc, *, tstop: float, v_init: float) -> None:
    from neuron import coreneuron, h

    coreneuron.enable = True
    coreneuron.gpu = True
    # For GPU execution, CoreNEURON requires cell_permute 1 or 2; use 2 to match prior convention.
    coreneuron.cell_permute = 2

    h.finitialize(float(v_init))
    pc.psolve(float(tstop))


def _run_heliox(
    cell0,
    export_dir: Path,
    *,
    tstop: float,
    dt: float,
    v_init: float,
    device: Literal["cpu", "gpu"],
    permute_type: int,
) -> np.ndarray:
    _ensure_heliox_on_path()
    import heliox
    from heliox_monitor import make_segment_monitor

    client = heliox.Sim()
    client.set_data_path(str(export_dir))
    client.set_device(device)
    if permute_type < 0:
        permute_type = 0 if device == "cpu" else 3
    client.set_permute_type(int(permute_type))

    v_mon = make_segment_monitor(cell0.soma[0](0.5), "v", client)
    rc = client.load_model()
    if rc != 0:
        raise RuntimeError(f"HelioX load_model() failed with code {rc}")
    client.set_dt(float(dt))
    client.finitialize(float(v_init))
    client.run(float(tstop))
    return np.asarray(v_mon.get_data(), dtype=float)


def _save_trace(out_path: Path, *, t: np.ndarray, v: np.ndarray, meta: dict[str, Any]) -> None:
    _atomic_savez(
        out_path,
        t=np.asarray(t, dtype=float),
        v=np.asarray(v, dtype=float),
        meta_json=json.dumps(meta, ensure_ascii=False, sort_keys=True),
    )


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", choices=["neuron", "coreneuron-gpu", "heliox-cpu", "heliox-gpu", "all"], required=True)
    parser.add_argument("--case", default="hpc_passive", help="Case name used in output filenames.")
    parser.add_argument("--out-dir", default="output", help="Output directory for .npz files and exports.")
    parser.add_argument("--export-dir", default="AUTO", help="bbcore export directory (HelioX modes).")
    parser.add_argument(
        "--passivehpc-root",
        default=os.environ.get("PASSIVEHPC_ROOT", "").strip(),
        help="Path to a directory containing PassiveHPC.hoc + sthcell.hoc + morphology .asc (optional).",
    )
    parser.add_argument("--verbose", action="store_true")

    # Simulation params (legacy defaults)
    parser.add_argument("--ncell", type=int, default=1)
    parser.add_argument("--dt", type=float, default=0.025)
    parser.add_argument("--tstop", type=float, default=500.0)
    parser.add_argument("--v-init", type=float, default=-65.0)
    parser.add_argument("--seed", type=int, default=1)

    # Inputs
    parser.add_argument("--vecstim-n", type=int, default=5)
    parser.add_argument("--vecstim-freq-hz", type=float, default=5.0)
    parser.add_argument("--vecstim-tstart", type=float, default=100.0)
    parser.add_argument("--vecstim-tend", type=float, default=500.0)
    parser.add_argument("--iclamp-amp", type=float, default=5.0)
    parser.add_argument("--iclamp-delay", type=float, default=200.0)
    parser.add_argument("--iclamp-dur", type=float, default=100.0)

    # HelioX knobs
    parser.add_argument("--heliox-permute-type", type=int, default=-1)
    args = parser.parse_args()

    _ensure_neuron_importable()
    from neuron import h  # noqa: F401

    _disable_usetable()

    repo_root = Path(__file__).resolve().parent
    hpc_root = (repo_root / "hpc_assets").resolve()
    passivehpc_root = (
        Path(args.passivehpc_root).resolve()
        if str(args.passivehpc_root).strip()
        else (hpc_root / "passivehpc").resolve()
    )
    if not passivehpc_root.exists():
        raise FileNotFoundError(
            f"PassiveHPC assets not found: {passivehpc_root} (set PASSIVEHPC_ROOT or --passivehpc-root)"
        )

    # Resolve output paths before we chdir (the model load uses relative hoc paths).
    old_cwd = Path.cwd()
    out_dir = Path(args.out_dir)
    if not out_dir.is_absolute():
        out_dir = (old_cwd / out_dir).resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    case = str(args.case)

    def out_path(tag: str) -> Path:
        return out_dir / f"{case}__{tag}.npz"

    def export_path() -> Path:
        if args.export_dir == "AUTO":
            return out_dir / f"coredat__{case}"
        return Path(args.export_dir).resolve()

    # Ensure mod mechanisms (VecStim) exist for this checkout
    _ensure_hpc_mechanisms(hpc_root, verbose=bool(args.verbose))

    if args.mode == "all":
        # Run the three modes via subprocess isolation (avoids multi-run reset issues).
        base = [sys.executable, str(Path(__file__).resolve())]
        # Keep 'all' for small runs; for large networks prefer running specific modes.
        modes: list[str] = ["neuron", "coreneuron-gpu", "heliox-cpu", "heliox-gpu"]
        for m in modes:
            cmd = base + [
                "--mode",
                m,
                "--case",
                case,
                "--out-dir",
                str(out_dir),
                "--export-dir",
                str(args.export_dir),
                "--passivehpc-root",
                str(passivehpc_root),
                "--ncell",
                str(args.ncell),
                "--dt",
                str(args.dt),
                "--tstop",
                str(args.tstop),
                "--v-init",
                str(args.v_init),
                "--seed",
                str(args.seed),
                "--vecstim-n",
                str(args.vecstim_n),
                "--vecstim-freq-hz",
                str(args.vecstim_freq_hz),
                "--vecstim-tstart",
                str(args.vecstim_tstart),
                "--vecstim-tend",
                str(args.vecstim_tend),
                "--iclamp-amp",
                str(args.iclamp_amp),
                "--iclamp-delay",
                str(args.iclamp_delay),
                "--iclamp-dur",
                str(args.iclamp_dur),
                "--heliox-permute-type",
                str(args.heliox_permute_type),
            ]
            if args.verbose:
                cmd.append("--verbose")
            print(f"[all] running: {' '.join(cmd)}", flush=True)
            subprocess.check_call(cmd)
        return 0

    # Use a stable CWD so hoc/morph references are relative (and don't leak absolute paths).
    os.chdir(str(passivehpc_root))
    pc, cell0, vrec, _keepalive = _build_network(
        passivehpc_root=passivehpc_root,
        ncell=int(args.ncell),
        seed=int(args.seed),
        dt=float(args.dt),
        tstop=float(args.tstop),
        v_init=float(args.v_init),
        vecstim_n=int(args.vecstim_n),
        vecstim_freq_hz=float(args.vecstim_freq_hz),
        vecstim_tstart=float(args.vecstim_tstart),
        vecstim_tend=float(args.vecstim_tend),
        iclamp_amp=float(args.iclamp_amp),
        iclamp_delay=float(args.iclamp_delay),
        iclamp_dur=float(args.iclamp_dur),
    )

    tag: str
    v: np.ndarray
    export_dir = export_path()

    if args.mode == "neuron":
        _run_neuron(pc, tstop=float(args.tstop), v_init=float(args.v_init))
        v = np.asarray(vrec.as_numpy(), dtype=float)
        tag = "neuron"
        meta = {
            "mode": "neuron",
            "case": case,
            "dt_ms": float(args.dt),
            "runtime_ms": float(args.tstop),
            "v_init_mV": float(args.v_init),
            "ncell": int(args.ncell),
        }
    elif args.mode == "coreneuron-gpu":
        _run_coreneuron_gpu(pc, tstop=float(args.tstop), v_init=float(args.v_init))
        v = np.asarray(vrec.as_numpy(), dtype=float)
        tag = "coreneuron_gpu"
        meta = {
            "mode": "coreneuron-gpu",
            "case": case,
            "dt_ms": float(args.dt),
            "runtime_ms": float(args.tstop),
            "v_init_mV": float(args.v_init),
            "ncell": int(args.ncell),
            "coreneuron_gpu": True,
            "coreneuron_cell_permute": 2,
        }
    else:
        # Export after setting up the model + stimuli, so HelioX inherits them.
        _export_bbcore(pc, export_dir=export_dir, dt=float(args.dt), v_init=float(args.v_init))
        device = "cpu" if args.mode == "heliox-cpu" else "gpu"
        v = _run_heliox(
            cell0,
            export_dir=export_dir,
            tstop=float(args.tstop),
            dt=float(args.dt),
            v_init=float(args.v_init),
            device=device,
            permute_type=int(args.heliox_permute_type),
        )
        tag = "heliox_cpu" if device == "cpu" else "heliox_gpu"
        meta = {
            "mode": f"heliox-{device}",
            "case": case,
            "dt_ms": float(args.dt),
            "runtime_ms": float(args.tstop),
            "v_init_mV": float(args.v_init),
            "ncell": int(args.ncell),
            "export_dir": str(export_dir),
            "heliox_device": device,
            "heliox_permute_type": (0 if device == "cpu" else 3) if int(args.heliox_permute_type) < 0 else int(args.heliox_permute_type),
        }

    t = np.arange(v.size, dtype=float) * float(args.dt)
    _save_trace(out_path(tag), t=t, v=v, meta=meta)
    print(f"Saved: {out_path(tag)}")
    os.chdir(str(old_cwd))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
