#!/usr/bin/env python3
"""
Phase 2: load saved traces and compare against NEURON baseline.

Expected input archives come from run_save_traces.py and contain:
  - t: time (ms)
  - v: soma voltage trace (mV)
  - meta_json: JSON metadata string

This script:
  - loads baseline NEURON trace
  - loads up to 3 other traces (CoreNEURON, HelioX CPU, HelioX GPU)
  - aligns each trace to baseline (optional integer-sample lag search)
  - plots overlay + absolute error vs baseline
  - saves a summary .npz with per-mode errors and stats
"""

from __future__ import annotations

import argparse
import json
import os
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal

import numpy as np


ModeName = Literal["coreneuron_cpu", "coreneuron_gpu", "heliox_cpu", "heliox_gpu"]


@dataclass(frozen=True)
class Trace:
    t: np.ndarray
    v: np.ndarray
    meta: dict[str, Any]
    path: Path


def _load_trace(path: Path) -> Trace:
    data = np.load(path, allow_pickle=False)
    t = np.asarray(data["t"], dtype=float).ravel()
    v = np.asarray(data["v"], dtype=float).ravel()
    meta_json = str(data["meta_json"])
    meta = json.loads(meta_json)
    n = min(t.size, v.size)
    return Trace(t=t[:n], v=v[:n], meta=meta, path=path)


def _rms(x: np.ndarray) -> float:
    x = np.asarray(x, dtype=float)
    if x.size == 0:
        return float("nan")
    return float(np.sqrt(np.mean(x * x)))


def _align_by_best_lag(
    t_ref: np.ndarray,
    v_ref: np.ndarray,
    t_x: np.ndarray,
    v_x: np.ndarray,
    *,
    dt: float,
    max_lag_ms: float,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, float]:
    """
    Brute-force search integer lag that minimizes RMS error over overlap.

    Returns (t_ref2, v_ref2, t_x2, v_x2, lag_samples, best_rms).
    Convention: lag > 0 means x is delayed relative to ref.
    """
    max_lag_samples = int(round(float(max_lag_ms) / float(dt)))
    max_lag_samples = max(0, max_lag_samples)

    n = min(v_ref.size, v_x.size)
    v_ref = v_ref[:n]
    v_x = v_x[:n]
    t_ref = t_ref[:n]
    t_x = t_x[:n]

    if max_lag_samples == 0 or n == 0:
        return t_ref, v_ref, t_x, v_x, 0, _rms(v_ref - v_x)

    max_lag_samples = min(max_lag_samples, n - 1)
    best_lag = 0
    best_rms = float("inf")
    best_slices = (slice(None), slice(None))

    for lag in range(-max_lag_samples, max_lag_samples + 1):
        if lag < 0:
            a = v_ref[-lag:]
            b = v_x[: a.size]
            sl_a = slice(-lag, None)
            sl_b = slice(0, a.size)
        elif lag > 0:
            a = v_ref[:-lag]
            b = v_x[lag : lag + a.size]
            sl_a = slice(0, a.size)
            sl_b = slice(lag, lag + a.size)
        else:
            a = v_ref
            b = v_x
            sl_a = slice(None)
            sl_b = slice(None)

        if a.size == 0:
            continue
        rms = _rms(a - b)
        if rms < best_rms:
            best_rms = rms
            best_lag = lag
            best_slices = (sl_a, sl_b)

    sl_a, sl_b = best_slices
    return t_ref[sl_a], v_ref[sl_a], t_x[sl_b], v_x[sl_b], int(best_lag), float(best_rms)


def _standard_paths(out_dir: Path, case: str) -> dict[str, Path]:
    return {
        "neuron": out_dir / f"{case}__neuron.npz",
        "coreneuron_cpu": out_dir / f"{case}__coreneuron_cpu.npz",
        "coreneuron_gpu": out_dir / f"{case}__coreneuron_gpu.npz",
        "heliox_cpu": out_dir / f"{case}__heliox_cpu.npz",
        "heliox_gpu": out_dir / f"{case}__heliox_gpu.npz",
    }


def _atomic_savez(path: Path, **kwargs: Any) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    # Ensure the temp filename ends with ".npz" so numpy doesn't append another
    # extension (which would leave the temp file empty).
    with tempfile.NamedTemporaryFile(
        prefix=path.stem + ".tmp.",
        suffix=path.suffix,
        dir=str(path.parent),
        delete=False,
    ) as f:
        tmp_path = Path(f.name)
    try:
        np.savez(tmp_path, **kwargs)
        os.replace(tmp_path, path)
    finally:
        try:
            tmp_path.unlink(missing_ok=True)
        except Exception:
            pass


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--out-dir", default="output", help="Directory holding saved .npz archives.")
    parser.add_argument("--case", required=True, help="Case name used in output filenames.")
    parser.add_argument(
        "--baseline",
        default="neuron",
        choices=["neuron", "coreneuron_cpu", "coreneuron_gpu", "heliox_cpu", "heliox_gpu"],
        help="Which saved trace to use as baseline for comparison (default: neuron).",
    )
    parser.add_argument("--no-align", action="store_true", help="Disable lag estimation/alignment.")
    parser.add_argument("--max-lag-ms", type=float, default=5.0, help="Alignment search window (ms).")
    parser.add_argument(
        "--log-abs-error",
        action=argparse.BooleanOptionalAction,
        default=False,
        help="Plot abs error on a log scale (default: disabled).",
    )
    parser.add_argument(
        "--abs-error-unit",
        type=float,
        default=1e-8,
        help="Scale factor for abs error plot (default: 1e-8, i.e. plot abs_err / 1e-8).",
    )
    parser.add_argument(
        "--abs-error-ymax",
        type=float,
        default=1.5,
        help="Y-axis upper bound for the scaled abs error plot (linear scale only).",
    )
    parser.add_argument(
        "--modes",
        nargs="*",
        default=["coreneuron_cpu", "coreneuron_gpu", "heliox_cpu", "heliox_gpu"],
        choices=["coreneuron_cpu", "coreneuron_gpu", "heliox_cpu", "heliox_gpu"],
        help="Which non-baseline modes to include (files must exist).",
    )
    parser.add_argument(
        "--png",
        default="",
        help=(
            "Deprecated. If set and --overlay-png/--error-png are not set, "
            "writes <png> (overlay) and <png>_abs_error.png (error)."
        ),
    )
    parser.add_argument(
        "--overlay-png",
        default="",
        help="Optional: also write overlay-only PNG path (default: disabled).",
    )
    parser.add_argument(
        "--error-png",
        default="",
        help="Optional: also write abs-error-only PNG path (default: disabled).",
    )
    parser.add_argument("--npz", default="", help="Output NPZ path (default: out-dir/<case>__errors.npz).")
    parser.add_argument(
        "--fig-width",
        type=float,
        default=4.8,
        help="Figure width in inches (default: 4.8; slightly wider for 2-row layout).",
    )
    parser.add_argument(
        "--fig-height",
        type=float,
        default=3.4,
        help="Figure height in inches (default: 3.4; 2-row layout).",
    )
    parser.add_argument("--dpi", type=int, default=300, help="Figure DPI (default: 300).")
    args = parser.parse_args()

    out_dir = Path(args.out_dir)
    paths = _standard_paths(out_dir, str(args.case))

    base_key = str(args.baseline)
    base_path = paths[base_key]
    if not base_path.exists():
        raise FileNotFoundError(f"Missing baseline trace file '{base_key}': {base_path}")
    base = _load_trace(base_path)
    dt = float(base.meta.get("dt_ms", float("nan")))
    if not np.isfinite(dt) or dt <= 0:
        # Fallback: compute from time vector.
        dt = float(np.median(np.diff(base.t))) if base.t.size >= 2 else 0.0
    if dt <= 0:
        raise RuntimeError("Could not determine dt from baseline trace.")

    others: dict[str, Trace] = {}
    for m in args.modes:
        p = paths[m]
        if not p.exists():
            raise FileNotFoundError(f"Missing trace file for mode '{m}': {p}")
        others[m] = _load_trace(p)

    # Compare
    compare_rows: list[dict[str, Any]] = []
    aligned: dict[str, dict[str, Any]] = {}

    for name, tr in others.items():
        if args.no_align:
            n = min(base.v.size, tr.v.size)
            t_ref = base.t[:n]
            v_ref = base.v[:n]
            t_x = tr.t[:n]
            v_x = tr.v[:n]
            lag = 0
            best_rms = _rms(v_ref - v_x)
        else:
            t_ref, v_ref, t_x, v_x, lag, best_rms = _align_by_best_lag(
                base.t,
                base.v,
                tr.t,
                tr.v,
                dt=dt,
                max_lag_ms=float(args.max_lag_ms),
            )

        abs_err = np.abs(v_ref - v_x)
        row = {
            "mode": name,
            "n": int(abs_err.size),
            "lag_samples": int(lag),
            "lag_ms": float(lag) * dt,
            "max_abs_err_mV": float(abs_err.max()) if abs_err.size else float("nan"),
            "mean_abs_err_mV": float(abs_err.mean()) if abs_err.size else float("nan"),
            "rms_err_mV": float(_rms(v_ref - v_x)),
            "best_rms_mV": float(best_rms),
        }
        compare_rows.append(row)
        aligned[name] = {
            "t": t_ref,
            "v_base": v_ref,
            "v": v_x,
            "abs_err": abs_err,
            "stats": row,
        }

    # Persist summary for downstream use.
    npz_path = Path(args.npz) if args.npz else out_dir / f"{args.case}__errors.npz"
    npz_payload: dict[str, Any] = {
        "case": str(args.case),
        "dt_ms": float(dt),
        "baseline_path": str(base.path),
        "compare_json": json.dumps(compare_rows, ensure_ascii=False, sort_keys=True),
        "baseline_v": base.v,
        "baseline_t": base.t,
    }
    for name, d in aligned.items():
        npz_payload[f"{name}_t"] = d["t"]
        npz_payload[f"{name}_v"] = d["v"]
        npz_payload[f"{name}_abs_err"] = d["abs_err"]
        npz_payload[f"{name}_lag_samples"] = np.array([d["stats"]["lag_samples"]], dtype=int)

    _atomic_savez(npz_path, **npz_payload)

    # Plot (publication-friendly): combined 2-row figure (overlay + abs error).
    import matplotlib.pyplot as plt

    def _display_name(mode: str) -> str:
        if mode == "neuron":
            return "NEURON"
        if mode == "coreneuron_cpu":
            return "CoreNEURON (CPU)"
        if mode == "coreneuron_gpu":
            return "CoreNEURON (GPU)"
        if mode == "heliox_cpu":
            return "HelioX (CPU)"
        if mode == "heliox_gpu":
            return "HelioX (GPU)"
        return mode

    colors = {
        "coreneuron_cpu": "#1f77b4",
        "coreneuron_gpu": "#17becf",
        "heliox_cpu": "#ff7f0e",
        "heliox_gpu": "#2ca02c",
    }
    linestyles = {
        "coreneuron_cpu": (0, (5, 2)),  # dashed
        "coreneuron_gpu": (0, (1, 2)),  # dotted
        "heliox_cpu": (0, (3, 1, 1, 1)),  # dash-dot-ish
        "heliox_gpu": (0, (7, 2, 2, 2)),  # long-short
    }

    fig_w = float(args.fig_width)
    fig_h = float(args.fig_height)
    dpi = int(args.dpi)

    combined_default = out_dir / f"{args.case}__compare.png"
    combined_png = Path(args.png) if args.png else combined_default
    overlay_png = Path(args.overlay_png) if args.overlay_png else None
    err_png = Path(args.error_png) if args.error_png else None

    unit = float(args.abs_error_unit)
    if not np.isfinite(unit) or unit <= 0:
        raise ValueError("--abs-error-unit must be a positive finite number")

    # Combined figure
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(fig_w, fig_h), dpi=dpi, sharex=True)

    # Top: overlay (no title; styles chosen so overlapping traces are distinguishable)
    ax1.set_title("Voltage traces", fontsize=9)
    ax1.plot(base.t, base.v, label=_display_name(base_key), linewidth=0.9, color="black", alpha=0.6, zorder=1)
    for name, tr in others.items():
        ax1.plot(
            tr.t,
            tr.v,
            label=_display_name(name),
            linewidth=1.0,
            alpha=0.95,
            color=colors.get(name, None),
            linestyle=linestyles.get(name, "--"),
            zorder=3,
        )
    ax1.set_ylabel("V (mV)")
    ax1.grid(True, alpha=0.25)
    ax1.legend(loc="best", fontsize=7, frameon=False, handlelength=2.8)

    # Bottom: abs error (scaled units, no title, annotate max errors for CoreNEURON vs HelioX)
    ax2.set_title("Absolute error vs NEURON", fontsize=9)
    for name, d in aligned.items():
        y = np.asarray(d["abs_err"], dtype=float) / unit
        if args.log_abs_error:
            y = np.maximum(y, 1e-18)
        ax2.plot(
            d["t"],
            y,
            label=_display_name(name),
            linewidth=1.1,
            alpha=0.95,
            color=colors.get(name, None),
        )
    ax2.set_ylabel(f"abs error (×{unit:.0e} mV)")
    if args.log_abs_error:
        ax2.set_yscale("log")
    else:
        ymax = float(args.abs_error_ymax)
        if np.isfinite(ymax) and ymax > 0:
            ax2.set_ylim(0.0, ymax)
    ax2.grid(True, alpha=0.25)
    ax2.legend(loc="best", fontsize=7, frameon=False)

    def _annotate_category(category: str, candidates: list[str], text_dx: int, text_dy: int) -> None:
        best_name = None
        best_max = -1.0
        best_idx = 0
        for name in candidates:
            if name not in aligned:
                continue
            y = np.asarray(aligned[name]["abs_err"], dtype=float) / unit
            if y.size == 0:
                continue
            idx = int(np.argmax(y))
            m = float(y[idx])
            if m > best_max:
                best_max = m
                best_name = name
                best_idx = idx
        if best_name is None:
            return
        t_arr = np.asarray(aligned[best_name]["t"], dtype=float)
        abs_err_mV_arr = np.asarray(aligned[best_name]["abs_err"], dtype=float)
        y_arr = abs_err_mV_arr / unit
        x0 = float(t_arr[best_idx])
        y0 = float(y_arr[best_idx])
        y0_mV = float(abs_err_mV_arr[best_idx])
        # Label uses the *original* physical units (mV), not the scaled plot units.
        # Also put the category on its own line for cleaner layout.
        label = f"{category}\nMAX diff\n{y0_mV:.2e} mV"
        ax2.annotate(
            label,
            xy=(x0, y0),
            xytext=(text_dx, text_dy),
            textcoords="offset points",
            ha="left" if text_dx >= 0 else "right",
            va="bottom" if text_dy >= 0 else "top",
            fontsize=9,
            color="black",
            bbox=dict(boxstyle="round,pad=0.2", fc="white", ec=colors.get(best_name, "black"), lw=0.8, alpha=0.9),
            arrowprops=dict(arrowstyle="->", lw=0.9, color=colors.get(best_name, "black")),
        )

    _annotate_category("CoreNEURON", ["coreneuron_cpu", "coreneuron_gpu"], text_dx=16, text_dy=16)
    _annotate_category("HelioX", ["heliox_cpu", "heliox_gpu"], text_dx=-16, text_dy=18)

    ax2.set_xlabel("t (ms)")
    fig.tight_layout(pad=0.25)
    fig.savefig(combined_png, dpi=dpi)
    plt.close(fig)
    print(f"Saved plot: {combined_png}")

    # Optional: also write split figures if user asked for them explicitly.
    if overlay_png is not None:
        fig_o, ax_o = plt.subplots(1, 1, figsize=(fig_w, fig_h / 2.0), dpi=dpi)
        ax_o.plot(base.t, base.v, label="NEURON", linewidth=0.9, color="black", alpha=0.6, zorder=1)
        for name, tr in others.items():
            ax_o.plot(
                tr.t,
                tr.v,
                label=_display_name(name),
                linewidth=1.0,
                alpha=0.95,
                color=colors.get(name, None),
                linestyle=linestyles.get(name, "--"),
                zorder=3,
            )
        ax_o.set_xlabel("t (ms)")
        ax_o.set_ylabel("V (mV)")
        ax_o.grid(True, alpha=0.25)
        ax_o.legend(loc="best", fontsize=7, frameon=False, handlelength=2.8)
        fig_o.tight_layout(pad=0.2)
        fig_o.savefig(overlay_png, dpi=dpi)
        plt.close(fig_o)
        print(f"Saved overlay: {overlay_png}")

    if err_png is not None:
        fig_e, ax_e = plt.subplots(1, 1, figsize=(fig_w, fig_h / 2.0), dpi=dpi)
        for name, d in aligned.items():
            y = np.asarray(d["abs_err"], dtype=float) / unit
            if args.log_abs_error:
                y = np.maximum(y, 1e-18)
            ax_e.plot(d["t"], y, label=_display_name(name), linewidth=1.1, alpha=0.95, color=colors.get(name, None))
        ax_e.set_xlabel("t (ms)")
        ax_e.set_ylabel(f"abs error (×{unit:.0e} mV)")
        if args.log_abs_error:
            ax_e.set_yscale("log")
        else:
            ymax = float(args.abs_error_ymax)
            if np.isfinite(ymax) and ymax > 0:
                ax_e.set_ylim(0.0, ymax)
        ax_e.grid(True, alpha=0.25)
        ax_e.legend(loc="best", fontsize=7, frameon=False)
        _annotate_category("CoreNEURON", ["coreneuron_cpu", "coreneuron_gpu"], text_dx=16, text_dy=16)
        _annotate_category("HelioX", ["heliox_cpu", "heliox_gpu"], text_dx=-16, text_dy=18)
        fig_e.tight_layout(pad=0.2)
        fig_e.savefig(err_png, dpi=dpi)
        plt.close(fig_e)
        print(f"Saved abs error: {err_png}")
    print(f"Saved errors: {npz_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
