#!/usr/bin/env python3
"""
Run a single deterministic L5PC simulation and save the soma voltage trace.

This script is the "phase 1" of the workflow:
  - build L5PC in NEURON
  - apply a fixed stimulus (optionally random but with a fixed seed)
  - run ONE backend mode
  - write a compact .npz archive with trace + metadata

Supported modes:
  - neuron:        NEURON interpreter engine
  - coreneuron:    NEURON+CoreNEURON engine (CPU by default)
  - heliox-cpu:   HelioX on CPU (loads exported bbcore)
  - heliox-gpu:   HelioX on GPU (loads exported bbcore)
  - all:           run the 4 modes via subprocess isolation
"""

from __future__ import annotations

import argparse
import json
import os
import shutil
import subprocess
import sys
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal

import numpy as np


Mode = Literal["neuron", "coreneuron-cpu", "coreneuron-gpu", "heliox-cpu", "heliox-gpu"]


def _ensure_heliox_on_path() -> None:
    """
    Ensure the local HelioX Python bindings are importable.

    For anonymity/reproducibility we avoid hard-coded absolute paths; provide the
    binding location via HELIOX_PYTHON_LIB when HelioX is not installed.
    """
    p = os.environ.get("HELIOX_PYTHON_LIB", "").strip()
    if p and os.path.isdir(p) and p not in sys.path:
        sys.path.insert(0, p)


def _disable_usetable() -> None:
    from neuron import h

    h.usetable_hh = 0
    h.usetable_sca_la = 0
    h.usetable_it2_la = 0
    h.usetable_kdf_la = 0
    h.usetable_kdr_la = 0
    h.usetable_km_la = 0


def _build_model():
    from neuron import h

    h.load_file("import3d.hoc")
    h.load_file("stdgui.hoc")
    h.load_file("L5PClatemplate_record.hoc")

    pc = h.ParallelContext()
    pc.nthread(1, 0)
    pc.set_maxstep(5)
    h.cvode.cache_efficient(1)
    h.stdinit()

    cell = h.L5PClatemplate_record()

    # Required for nrnbbcore_write: register at least one gid/cell.
    spike_detector = h.NetCon(cell.soma(0.5)._ref_v, None, sec=cell.soma)
    pc.set_gid2node(0, pc.id())
    pc.cell(0, spike_detector)

    return pc, cell


def _record_soma_v(cell):
    from neuron import h

    v = h.Vector().record(cell.soma(0.5)._ref_v)
    return v


@dataclass(frozen=True)
class StimulusSpec:
    stim: Literal["soma_single", "dend_69"]
    stim_pattern: Literal["constant", "ramp", "random"]
    amp: float
    delay: float
    dur: float
    seed: int
    random_max: float


def _compute_dend_amps(spec: StimulusSpec) -> np.ndarray:
    if spec.stim != "dend_69":
        raise ValueError("dend_amps requested but stim != dend_69")
    if spec.stim_pattern == "constant":
        return np.full(69, float(spec.amp), dtype=float)
    if spec.stim_pattern == "ramp":
        return np.array([float(spec.amp) * (i + 1) for i in range(69)], dtype=float)
    rng = np.random.default_rng(int(spec.seed))
    return rng.uniform(0.0, float(spec.random_max), size=69).astype(float)


def _apply_stimulus(cell, spec: StimulusSpec) -> tuple[dict[str, Any], list[Any]]:
    """
    Create IClamp(s) in NEURON. This MUST run before bbcore export for HelioX.
    Returns a small metadata dict that is persisted to the output archive.
    """
    from neuron import h

    stims = []
    meta: dict[str, Any] = {
        "stim": spec.stim,
        "stim_pattern": spec.stim_pattern,
        "amp": float(spec.amp),
        "delay": float(spec.delay),
        "dur": float(spec.dur),
        "seed": int(spec.seed),
        "random_max": float(spec.random_max),
    }

    if spec.stim == "soma_single":
        stim = h.IClamp(cell.soma(0.5))
        stim.delay = spec.delay
        stim.dur = spec.dur
        stim.amp = spec.amp
        stims.append(stim)
        meta["stim_desc"] = f"soma_single amp={spec.amp} delay={spec.delay} dur={spec.dur}"
        return meta, stims

    amps = _compute_dend_amps(spec)
    for i in range(69):
        stim = h.IClamp(cell.dend[i](0.5))
        stim.delay = spec.delay
        stim.dur = spec.dur
        stim.amp = float(amps[i])
        stims.append(stim)
    meta["stim_desc"] = (
        f"dend_69 pattern={spec.stim_pattern} base_amp={spec.amp} seed={spec.seed} "
        f"max={spec.random_max} delay={spec.delay} dur={spec.dur}"
    )
    meta["stim_amps"] = amps
    return meta, stims


def _export_bbcore(pc, export_dir: Path, dt: float, v_init: float) -> None:
    from neuron import h

    export_dir.mkdir(parents=True, exist_ok=True)
    h.dt = dt
    h.finitialize(v_init)
    pc.nrnbbcore_write(str(export_dir))


def _prepare_fresh_dir(path: Path) -> None:
    if path.exists():
        shutil.rmtree(path)
    path.mkdir(parents=True, exist_ok=True)


def _run_neuron(pc, runtime: float, dt: float, v_init: float) -> None:
    from neuron import h

    h.dt = dt
    h.finitialize(v_init)
    pc.psolve(runtime)


def _run_coreneuron(pc, runtime: float, dt: float, v_init: float, *, gpu: bool) -> None:
    from neuron import coreneuron, h

    # CoreNEURON needs a CoreNEURON mechanism library for custom MOD files.
    # If it is missing, NEURON may abort inside the embedded CoreNEURON runner,
    # which is hard to debug. Fail fast with a helpful message.
    candidates = [
        Path("x86_64") / "libcorenrnmech.so",
        Path("x86_64") / ".libs" / "libcorenrnmech.so",
    ]
    if not any(p.exists() for p in candidates):
        raise RuntimeError(
            "CoreNEURON mechanism library not found (expected one of: "
            + ", ".join(str(p) for p in candidates)
            + ").\n"
            "Build it with:\n"
            "  nrnivmodl -coreneuron mod"
        )

    coreneuron.enable = True
    coreneuron.verbose = int(os.environ.get("CORENEURON_VERBOSE", "0").strip() or "0")
    coreneuron.gpu = bool(gpu)
    if gpu:
        coreneuron.num_gpus = int(os.environ.get("CORENEURON_NUM_GPUS", "1").strip() or "1")
        # For GPU execution, CoreNEURON requires cell_permute 1 or 2. Fix to 2 for determinism.
        # (This avoids CoreNEURON auto-downgrading it to 1 with a warning.)
        coreneuron.cell_permute = 2

    h.dt = dt
    h.finitialize(v_init)
    pc.psolve(runtime)


def _run_heliox(
    cell,
    export_dir: Path,
    runtime: float,
    dt: float,
    v_init: float,
    *,
    device: Literal["cpu", "gpu"],
    permute_type: int,
) -> np.ndarray:
    _ensure_heliox_on_path()

    import heliox
    from heliox_monitor import make_segment_monitor

    client = heliox.Sim()
    client.set_data_path(str(export_dir))
    client.set_device(device)

    if permute_type < 0:
        permute_type = 0 if device == "cpu" else 3
    client.set_permute_type(int(permute_type))

    v_mon = make_segment_monitor(cell.soma(0.5), "v", client)
    rc = client.load_model()
    if rc != 0:
        raise RuntimeError(f"HelioX load_model() failed with code {rc}")

    client.set_dt(dt)
    client.finitialize(v_init)
    client.run(runtime)
    return np.asarray(v_mon.get_data(), dtype=float)


def _default_case_name(args: argparse.Namespace) -> str:
    # Make filenames stable and greppable without being too long.
    stim_part = f"{args.stim}"
    if args.stim == "dend_69":
        stim_part += f"_{args.stim_pattern}"
        if args.stim_pattern == "random":
            stim_part += f"_seed{args.seed}_max{args.random_max}"
        stim_part += f"_amp{args.amp}"
    else:
        stim_part += f"_amp{args.amp}"
    return f"{stim_part}_dt{args.dt}_T{args.runtime}"


def _out_path(out_dir: Path, case: str, mode: Mode) -> Path:
    safe_mode = mode.replace("-", "_")
    return out_dir / f"{case}__{safe_mode}.npz"


def _atomic_savez(path: Path, **kwargs: Any) -> None:
    """
    Save a .npz robustly, even if the destination file already exists and is not
    directly writable (e.g., copied from another machine/user with odd ACLs).

    Implementation: write to a temp file in the same directory, then atomically
    replace the destination.
    """
    path.parent.mkdir(parents=True, exist_ok=True)
    # Important: the temp filename MUST end with ".npz", otherwise numpy will
    # silently append ".npz" and we'd end up replacing the destination with an
    # empty temp file while leaving the real archive behind as "*.tmp.npz".
    with tempfile.NamedTemporaryFile(
        prefix=path.stem + ".tmp.",
        suffix=path.suffix,
        dir=str(path.parent),
        delete=False,
    ) as f:
        tmp_path = Path(f.name)
    try:
        np.savez(tmp_path, **kwargs)
        os.replace(tmp_path, path)
    finally:
        try:
            tmp_path.unlink(missing_ok=True)
        except Exception:
            pass


def _run_one_mode(args: argparse.Namespace, mode: Mode) -> Path:
    from neuron import h  # noqa: F401

    _disable_usetable()
    pc, cell = _build_model()
    v_vec = _record_soma_v(cell)

    stim_spec = StimulusSpec(
        stim=args.stim,
        stim_pattern=args.stim_pattern,
        amp=float(args.amp),
        delay=float(args.delay),
        dur=float(args.dur),
        seed=int(args.seed),
        random_max=float(args.random_max),
    )
    stim_meta, _stims_keepalive = _apply_stimulus(cell, stim_spec)

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    case = args.case or _default_case_name(args)
    out_path = _out_path(out_dir, case, mode)

    heliox_device: Literal["cpu", "gpu"] | None = None
    heliox_permute_type_used = int(args.heliox_permute_type)

    # Run
    if mode == "neuron":
        _run_neuron(pc, runtime=float(args.runtime), dt=float(args.dt), v_init=float(args.v_init))
        v = np.asarray(v_vec.as_numpy(), dtype=float)
    elif mode in {"coreneuron-cpu", "coreneuron-gpu"}:
        _run_coreneuron(
            pc,
            runtime=float(args.runtime),
            dt=float(args.dt),
            v_init=float(args.v_init),
            gpu=(mode == "coreneuron-gpu"),
        )
        v = np.asarray(v_vec.as_numpy(), dtype=float)
    elif mode in {"heliox-cpu", "heliox-gpu"}:
        export_dir = Path(args.export_dir)
        # Keep export directories per case/mode by default to avoid accidental mixing.
        if str(args.export_dir).strip() == "":
            raise ValueError("--export-dir must not be empty")
        if args.export_dir == "AUTO":
            export_dir = out_dir / f"coredat__{case}"
        if args.clean_export:
            _prepare_fresh_dir(export_dir)
        _export_bbcore(pc, export_dir=export_dir, dt=float(args.dt), v_init=float(args.v_init))
        heliox_device = "cpu" if mode == "heliox-cpu" else "gpu"
        if heliox_permute_type_used < 0:
            heliox_permute_type_used = 0 if heliox_device == "cpu" else 3
        v = _run_heliox(
            cell,
            export_dir=export_dir,
            runtime=float(args.runtime),
            dt=float(args.dt),
            v_init=float(args.v_init),
            device=heliox_device,
            permute_type=int(heliox_permute_type_used),
        )
    else:
        raise ValueError(f"Unknown mode: {mode}")

    # Time axis: derive from dt and returned vector length (be defensive about slight mismatches).
    dt = float(args.dt)
    n = int(v.size)
    t = np.arange(n, dtype=float) * dt

    meta = {
        "mode": mode,
        "runtime_ms": float(args.runtime),
        "dt_ms": dt,
        "v_init_mV": float(args.v_init),
        "case": case,
        "export_dir": str(args.export_dir),
        "coreneuron_gpu": (mode == "coreneuron-gpu"),
        "coreneuron_cell_permute": 2 if (mode == "coreneuron-gpu") else None,
        # For heliox modes, persist the *resolved* permute type (0 for cpu, 3 for gpu),
        # not the CLI placeholder (-1).
        "heliox_device": heliox_device,
        "heliox_permute_type": int(heliox_permute_type_used) if heliox_device else None,
    }
    meta.update({k: (v.tolist() if isinstance(v, np.ndarray) and v.ndim == 1 and v.size == 69 else v) for k, v in {**stim_meta}.items() if k != "stim_amps"})

    # Save. Keep stim_amps as a numeric array if present.
    npz_kwargs: dict[str, Any] = {
        "t": t,
        "v": v.astype(float, copy=False),
        "meta_json": json.dumps(meta, ensure_ascii=False, sort_keys=True),
    }
    if "stim_amps" in stim_meta:
        npz_kwargs["stim_amps"] = np.asarray(stim_meta["stim_amps"], dtype=float)

    _atomic_savez(out_path, **npz_kwargs)
    return out_path


def _run_all_modes(args: argparse.Namespace) -> tuple[list[Path], list[tuple[Mode, int]]]:
    case = args.case or _default_case_name(args)
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    modes: list[Mode] = ["neuron", "coreneuron-cpu", "coreneuron-gpu", "heliox-cpu", "heliox-gpu"]
    out_paths: list[Path] = []
    failures: list[tuple[Mode, int]] = []

    base_argv = [sys.executable, str(Path(__file__).resolve())]
    # Reconstruct argv from args (avoid relying on sys.argv parsing quirks).
    shared_flags = [
        "--runtime",
        str(args.runtime),
        "--dt",
        str(args.dt),
        "--v-init",
        str(args.v_init),
        "--delay",
        str(args.delay),
        "--dur",
        str(args.dur),
        "--stim",
        str(args.stim),
        "--stim-pattern",
        str(args.stim_pattern),
        "--amp",
        str(args.amp),
        "--seed",
        str(args.seed),
        "--random-max",
        str(args.random_max),
        "--out-dir",
        str(args.out_dir),
        "--case",
        str(case),
        "--export-dir",
        str(args.export_dir),
        "--heliox-permute-type",
        str(args.heliox_permute_type),
    ]

    for mode in modes:
        cmd = base_argv + ["--mode", mode] + shared_flags
        print(f"[all] running: {' '.join(cmd)}", flush=True)
        proc = subprocess.run(cmd)
        if proc.returncode != 0:
            failures.append((mode, int(proc.returncode)))
            if not args.continue_on_error:
                break
        else:
            out_paths.append(_out_path(out_dir, case, mode))

    return out_paths, failures


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mode",
        choices=["neuron", "coreneuron-cpu", "coreneuron-gpu", "heliox-cpu", "heliox-gpu", "all"],
        required=True,
        help="Which backend mode to run and save.",
    )
    parser.add_argument("--case", default="", help="Case name used in output filenames (default: auto).")
    parser.add_argument("--out-dir", default="output", help="Output directory for .npz archives.")
    parser.add_argument(
        "--export-dir",
        default="AUTO",
        help="Export directory for bbcore (HelioX modes only). Use AUTO to place under out-dir.",
    )
    parser.add_argument(
        "--clean-export",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="For HelioX modes: delete export-dir before writing (default: enabled).",
    )

    # Simulation params
    parser.add_argument("--runtime", type=float, default=300.0, help="Runtime in ms")
    parser.add_argument("--dt", type=float, default=0.05, help="Timestep in ms")
    parser.add_argument("--v-init", type=float, default=-62.5, help="Initial voltage (mV)")

    # Stimulus params
    parser.add_argument(
        "--stim-preset",
        choices=["", "repro_moderate"],
        default="",
        help="Convenience preset.",
    )
    parser.add_argument("--delay", type=float, default=20.0, help="IClamp delay (ms)")
    parser.add_argument("--dur", type=float, default=200.0, help="IClamp duration (ms)")
    parser.add_argument("--stim", choices=["soma_single", "dend_69"], default="dend_69")
    parser.add_argument("--stim-pattern", choices=["constant", "ramp", "random"], default="constant")
    parser.add_argument("--amp", type=float, default=0.02, help="Base IClamp amplitude")
    parser.add_argument("--seed", type=int, default=1000, help="Random seed for stim-pattern=random")
    parser.add_argument("--random-max", type=float, default=0.04, help="Upper bound for random amplitudes")

    # Backend knobs
    parser.add_argument(
        "--heliox-permute-type",
        type=int,
        default=int(os.environ.get("HELIOX_PERMUTE_TYPE", "-1")),
        help="HelioX permute type (default: 0 for cpu, 3 for gpu).",
    )
    parser.add_argument(
        "--coreneuron-gpu",
        action="store_true",
        help="Deprecated. Use --mode coreneuron-gpu (kept for older notes/scripts).",
    )
    parser.add_argument(
        "--continue-on-error",
        action="store_true",
        help="With --mode all: keep going even if a mode fails (still prints failures).",
    )

    args = parser.parse_args()

    # Apply stimulus presets by overriding args fields (CLI still prints chosen values via metadata).
    if args.stim_preset == "repro_moderate":
        args.stim = "dend_69"
        args.stim_pattern = "random"
        args.amp = 0.02
        args.seed = 1000
        args.random_max = 0.02
        args.delay = 20.0
        args.dur = 200.0

    if args.mode == "all":
        out_paths, failures = _run_all_modes(args)
        if out_paths:
            print("\nSaved:")
            for p in out_paths:
                print(f"  - {p}")
        if failures:
            print("\nFailed:")
            for mode, rc in failures:
                print(f"  - {mode}: exit code {rc}")
            return 2
        return 0

    # Back-compat: allow `--mode coreneuron-cpu` but user passed legacy `--coreneuron-gpu`.
    # Also allow `--mode coreneuron-gpu` regardless of the legacy flag.
    mode = str(args.mode)
    if mode == "coreneuron-cpu" and args.coreneuron_gpu:
        mode = "coreneuron-gpu"

    out_path = _run_one_mode(args, mode=mode)  # type: ignore[arg-type]
    print(f"Saved: {out_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
