from __future__ import annotations

import os
from dataclasses import dataclass
from typing import Any

import numpy as np
from neuron import h

from worm_network import Network
from worm_data_utils import load_json


@dataclass(frozen=True)
class WormModelBuildResult:
    config: dict[str, Any]
    config_file: str
    connection_file: str

    dt: float
    v_r: float
    tstop: float

    input_names: list[str]
    output_names: list[str]
    target: np.ndarray
    input_is: np.ndarray

    net: Network


def _load_mechanisms(mech_dir: str, logger=None) -> None:
    """
    Load MOD mechanisms when running via Python (not `special`).

    NOTE: In some environments mechanisms may already be loaded (e.g. via NRNMECH paths),
    and re-loading the same lib will error: "The user defined name already exists".
    We ignore that specific error and continue.
    """
    import shutil
    import subprocess
    import neuron

    # If mechanisms are already loaded, avoid calling `neuron.load_mechanisms` again.
    # Re-loading can print noisy "already exists" diagnostics from NEURON.
    try:
        probe = h.Section(name="__mech_probe__")
        probe.insert("cainternm_lr")
        return
    except Exception:
        pass

    # Ensure x86_64 exists (build from shipped mod files if needed).
    libnrnmech = os.path.join(mech_dir, "x86_64", "libnrnmech.so")
    if not os.path.exists(libnrnmech):
        nrnivmodl = shutil.which("nrnivmodl")
        if not nrnivmodl:
            raise RuntimeError(
                "Missing NEURON mod build tool 'nrnivmodl'. "
                "Install NEURON developer tools or ensure nrnivmodl is on PATH."
            )
        mod_dir = os.path.join(mech_dir, "components", "mechanism", "modfile")
        if logger:
            logger.info("Building NEURON MOD mechanisms via: %s %s", nrnivmodl, mod_dir)
        subprocess.check_call([nrnivmodl, mod_dir], cwd=mech_dir)

    try:
        neuron.load_mechanisms(mech_dir)
    except RuntimeError as e:
        if "already exists" not in str(e):
            raise
        if logger:
            logger.info(f"neuron.load_mechanisms({mech_dir}) skipped: already exists")


def _build_target_corr(base_dir: str, output_names_config: list[str]) -> tuple[list[str], np.ndarray]:
    with open(os.path.join(base_dir, "components", "cb2022_data", "Ca_corr_mat_cell_name.txt")) as f:
        output_names_target = f.read().split("\t")
    ca_corr = np.loadtxt(os.path.join(base_dir, "components", "cb2022_data", "Ca_corr_mat.txt"))
    output_names: list[str] = []
    output_ids: list[int] = []
    for i, cn in enumerate(output_names_target):
        if cn in output_names_config:
            output_names.append(cn)
            output_ids.append(i)
    target = ca_corr[output_ids, :][:, output_ids]
    return output_names, target


def build_worm_model(
    *,
    output_path: str,
    random_seed: int,
    K_mul: int,
    ngpu: int,
    K_max_t_default_ms: float,
    K_nblock: int,
    w_gap_max: float | None,
    w_gap_min: float,
    w_syn_max: float | None,
    w_syn_min: float,
    k_len: int | None = None,
    k_max_t_ms: float | None = None,
    tstop_override_ms: float | None = None,
    logger=None,
) -> WormModelBuildResult:
    """
    使用 NEURON 构建 worm 生物学模型并导出训练所需数据：
    - Read config & connection pickles produced by the circuit search pipeline
    - Load MOD mechanisms
    - Instantiate Network
    - Prepare initial input currents and target

    说明：
    - 训练/仿真的 runtime 由 HELIOX 后端负责（见 worm_trainables/worm_network）。
    - 这里仅负责“建模 + 数据准备”，不做学习。
    """
    np.random.seed(random_seed)

    config_file = os.path.join(output_path, "000_circuit_search_config.json")
    connection_file = os.path.join(output_path, "sample_#0_circuit_old.pkl")
    config = load_json(config_file)
    sim_config = config["sim_config"]

    dt = float(sim_config["dt"])
    v_r = float(sim_config["v_init"])
    tstop = float(tstop_override_ms) if tstop_override_ms is not None else float(sim_config["tstop"])

    base_dir = os.path.dirname(__file__)
    _load_mechanisms(base_dir, logger=logger)

    h.load_file("stdrun.hoc")
    h.dt = dt

    input_names = list(config["search_config"]["input_cell_names"])
    output_names_config = list(config["search_config"]["output_cell_names"])

    # Legacy default: random Gaussian + DC offset.
    input_is = np.random.normal(loc=0.0, scale=1e-3, size=(len(input_names), int(tstop / dt))) + 0.03

    # 本 demo 只支持 corr 目标（保持专注，避免为兼容旧模式引入复杂分支）。
    output_names, target = _build_target_corr(base_dir, output_names_config)

    if k_len is not None:
        # Ensure K_len derived inside `worm_network.py` stays exactly `k_len` after:
        #   K_len = int(K_max_t / (dt * K_mul))
        # by nudging with a tiny epsilon to avoid float rounding down.
        K_max_t_ms = float(k_len) * dt * float(K_mul) + 1e-6
    elif k_max_t_ms is not None:
        K_max_t_ms = float(k_max_t_ms)
    else:
        K_max_t_ms = float(K_max_t_default_ms)

    net_config = config["config"]
    if abs(K_max_t_ms - float(K_max_t_default_ms)) < 1e-6:
        K_filename = os.path.join(output_path, f"K_eworm_v4_x{K_mul}.npz")
    else:
        # Avoid collisions when users experiment with different history horizons.
        K_filename = os.path.join(output_path, f"K_eworm_v4_x{K_mul}_t{int(round(K_max_t_ms))}.npz")
    lr_config = {
        "v_r": v_r,
        "ngpu": ngpu,
        "K_max_t": K_max_t_ms,
        "K_filename": K_filename,
        "K_nblock": K_nblock,
        "K_mul": K_mul,
        "w_gap_max": w_gap_max,
        "w_gap_min": w_gap_min,
        "w_syn_max": w_syn_max,
        "w_syn_min": w_syn_min,
    }

    eworm_net = Network(net_config, lr_config, random_seed)
    eworm_net.read_cells_neurite_connection(connection_file, input_names)

    if logger:
        logger.info(f"model_builder: dt={dt}, v_r={v_r}, tstop={tstop}")
        logger.info(f"model_builder: N_input={len(input_names)} N_output={len(output_names)}")

    return WormModelBuildResult(
        config=config,
        config_file=config_file,
        connection_file=connection_file,
        dt=dt,
        v_r=v_r,
        tstop=tstop,
        input_names=input_names,
        output_names=output_names,
        target=target,
        input_is=input_is,
        net=eworm_net,
    )


def attach_heliox_backend(
    net: Network,
    output_names: list[str],
    *,
    dt: float,
    v_init: float,
    output_path: str,
    export_path: str | None = None,
) -> Any:
    from worm_network import WormHelioXRuntime

    if export_path is None:
        export_path = os.path.join(output_path, "heliox_export")
    backend = WormHelioXRuntime(net, output_names, dt=dt, v_init=v_init, export_path=export_path)
    net.attach_heliox_backend(backend)
    net.set_weights()
    return backend
