#!/usr/bin/env python3
"""
Minimal single-run NEURON vs HelioX comparison for L5PC.

Design goals:
- single-run only (avoid known multi-run reset issues)
- set stimulus parameters in NEURON before export (so they are carried into HelioX)
- record soma(0.5).v in both simulators and plot overlay + absolute error
"""

from __future__ import annotations

import argparse
import os
import sys
from pathlib import Path

import numpy as np


def _configure_matplotlib_backend() -> None:
    import matplotlib

    backend = os.environ.get("MPLBACKEND", "").strip()
    if backend:
        return
    if os.environ.get("DISPLAY"):
        matplotlib.use("QtAgg")
    else:
        matplotlib.use("Agg")


def _ensure_heliox_on_path() -> None:
    candidates = [
        os.environ.get("HELIOX_PYTHON_LIB", "").strip(),
        "$HOME/heliox/python_lib",
        "$HOME/Documents/heliox/python_lib",
    ]
    for p in candidates:
        if p and os.path.isdir(p) and p not in sys.path:
            sys.path.insert(0, p)
            return


def _disable_usetable() -> None:
    from neuron import h

    # Keep exports free of usetable-related globals to avoid precision drift.
    h.usetable_hh = 0
    h.usetable_sca_la = 0
    h.usetable_it2_la = 0
    h.usetable_kdf_la = 0
    h.usetable_kdr_la = 0
    h.usetable_km_la = 0


def _build_model():
    from neuron import h

    h.load_file("import3d.hoc")
    h.load_file("stdgui.hoc")
    h.load_file("L5PClatemplate_record.hoc")

    pc = h.ParallelContext()
    pc.nthread(1, 0)
    pc.set_maxstep(5)
    h.cvode.cache_efficient(1)
    h.stdinit()

    cell = h.L5PClatemplate_record()

    # Required for nrnbbcore_write: register at least one gid/cell.
    spike_detector = h.NetCon(cell.soma(0.5)._ref_v, None, sec=cell.soma)
    pc.set_gid2node(0, pc.id())
    pc.cell(0, spike_detector)

    return pc, cell


def _make_stimulus(cell, amp: float, delay: float, dur: float):
    # (legacy helper removed) kept for backward compatibility of older notes;
    # actual stimulus setup is in main() to allow richer patterns.
    raise NotImplementedError("Use main() stimulus setup instead.")


def _record_soma_v(cell):
    from neuron import h

    v = h.Vector().record(cell.soma(0.5)._ref_v)
    return v


def _export_model(pc, export_dir: Path, dt: float, v_init: float) -> None:
    from neuron import h

    export_dir.mkdir(parents=True, exist_ok=True)
    h.dt = dt
    h.finitialize(v_init)
    pc.nrnbbcore_write(str(export_dir))


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--runtime", type=float, default=300.0, help="Runtime in ms")
    parser.add_argument("--dt", type=float, default=0.05, help="Timestep in ms")
    parser.add_argument("--v-init", type=float, default=-62.5, help="Initial voltage (mV)")
    parser.add_argument("--delay", type=float, default=20.0, help="IClamp delay (ms)")
    parser.add_argument("--dur", type=float, default=200.0, help="IClamp duration (ms)")
    parser.add_argument(
        "--stim",
        choices=["soma_single", "dend_69"],
        default="dend_69",
        help="Stimulus setup. dend_69 matches the stronger dendritic-stim style used in prior tests.",
    )
    parser.add_argument(
        "--stim-pattern",
        choices=["constant", "ramp", "random"],
        default="constant",
        help="How to assign amplitudes for dend_69.",
    )
    parser.add_argument(
        "--amp",
        type=float,
        default=0.02,
        help="Base IClamp amplitude. For dend_69+ramp, dend i uses amp*(i+1). For constant, all dend use amp.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1000,
        help="Random seed for stim-pattern=random (dend_69 only).",
    )
    parser.add_argument(
        "--random-max",
        type=float,
        default=0.04,
        help="Upper bound for random amplitudes (uniform in [0, random-max]).",
    )
    parser.add_argument("--device", choices=["cpu", "gpu"], default=os.environ.get("HELIOX_DEVICE", "cpu"))
    parser.add_argument(
        "--permute-type",
        type=int,
        default=int(os.environ.get("HELIOX_PERMUTE_TYPE", "-1")),
        help="HelioX permute type (default: 0 for cpu, 3 for gpu)",
    )
    parser.add_argument("--export-dir", default="repro/coredat-repro", help="Export directory for coredat")
    parser.add_argument("--out-dir", default="repro/output", help="Output directory for plots/data")
    args = parser.parse_args()

    _configure_matplotlib_backend()
    _ensure_heliox_on_path()

    from neuron import h  # noqa: F401

    _disable_usetable()

    pc, cell = _build_model()
    v_neuron_vec = _record_soma_v(cell)

    # Create stimulus in NEURON BEFORE export, so it is carried into HelioX by the export.
    # This avoids any ambiguity about HelioX-side instance ordering / indices.
    from neuron import h

    stims = []
    stim_desc = ""
    if args.stim == "soma_single":
        stim = h.IClamp(cell.soma(0.5))
        stim.delay = args.delay
        stim.dur = args.dur
        stim.amp = args.amp
        stims.append(stim)
        stim_desc = f"soma_single amp={args.amp} delay={args.delay} dur={args.dur}"
    else:
        if args.stim_pattern == "constant":
            amps = np.full(69, float(args.amp), dtype=float)
        elif args.stim_pattern == "ramp":
            amps = np.array([float(args.amp) * (i + 1) for i in range(69)], dtype=float)
        else:
            rng = np.random.default_rng(int(args.seed))
            amps = rng.uniform(0.0, float(args.random_max), size=69).astype(float)

        for i in range(69):
            stim = h.IClamp(cell.dend[i](0.5))
            stim.delay = args.delay
            stim.dur = args.dur
            stim.amp = float(amps[i])
            stims.append(stim)

        stim_desc = f"dend_69 pattern={args.stim_pattern} base_amp={args.amp} seed={args.seed} max={args.random_max} delay={args.delay} dur={args.dur}"

    export_dir = Path(args.export_dir)
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    # Export after stimulus is set: HelioX should inherit the exact parameters.
    _export_model(pc, export_dir=export_dir, dt=args.dt, v_init=args.v_init)

    # Setup HelioX
    import heliox
    from heliox_monitor import make_segment_monitor

    client = heliox.Sim()
    client.set_data_path(str(export_dir))
    client.set_device(args.device)

    permute = args.permute_type
    if permute < 0:
        permute = 0 if args.device == "cpu" else 3
    client.set_permute_type(int(permute))

    v_mon = make_segment_monitor(cell.soma(0.5), "v", client)
    rc = client.load_model()
    if rc != 0:
        raise RuntimeError(f"HelioX load_model() failed with code {rc}")

    # Run both once
    h.dt = args.dt
    h.finitialize(args.v_init)
    client.set_dt(args.dt)
    client.finitialize(args.v_init)

    pc.psolve(args.runtime)
    client.run(args.runtime)

    v_neuron = np.asarray(v_neuron_vec.as_numpy(), dtype=float)
    v_heliox = np.asarray(v_mon.get_data(), dtype=float)

    # Make lengths consistent (HelioX may return exactly same, but be defensive)
    n = min(v_neuron.size, v_heliox.size)
    v_neuron = v_neuron[:n]
    v_heliox = v_heliox[:n]

    t = np.arange(n, dtype=float) * args.dt
    abs_err = np.abs(v_neuron - v_heliox)

    # Stats
    max_err = float(abs_err.max()) if abs_err.size else float("nan")
    mean_err = float(abs_err.mean()) if abs_err.size else float("nan")
    rms_err = float(np.sqrt(np.mean((v_neuron - v_heliox) ** 2))) if abs_err.size else float("nan")

    # Save raw traces for debugging
    np.savez(
        out_dir / "repro_traces.npz",
        t=t,
        v_neuron=v_neuron,
        v_heliox=v_heliox,
        abs_err=abs_err,
        dt=args.dt,
        runtime=args.runtime,
        device=args.device,
        permute_type=int(permute),
    )

    import matplotlib.pyplot as plt

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
    ax1.plot(t, v_neuron, label="NEURON", linewidth=1.5)
    ax1.plot(t, v_heliox, label="HelioX", linewidth=1.2, alpha=0.85)
    ax1.set_ylabel("V (mV)")
    ax1.set_title(
        f"L5PC soma.v single-run | device={args.device} permute={permute} | "
        f"{stim_desc} | "
        f"max={max_err:.3e} mean={mean_err:.3e} rms={rms_err:.3e} mV"
    )
    ax1.legend(loc="best")
    ax1.grid(True, alpha=0.25)

    ax2.plot(t, abs_err, color="crimson", linewidth=1.0)
    ax2.set_xlabel("Time (ms)")
    ax2.set_ylabel("|ΔV| (mV)")
    ax2.set_yscale("log")
    ax2.grid(True, which="both", alpha=0.25)

    fig.tight_layout()
    out_png = out_dir / "repro_overlay_and_abs_error.png"
    fig.savefig(out_png, dpi=200)

    print("OK: single-run comparison finished")
    print(f"  export_dir: {export_dir}")
    print(f"  plot: {out_png}")
    print(f"  npz:  {out_dir / 'repro_traces.npz'}")
    print(f"  max_abs_err_mV: {max_err:.12e}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
