#!/usr/bin/env python3
"""
HelioX training script.
"""

from __future__ import annotations

import argparse
import os
import sys
from pathlib import Path
from time import perf_counter
from typing import List, Dict, Tuple, Optional

import numpy as np
from neuron import h  # type: ignore

# Attempt to locate the HelioX python bindings automatically.
REPO_ROOT = Path(__file__).resolve().parent


def _ensure_heliox_on_syspath() -> None:
    """
    Ensure both HelioX's Python wrapper (`python_lib/`) and the compiled extension
    module (`build/heliox*.so`) are importable.

    Priority order:
    1) $HELIOX_HOME
    """

    candidates = []
    env_home = os.environ.get("HELIOX_HOME")
    if env_home:
        candidates.append(Path(env_home))
    # Intentionally avoid hardcoded paths with legacy naming.

    for base in candidates:
        try:
            if not base.is_dir():
                continue
        except OSError:
            continue
        for sub in ("python_lib", "build"):
            path = (base / sub).resolve()
            if path.is_dir():
                path_str = str(path)
                if path_str not in sys.path:
                    sys.path.insert(0, path_str)


_ensure_heliox_on_syspath()

from hybrid_backend import BackendConfig, HybridBackend  # noqa: E402
from sequential import Sequential  # noqa: E402
from layers import InputLayer, DenseLayer, SoftMaxLayer  # noqa: E402
from optimizer import SGD, SGDMomentum, Adam  # noqa: E402
from connection_pattern import ConnectionPattern  # noqa: E402


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Train Sequential model with HelioX.")
    parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs.")
    parser.add_argument("--train-samples", type=int, default=500, help="Number of training samples to use.")
    parser.add_argument("--test-samples", type=int, default=100, help="Number of test samples to evaluate.")
    parser.add_argument("--hidden-size", type=int, default=64, help="Hidden layer size (single hidden layer).")
    parser.add_argument(
        "--hidden-sizes",
        type=str,
        default="",
        help="Comma-separated hidden layer sizes, e.g. 128,64 for 784-128-64-10.",
    )
    parser.add_argument("--batch-size", type=int, default=1, help="Number of network replicas for batch training.")
    parser.add_argument("--dt", type=float, default=1.0, help="Simulation timestep (ms).")
    parser.add_argument("--sim-time", type=float, default=50.0, help="Simulation duration per sample (ms).")
    parser.add_argument("--v-init", type=float, default=0.0, help="Initial membrane potential (mV).")
    parser.add_argument("--optimizer", choices=["sgd", "momentum", "adam"], default="adam", help="Optimizer type.")
    parser.add_argument("--learning-rate", type=float, default=0.002, help="Optimizer learning rate.")
    parser.add_argument("--momentum", type=float, default=0.9, help="Momentum factor for SGDMomentum.")
    parser.add_argument("--optimizer-backend", choices=["auto", "heliox"], default="auto",
                        help="Backend to use when collecting gradients.")
    parser.add_argument("--disable-heliox-optimizer", action="store_true",
                        help="Disable HelioX's internal optimizer and use Python-side updates.")
    parser.add_argument("--export-path", type=str, default="heliox_export",
                        help="Directory used for HelioX model export.")
    parser.add_argument("--heliox-device", type=str, default="gpu", help="HelioX device (cpu/gpu).")
    parser.add_argument("--heliox-permute", type=int, default=3, help="HelioX permute type.")
    parser.add_argument("--timing-interval", type=int, default=100,
                        help="Report finitialize/run timings every N samples (<=0 disables).")
    parser.add_argument("--progress-interval", type=int, default=1000,
                        help="Print training progress every N samples (<=0 disables).")
    parser.add_argument("--shuffle", action="store_true",
                        help="Shuffle training samples at the start of each epoch.")
    parser.add_argument("--record-time", type=float, default=0.0,
                        help="Recording window for gradients in ms (0 uses --sim-time).")
    parser.add_argument("--debug-interval", type=int, default=0,
                        help="Print debug stats every N samples (<=0 disables).")
    parser.add_argument("--lr-decay", type=float, default=1.0,
                        help="Multiply learning rate by this factor after each epoch.")
    parser.add_argument("--seed", type=int, default=1234, help="Random seed for reproducibility.")
    parser.add_argument("--use-netstim", action="store_true", help="Use NetStim instead of VecStim for input layer.")
    return parser.parse_args()


def build_model(args: argparse.Namespace) -> Dict[str, object]:
    if args.batch_size <= 0:
        raise ValueError("Batch size must be a positive integer.")

    h.Random().Random123_globalindex(args.seed)
    h.load_file('stdgui.hoc')
    pc = h.ParallelContext()
    pc.nthread(1, 0)

    backend_config = BackendConfig(
        enable_heliox=True,
        export_path=args.export_path,
        device=args.heliox_device,
        permute_type=args.heliox_permute,
    )

    connection_pattern = ConnectionPattern(seed=args.seed)
    shared_backend: Optional[HybridBackend] = None
    if args.batch_size > 1:
        shared_backend = HybridBackend(backend_config)

    def _parse_hidden_sizes() -> List[int]:
        if args.hidden_sizes:
            sizes = []
            for part in args.hidden_sizes.split(","):
                part = part.strip()
                if not part:
                    continue
                sizes.append(int(part))
            if not sizes:
                raise ValueError("--hidden-sizes provided but no valid sizes parsed.")
            return sizes
        return [args.hidden_size]

    hidden_sizes = _parse_hidden_sizes()

    def make_layers() -> List:
        layers: List = [InputLayer(784, use_vecstim=not args.use_netstim)]
        for size in hidden_sizes:
            layers.append(
                DenseLayer(
                    size,
                    neuron_type='passive_hpc',
                    morph_file='2013_03_06_cell11_1125_H41_06.asc'
                )
            )
        layers.extend([
            DenseLayer(10, neuron_type='point'),
            SoftMaxLayer(),
        ])
        return layers

    networks: List[Sequential] = []
    for _ in range(args.batch_size):
        layers = make_layers()
        backend = shared_backend if shared_backend is not None else None
        model = Sequential(
            layers,
            seed=args.seed,
            connection_pattern=connection_pattern,
            backend_config=backend_config,
            backend=backend,
        )
        model.build()
        networks.append(model)

    networks[0].reset_weights()
    for replica in networks[1:]:
        replica.sync_weights_from(networks[0])

    pc.setup_transfer()
    pc.set_maxstep(10)

    h.cvode.cache_efficient(1)
    backend = networks[0].backend
    if hasattr(backend, "set_optimizer_config"):
        backend.set_optimizer_config(
            optimizer_type=args.optimizer,
            momentum=args.momentum,
            beta1=0.9,
            beta2=0.999,
            epsilon=1e-8,
        )
    if shared_backend is not None:
        shared_backend.configure_batch_optimizer(networks)
    for net in networks:
        net.initialize_backends(dt=args.dt, v_init=args.v_init, export_path=args.export_path)
    backend.set_dt(args.dt)

    return {
        "networks": networks,
        "backend_config": backend_config,
    }


def make_optimizer(args: argparse.Namespace):
    if args.optimizer == "sgd":
        opt = SGD(learning_rate=args.learning_rate)
    elif args.optimizer == "momentum":
        opt = SGDMomentum(learning_rate=args.learning_rate, momentum=args.momentum)
    else:
        opt = Adam(learning_rate=args.learning_rate)
    if args.disable_heliox_optimizer:
        opt.use_heliox_optimizer = False
    return opt


def determine_optimizer_backend(opt, requested: str) -> str:
    if requested != "auto":
        opt.set_gradient_backend(requested)
        return requested
    opt.set_gradient_backend("heliox")
    return "heliox"


def init_perf_stats(backend_config: BackendConfig) -> Dict[str, Dict[str, float]]:
    stats: Dict[str, Dict[str, float]] = {}
    if backend_config.enable_heliox:
        stats["heliox"] = {"sim": 0.0, "io": 0.0, "set_stim": 0.0, "optimizer": 0.0}
    return stats


def report_perf(stage: str, perf_stats: Dict[str, Dict[str, float]], duration: float) -> None:
    if not perf_stats:
        return
    print(f"\n⏱️ Performance ({stage}): wall={duration:.3f}s")
    estimated = 0.0
    for backend, metrics in perf_stats.items():
        label = backend.upper()
        parts = []
        total = 0.0

        sim = metrics.get("sim", 0.0)
        if sim:
            parts.append(f"sim={sim:.3f}s")
            total += sim

        io = metrics.get("io", 0.0)
        if io:
            parts.append(f"io={io:.3f}s")
            total += io

        set_stim = metrics.get("set_stim", 0.0)
        if set_stim:
            parts.append(f"set_stim={set_stim:.3f}s")
            total += set_stim

        optimizer_time = metrics.get("optimizer", 0.0)
        if optimizer_time:
            parts.append(f"optimizer={optimizer_time:.3f}s")
            total += optimizer_time

        estimated += total
        if not parts:
            parts.append("total=0.000s")
            print(f"  {label}: {' '.join(parts)}")
        else:
            print(f"  {label}: {', '.join(parts)}, total={total:.3f}s")
    print(f"  est_total={estimated:.3f}s (difference {duration - estimated:+.3f}s)")


def init_interval_timing(backend_config: BackendConfig) -> Dict[str, Dict[str, float]]:
    stats: Dict[str, Dict[str, float]] = {}
    if backend_config.enable_heliox:
        stats["heliox"] = {"finit": 0.0, "run": 0.0}
    return stats


def reset_interval_timing(interval_stats: Dict[str, Dict[str, float]]) -> None:
    for backend in ("heliox",):
        data = interval_stats.get(backend)
        if data is None:
            continue
        data["finit"] = 0.0
        data["run"] = 0.0


def maybe_report_interval(stage: str, interval_stats: Dict[str, Dict[str, float]],
                          count: int, last_idx: int, interval: int) -> int:
    if interval <= 0 or count < interval:
        return count
    start_idx = last_idx - interval + 1
    print(f"[Timing][{stage}] Samples {start_idx}-{last_idx}:")
    for backend in ("heliox",):
        data = interval_stats.get(backend)
        if data is None:
            continue
        total = data["finit"] + data["run"]
        print(f"  {backend.upper()}: finit={data['finit']:.3f}s, run={data['run']:.3f}s, total={total:.3f}s")
    reset_interval_timing(interval_stats)
    return 0


def _summarize_weight_stats(network: Sequential) -> Dict[str, Tuple[float, float]]:
    stats: Dict[str, Tuple[float, float]] = {}
    if not hasattr(network, "weights"):
        return stats
    for name, data in network.weights.items():
        try:
            w = np.asarray(data["weight"])
            b = np.asarray(data["bias"])
            w_mean = float(np.mean(np.abs(w)))
            b_mean = float(np.mean(np.abs(b)))
            stats[name] = (w_mean, b_mean)
        except Exception:
            continue
    return stats


def _sample_wrapper_weights(network: Sequential, max_items: int = 3) -> List[Tuple[str, str, float]]:
    samples: List[Tuple[str, str, float]] = []
    for layer in getattr(network, "layers", []):
        layer_name = getattr(layer, "layer_name", layer.__class__.__name__)
        wrappers_matrix = getattr(layer, "weight_synapse_wrappers", None)
        if wrappers_matrix:
            for row in wrappers_matrix:
                for wrapper in row:
                    if wrapper is None:
                        continue
                    try:
                        val = float(wrapper.w)
                    except Exception:
                        continue
                    samples.append((layer_name, "w", val))
                    if len(samples) >= max_items:
                        return samples
        bias_wrappers = getattr(layer, "bias_synapse_wrappers", None)
        if bias_wrappers:
            for wrapper in bias_wrappers:
                if wrapper is None:
                    continue
                try:
                    val = float(wrapper.w)
                except Exception:
                    continue
                samples.append((layer_name, "b", val))
                if len(samples) >= max_items:
                    return samples
    return samples


def train(networks: List[Sequential], optimizer, backend_config: BackendConfig, args: argparse.Namespace,
          x_train, y_train) -> Tuple[Dict[str, Dict[str, float]], float]:
    if not networks:
        raise ValueError("No networks provided for training")

    batch_size = len(networks)
    primary = networks[0]

    h.dt = args.dt
    primary.backend.set_dt(args.dt)

    for net in networks:
        net.is_train()

    perf_stats = init_perf_stats(backend_config)
    stage_start = perf_counter()
    total_samples = len(x_train)
    total_effective = (total_samples // batch_size) * batch_size
    if total_effective == 0:
        print(f"Warning: training samples ({total_samples}) fewer than batch size ({batch_size}); skipping training.")
        return perf_stats, 0.0

    for epoch in range(args.epochs):
        interval_stats = init_interval_timing(backend_config) if args.timing_interval > 0 else None
        interval_count = 0
        stage_label = f"Train E{epoch + 1}"
        processed_samples = 0
        if args.shuffle:
            perm = np.random.permutation(total_effective)
            x_epoch = x_train[perm]
            y_epoch = y_train[perm]
        else:
            x_epoch = x_train
            y_epoch = y_train

        for start in range(0, total_effective, batch_size):
            end = start + batch_size
            batch_x = x_epoch[start:end]
            batch_y = y_epoch[start:end]

            io_start = perf_counter()
            for net, img, label in zip(networks, batch_x, batch_y):
                net.set_stim(img, int(label))
            io_elapsed = perf_counter() - io_start
            perf_stats["heliox"]["io"] += io_elapsed
            perf_stats["heliox"]["set_stim"] += io_elapsed

            finit_start = perf_counter()
            primary.backend.finitialize(args.v_init)
            finit_elapsed = perf_counter() - finit_start

            run_start = perf_counter()
            primary.backend.run(args.sim_time)
            run_elapsed = perf_counter() - run_start

            perf_stats["heliox"]["sim"] += finit_elapsed + run_elapsed

            if interval_stats is not None:
                heliox_stats = interval_stats.get("heliox")
                if heliox_stats is not None:
                    heliox_stats["finit"] += finit_elapsed
                    heliox_stats["run"] += run_elapsed

            processed_samples += batch_size

            opt_start = perf_counter()
            optimizer.step(networks)
            optimizer_elapsed = perf_counter() - opt_start
            perf_stats["heliox"]["optimizer"] += optimizer_elapsed

            if args.debug_interval > 0 and processed_samples % args.debug_interval == 0:
                weight_stats = _summarize_weight_stats(primary)
                wrapper_samples = _sample_wrapper_weights(primary, max_items=3)
                try:
                    softmax = primary.get_softmax_outputs(backend="heliox")
                    softmax_arr = np.asarray(softmax, dtype=float)
                    sm_info = f"softmax min/max={softmax_arr.min():.4f}/{softmax_arr.max():.4f}"
                except Exception as exc:
                    sm_info = f"softmax read failed: {exc}"
                if weight_stats:
                    w_lines = ", ".join(f"{k}: w|mean|={v[0]:.4g}, b|mean|={v[1]:.4g}"
                                        for k, v in weight_stats.items())
                else:
                    w_lines = "no weight stats"
                if wrapper_samples:
                    w_samples = ", ".join(f"{name}.{kind}={val:.4g}" for name, kind, val in wrapper_samples)
                else:
                    w_samples = "no wrapper samples"
                print(f"[Debug][{stage_label}] sample {processed_samples}: {sm_info}; {w_lines}; wrappers: {w_samples}")

            if interval_stats is not None:
                interval_count += batch_size
                interval_count = maybe_report_interval(stage_label, interval_stats, interval_count,
                                                       processed_samples, args.timing_interval)

            if args.progress_interval > 0:
                if processed_samples % args.progress_interval == 0 or processed_samples == total_effective:
                    print(f"[Epoch {epoch + 1}] Processed sample {processed_samples}/{total_effective}")

        if args.lr_decay != 1.0:
            optimizer.learning_rate *= args.lr_decay
            print(f"[Epoch {epoch + 1}] Learning rate now {optimizer.learning_rate:.6g}")

    duration = perf_counter() - stage_start
    return perf_stats, duration


def evaluate(networks: List[Sequential], backend_config: BackendConfig, args: argparse.Namespace,
             x_test, y_test) -> Tuple[Dict[str, Dict[str, object]], Dict[str, Dict[str, float]], float]:
    if not networks:
        raise ValueError("No networks provided for evaluation")

    primary = networks[0]
    batch_size = len(networks)
    for net in networks:
        net.is_test()

    h.dt = args.dt
    primary.backend.set_dt(args.dt)

    requested_backends = ["heliox"]

    stats = {
        backend: {"predictions": [], "correct": 0} for backend in requested_backends
    }

    perf_stats = init_perf_stats(backend_config)
    stage_start = perf_counter()

    interval_stats = init_interval_timing(backend_config) if args.timing_interval > 0 else None
    interval_count = 0

    total_samples = len(x_test)
    total_effective = (total_samples // batch_size) * batch_size
    if total_effective == 0:
        return stats, perf_stats, 0.0

    for start in range(0, total_effective, batch_size):
        end = start + batch_size
        batch_x = x_test[start:end]
        batch_y = y_test[start:end]

        io_start = perf_counter()
        for net, img, label in zip(networks, batch_x, batch_y):
            net.set_stim(img, int(label))
        io_elapsed = perf_counter() - io_start
        perf_stats["heliox"]["io"] += io_elapsed
        perf_stats["heliox"]["set_stim"] += io_elapsed

        finit_start = perf_counter()
        primary.backend.finitialize(args.v_init)
        finit_elapsed = perf_counter() - finit_start

        run_start = perf_counter()
        primary.backend.run(args.sim_time)
        run_elapsed = perf_counter() - run_start

        perf_stats["heliox"]["sim"] += finit_elapsed + run_elapsed

        if interval_stats is not None:
            heliox_stats = interval_stats.get("heliox")
            if heliox_stats is not None:
                heliox_stats["finit"] += finit_elapsed
                heliox_stats["run"] += run_elapsed

        outputs = {}
        read_start = perf_counter()
        for i, net in enumerate(networks):
            outputs.setdefault("heliox", []).append(net.get_softmax_outputs(backend="heliox"))
        read_elapsed = perf_counter() - read_start
        if "heliox" in perf_stats:
            perf_stats["heliox"]["io"] += read_elapsed

        for backend in requested_backends:
            for logits, label in zip(outputs[backend], batch_y):
                pred = int(np.argmax(logits))
                stats[backend]["predictions"].append(pred)
                if pred == int(label):
                    stats[backend]["correct"] += 1

        if interval_stats is not None:
            interval_count += batch_size
            interval_count = maybe_report_interval("Eval", interval_stats, interval_count, end,
                                                   args.timing_interval)

    for backend in requested_backends:
        stats[backend]["accuracy"] = stats[backend]["correct"] / total_effective

    duration = perf_counter() - stage_start
    return stats, perf_stats, duration


def main():
    args = parse_args()
    resources = build_model(args)
    networks: List[Sequential] = resources["networks"]
    if not networks:
        raise RuntimeError("No networks were built")
    backend_config: BackendConfig = resources["backend_config"]

    optimizer = make_optimizer(args)
    optimizer.record_time = args.record_time if args.record_time > 0 else args.sim_time
    backend_used = determine_optimizer_backend(optimizer, args.optimizer_backend)
    print(f"Optimizer will collect gradients from backend: {backend_used}")
    if hasattr(networks[0], "backend"):
        ready = getattr(networks[0].backend, "optimizer_ready", False)
        print(f"HelioX optimizer ready: {ready} (python_fallback={not ready or not optimizer.use_heliox_optimizer})")

    with np.load('mnist.npz') as data:
        x_train = (data['x_train'][:args.train_samples] / 255).astype(np.float32)
        y_train = data['y_train'][:args.train_samples]
        if args.test_samples > 0:
            x_test = (data['x_test'][:args.test_samples] / 255).astype(np.float32)
            y_test = data['y_test'][:args.test_samples]
        else:
            x_test = np.empty((0,), dtype=np.float32)
            y_test = np.empty((0,), dtype=np.int64)

    print("=" * 60)
    print("HelioX Training Script")
    print("=" * 60)

    train_perf, train_duration = train(networks, optimizer, backend_config, args, x_train, y_train)
    report_perf("Train", train_perf, train_duration)

    if args.test_samples <= 0:
        print("\nSkipping evaluation (--test-samples <= 0).")
    else:
        print("\n📊 Evaluating...")
        eval_results, eval_perf, eval_duration = evaluate(networks, backend_config, args, x_test, y_test)

        for backend, info in eval_results.items():
            accuracy = info["accuracy"]
            preds = info["predictions"][:10]
            print(f"[{backend.upper()}] Test Accuracy: {accuracy:.2%}")
            print(f"[{backend.upper()}] Sample predictions: {preds}")

        report_perf("Eval", eval_perf, eval_duration)

        print(f"Actual labels: {y_test[:10].tolist()}")
    print("\n✨ Done!")
    h.ParallelContext().done()


if __name__ == "__main__":
    main()
