#!/usr/bin/env python3
"""
Simplified L5PC comparison between NEURON and HelioX frameworks.
Builds the model once, then runs both frameworks and compares results.
"""

import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from neuron import h, gui

# Add HelioX python_lib path (prefer repo checkout, fallback to older path)
_HELIOX_PYLIB_CANDIDATES = [
    os.environ.get("HELIOX_PYTHON_LIB", "").strip(),
    "$HOME/heliox/python_lib",
    "$HOME/Documents/heliox/python_lib",
]
for _p in _HELIOX_PYLIB_CANDIDATES:
    if _p and os.path.isdir(_p) and _p not in sys.path:
        sys.path.insert(0, _p)
        break

import heliox
from heliox_monitor import make_segment_monitor

# Disable lookup tables for accurate comparison
h.usetable_hh = 0
h.usetable_sca_la = 0
h.usetable_it2_la = 0
h.usetable_kdf_la = 0
h.usetable_kdr_la = 0
h.usetable_km_la = 0

# Alignment knobs (env vars):
# - ALIGN_TRACES=0 to disable lag estimation / alignment
# - ALIGN_MAX_LAG_MS=5.0 sets ± search window in ms

class L5PC_Comparator:
    """Compare L5PC neuron simulation between NEURON and HelioX"""

    def __init__(self, num_cells=1):
        """Initialize the simulator with model building"""
        print("Initializing L5PC model...")

        # Load NEURON files
        h.load_file("import3d.hoc")
        h.load_file('stdgui.hoc')
        h.load_file("L5PClatemplate_record.hoc")

        # Setup parallel context
        self.pc = h.ParallelContext()
        self.pc.nthread(1, 0)
        self.pc.set_maxstep(5)

        # Enable CoreNEURON
        h.cvode.cache_efficient(1)
        #from neuron import coreneuron
        #coreneuron.enable = True
        #coreneuron.verbose = 0
        #coreneuron.gpu = True
        #coreneuron.num_gpus = 1
        #coreneuron.cell_permute = 2

        h.stdinit()

        # Create cells
        self.num_cells = num_cells
        self.cells = []
        self.Ngid = 0

        for i in range(num_cells):
            cell = self.create_cell()
            self.cells.append(cell)

        # Initialize HelioX client
        self.heliox_client = heliox.Sim()

        print(f"Created {num_cells} L5PC cell(s)")

    @staticmethod
    def _rms(x: np.ndarray) -> float:
        x = np.asarray(x, dtype=float)
        return float(np.sqrt(np.mean(x * x)))

    @classmethod
    def _estimate_best_lag_samples(
        cls,
        v1: np.ndarray,
        v2: np.ndarray,
        max_lag_samples: int,
    ):
        """
        Estimate the integer sample lag between two 1D traces by minimizing RMS error.

        Convention:
        - lag > 0 means v2 is delayed (shifted right) relative to v1, so compare v1[:-lag] vs v2[lag:].
        - lag < 0 means v2 is advanced relative to v1, so compare v1[-lag:] vs v2[:len(...) ].
        """
        v1 = np.asarray(v1, dtype=float).ravel()
        v2 = np.asarray(v2, dtype=float).ravel()

        if v1.size == 0 or v2.size == 0:
            raise ValueError("Cannot estimate lag with empty traces.")

        max_lag_samples = int(max_lag_samples)
        if max_lag_samples < 0:
            raise ValueError("max_lag_samples must be >= 0.")

        max_lag_samples = min(max_lag_samples, v1.size - 1, v2.size - 1)
        if max_lag_samples == 0:
            return 0, v1[: min(v1.size, v2.size)], v2[: min(v1.size, v2.size)], cls._rms(v1[: min(v1.size, v2.size)] - v2[: min(v1.size, v2.size)])

        best_lag = 0
        best_rms = float("inf")
        best_a = None
        best_b = None

        # Brute force is fine here (typically just a few ms / few hundred samples).
        for lag in range(-max_lag_samples, max_lag_samples + 1):
            if lag < 0:
                a = v1[-lag:]
                b = v2[: a.size]
            elif lag > 0:
                a = v1[:-lag]
                b = v2[lag : lag + a.size]
            else:
                n = min(v1.size, v2.size)
                a = v1[:n]
                b = v2[:n]

            if a.size == 0 or b.size == 0:
                continue

            # Compute RMS difference for this overlap window
            rms = cls._rms(a - b)
            if rms < best_rms:
                best_rms = rms
                best_lag = lag
                best_a = a
                best_b = b

        if best_a is None or best_b is None:
            raise RuntimeError("Lag estimation failed to find any overlapping window.")

        return best_lag, best_a, best_b, float(best_rms)

    def create_cell(self):
        """Create a single L5PC cell with recording"""
        cell_obj = h.L5PClatemplate_record()

        # Create structure to hold cell data
        cell = {
            'cell': cell_obj,
            'dendnumber': 69,
            'v_neuron': h.Vector().record(cell_obj.soma(0.5)._ref_v),
            'iClamps': []
        }

        # Create IClamps on dendrites
        for i in range(cell['dendnumber']):
            iclamp = h.IClamp(cell_obj.dend[i](0.5))
            iclamp.delay = 20
            iclamp.dur = 200
            iclamp.amp = 0.01
            cell['iClamps'].append(iclamp)

        # Setup spike detector and gid
        spike_detector = h.NetCon(cell_obj.soma(0.5)._ref_v, None, sec=cell_obj.soma)
        gid = self.Ngid
        self.Ngid += 1
        self.pc.set_gid2node(gid, self.pc.id())
        self.pc.cell(gid, spike_detector)

        cell['gid'] = gid
        cell['spike_detector'] = spike_detector

        return cell

    def dump_and_load_model(self, model_path="coredat-comparison"):
        """Dump model for HelioX and load it"""
        print(f"\nDumping model to {model_path}...")

        h.dt = 0.05
        h.finitialize(-62.5)
        self.pc.nrnbbcore_write(model_path)

        print("Loading model into HelioX...")
        self.heliox_client.set_data_path(model_path)
        device = os.environ.get("HELIOX_DEVICE", "gpu").strip().lower()
        if device not in {"gpu", "cpu"}:
            raise ValueError(f"Invalid HELIOX_DEVICE={device!r}; expected 'gpu' or 'cpu'.")
        self.heliox_client.set_device(device)

        default_permute = "3" if device == "gpu" else "0"
        permute_type = int(os.environ.get("HELIOX_PERMUTE_TYPE", default_permute))
        self.heliox_client.set_permute_type(permute_type)

        monitor_debug = os.environ.get("HELIOX_MONITOR_DEBUG", "0").strip().lower() in {"1", "true", "yes", "on"}
        for cell in self.cells:
            cell['v_monitor'] = make_segment_monitor(cell['cell'].soma(0.5), "v", self.heliox_client)
            if monitor_debug:
                mw = cell['v_monitor']
                print(
                    f"HelioX monitor init: mech={mw._mech_name}, var={mw.var_name}, "
                    f"idx={mw.node_or_mech_idx}, pending_handle={mw.monitor_id}"
                )

        rc = self.heliox_client.load_model()
        if rc != 0:
            raise RuntimeError(f"HelioX load_model() failed with code {rc}.")
        print("Model loaded successfully")

    def set_input(self, input_currents):
        """Set input currents for all cells

        Args:
            input_currents: numpy array of shape (num_cells, 69)
        """
        for cell_idx, cell in enumerate(self.cells):
            for dend_idx in range(cell['dendnumber']):
                # Set for NEURON
                cell['iClamps'][dend_idx].amp = input_currents[cell_idx, dend_idx]

                # Set for HelioX
                object_id = int(h.object_id(cell['iClamps'][dend_idx], 1))
                self.heliox_client.set_variable_value(
                    input_currents[cell_idx, dend_idx],
                    "IClamp", "amp", object_id
                )

    def run_neuron(self, runtime=300, dt=0.05):
        """Run simulation with NEURON framework"""
        print(f"\nRunning NEURON simulation (runtime={runtime}ms, dt={dt}ms)...")

        h.dt = dt
        h.finitialize(-62.5)

        import time
        start_time = time.time()
        self.pc.psolve(runtime)
        elapsed = time.time() - start_time

        print(f"NEURON simulation completed in {elapsed:.2f}s")

        # Collect voltage traces
        v_traces = []
        for cell in self.cells:
            v_traces.append(cell['v_neuron'].as_numpy())

        return np.array(v_traces)

    def run_heliox(self, runtime=300, dt=0.05):
        """Run simulation with HelioX framework"""
        print(f"\nRunning HelioX simulation (runtime={runtime}ms, dt={dt}ms)...")

        self.heliox_client.set_dt(dt)
        self.heliox_client.finitialize(-62.5)

        import time
        start_time = time.time()
        self.heliox_client.run(runtime)
        elapsed = time.time() - start_time

        print(f"HelioX simulation completed in {elapsed:.2f}s")

        # Collect voltage traces
        if any('v_monitor' not in cell for cell in self.cells):
            raise RuntimeError("HelioX monitors not initialized. Call dump_and_load_model() first.")

        monitor_debug = os.environ.get("HELIOX_MONITOR_DEBUG", "0").strip().lower() in {"1", "true", "yes", "on"}
        v_traces = []
        for cell in self.cells:
            mw = cell['v_monitor']
            v_data = mw.get_data()
            if not v_data:
                raise RuntimeError(
                    f"HelioX monitor returned no data (mech={mw.mech_name}, var={mw.var_name}, "
                    f"idx={mw.node_or_mech_idx}, handle={mw.monitor_id}). "
                    "The simulation likely ran without recording, or the monitor is still pending/invalid."
                )
            if monitor_debug:
                print(
                    f"HelioX monitor resolved: mech={mw.mech_name}, var={mw.var_name}, "
                    f"idx={mw.node_or_mech_idx}, handle={mw.monitor_id}, n={len(v_data)}"
                )
            v_traces.append(np.asarray(v_data, dtype=float))

        return np.asarray(v_traces, dtype=float)

    def compare_and_plot(self, v_neuron, v_heliox, dt=0.05, output_file="comparison_result.png"):
        """Compare voltage traces and create visualization"""
        print("\nComparing results...")

        if v_neuron is None or v_heliox is None:
            raise ValueError("Missing voltage traces: v_neuron and v_heliox must both be provided.")
        if len(v_neuron) == 0 or len(v_heliox) == 0:
            raise ValueError(
                f"Empty voltage trace arrays: len(v_neuron)={len(v_neuron)}, len(v_heliox)={len(v_heliox)}. "
                "This usually means HelioX monitor data was not captured."
            )

        # Use first cell for comparison
        v1 = np.asarray(v_neuron[0], dtype=float).ravel()
        v2 = np.asarray(v_heliox[0], dtype=float).ravel()

        # Ensure same length
        min_len = min(len(v1), len(v2))
        if min_len == 0:
            raise ValueError(
                f"Cannot compare empty traces: len(NEURON)={len(v1)}, len(HelioX)={len(v2)}. "
                "Check HelioX monitor setup and that run() produced recorded samples."
            )
        v1_0 = v1[:min_len]
        v2_0 = v2[:min_len]

        # Calculate baseline (no alignment) statistics
        diff0 = np.abs(v1_0 - v2_0)
        max_diff_idx0 = int(np.argmax(diff0))
        max_diff_value0 = float(diff0[max_diff_idx0])
        mean_diff0 = float(np.mean(diff0))
        rms_diff0 = float(np.sqrt(np.mean(diff0**2)))

        # Optional lag estimation to detect sequence misalignment
        dt = float(dt)
        align = os.environ.get("ALIGN_TRACES", "1").strip().lower() not in {"0", "false", "no", "off"}
        align_max_lag_ms = float(os.environ.get("ALIGN_MAX_LAG_MS", "5.0"))
        max_lag_samples = int(round(align_max_lag_ms / dt))

        v1_aligned = v1_0
        v2_aligned = v2_0
        best_lag = 0
        rms_aligned = rms_diff0

        if align and max_lag_samples > 0:
            best_lag, v1_aligned, v2_aligned, rms_aligned = self._estimate_best_lag_samples(
                v1, v2, max_lag_samples=max_lag_samples
            )

        # Stats on (possibly) aligned traces
        min_len_aligned = min(v1_aligned.size, v2_aligned.size)
        v1_aligned = v1_aligned[:min_len_aligned]
        v2_aligned = v2_aligned[:min_len_aligned]
        diff = np.abs(v1_aligned - v2_aligned)
        max_diff_idx = int(np.argmax(diff))
        max_diff_value = float(diff[max_diff_idx])
        mean_diff = float(np.mean(diff))
        rms_diff = float(np.sqrt(np.mean(diff**2)))

        print(f"\nComparison Statistics:")
        print(f"  (Unaligned) Max difference:  {max_diff_value0:.6e} mV at index {max_diff_idx0}")
        print(f"  (Unaligned) Mean difference: {mean_diff0:.6e} mV")
        print(f"  (Unaligned) RMS difference:  {rms_diff0:.6e} mV")
        if align and max_lag_samples > 0:
            lag_ms = best_lag * dt
            print(f"  Estimated lag: {best_lag} samples ({lag_ms:+.4f} ms), searched ±{max_lag_samples} samples")
            print(f"  (Aligned)   Max difference:  {max_diff_value:.6e} mV at index {max_diff_idx}")
            print(f"  (Aligned)   Mean difference: {mean_diff:.6e} mV")
            print(f"  (Aligned)   RMS difference:  {rms_diff:.6e} mV (best RMS during search: {rms_aligned:.6e})")
        else:
            print(f"  Max difference:  {max_diff_value:.6e} mV at index {max_diff_idx}")
            print(f"  Mean difference: {mean_diff:.6e} mV")
            print(f"  RMS difference:  {rms_diff:.6e} mV")

        # Create time axis
        time_axis = np.arange(min_len_aligned) * dt
        max_diff_time = time_axis[max_diff_idx]

        # Create plot
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

        # Plot voltage traces
        ax1.plot(time_axis, v1_aligned, label='NEURON', color='blue', linewidth=1.5)
        ax1.plot(time_axis, v2_aligned, label='HelioX', color='red', linewidth=1.5, alpha=0.8)
        ax1.set_ylabel('Voltage (mV)', fontsize=12)
        title = 'L5PC Soma Voltage: NEURON vs HelioX'
        if align and max_lag_samples > 0 and best_lag != 0:
            title += f' (aligned, lag={best_lag} samples)'
        ax1.set_title(title, fontsize=14, fontweight='bold')
        ax1.legend(fontsize=11)
        ax1.grid(True, alpha=0.3)

        # Plot difference
        ax2.plot(time_axis, diff, color='green', linewidth=1.5)
        ax2.scatter(max_diff_time, max_diff_value, color='red', s=100, marker='*', zorder=5)
        ax2.annotate(f'Max diff: {max_diff_value:.6e} mV\\nTime: {max_diff_time:.2f} ms',
                    (max_diff_time, max_diff_value),
                    xytext=(max_diff_time + 20, max_diff_value * 1.2),
                    arrowprops=dict(facecolor='red', shrink=0.05, width=1.5),
                    fontsize=10,
                    bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.7))
        ax2.set_xlabel('Time (ms)', fontsize=12)
        ax2.set_ylabel('Absolute Difference (mV)', fontsize=12)
        ax2.set_title('Absolute Difference Between Frameworks', fontsize=14, fontweight='bold')
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()

        # Save figure
        os.makedirs('output', exist_ok=True)
        output_path = os.path.join('output', output_file)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"\nPlot saved to: {output_path}")

        # Show plot
        plt.show()

        return {
            'max_diff': max_diff_value,
            'mean_diff': mean_diff,
            'rms_diff': rms_diff,
            'max_diff_idx': max_diff_idx,
            'aligned': bool(align and max_lag_samples > 0),
            'lag_samples': int(best_lag),
            'lag_ms': float(best_lag * dt),
            'unaligned_max_diff': max_diff_value0,
            'unaligned_mean_diff': mean_diff0,
            'unaligned_rms_diff': rms_diff0,
        }


def main():
    """Main comparison workflow"""
    print("="*60)
    print("L5PC Neuron Comparison: NEURON vs HelioX")
    print("="*60)

    # Configuration
    num_cells = 1
    runtime = 300  # ms
    dt = 0.05      # ms
    input_scale = 0.02

    # Create comparator
    comparator = L5PC_Comparator(num_cells=num_cells)

    # Dump and load model (only once!)
    comparator.dump_and_load_model("coredat-comparison")

    # Create input currents
    input_currents = np.ones((num_cells, 69)) * input_scale
    print(f"\nSetting input currents (scale={input_scale})...")
    comparator.set_input(input_currents)

    # Run both simulations
    v_neuron = comparator.run_neuron(runtime=runtime, dt=dt)
    v_heliox = comparator.run_heliox(runtime=runtime, dt=dt)

    # Compare and visualize
    stats = comparator.compare_and_plot(v_neuron, v_heliox, dt=dt)

    print("\n" + "="*60)
    print("Comparison Complete!")
    print("="*60)


if __name__ == "__main__":
    main()
