"""Leave-one-mouse-out cross-validation for the IRL→GRU pipeline."""
import os
import sys
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

import argparse
import time
import subprocess
import tempfile
import shutil
import torch
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from src.rosenberg_data import (
    load_rosenberg_trajectories, split_rosenberg_trajectories,
    build_binary_tree_transition_tensor, REWARDED_ANIMALS,
)

ANIMALS = list(REWARDED_ANIMALS)  # ['B1','B2','B3','B4','C1','C3','C6','C7','C8','C9']
N_WORKERS = 5
HIDDEN_DIM = 128
IRL_EPOCHS = 300
GRU_EPOCHS = 200
FT_EPOCHS = 20
PYTHON = sys.executable
WORKER_SCRIPT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "_leave_one_out_worker.py")


def plot_leave_one_out(cross_results, within_results, save_path):
    animals = sorted(cross_results.keys())
    n = len(animals)
    x = np.arange(n)

    cross_lls = [cross_results[a]["ll"] for a in animals]
    within_lls = [within_results[a]["ll"] for a in animals]
    transfer_gaps = [within_lls[i] - cross_lls[i] for i in range(n)]

    ft_lls = [cross_results[a].get("ft_ll", cross_results[a]["ll"]) for a in animals]
    ft_improvements = [ft_lls[i] - cross_lls[i] for i in range(n)]

    cross_probes = [cross_results[a]["probe_acc"] for a in animals]
    within_probes = [within_results[a]["probe_acc"] for a in animals]

    fig, axes = plt.subplots(1, 4, figsize=(18, 4.5))

    ax = axes[0]
    w = 0.35
    ax.bar(x - w / 2, cross_lls, w, label="Cross (9→1)", color="#2196F3", alpha=0.85,
           edgecolor="black", linewidth=0.5)
    ax.bar(x + w / 2, within_lls, w, label="Within (80→20%)", color="#4CAF50", alpha=0.85,
           edgecolor="black", linewidth=0.5)
    ax.set_xticks(x)
    ax.set_xticklabels(animals, rotation=45, ha="right", fontsize=8)
    ax.set_ylabel("LL (bits/dec)")
    ax.set_title("(a) Cross vs Within LL")
    ax.legend(fontsize=7)

    ax = axes[1]
    colors_gap = ["#F44336" if g > 0 else "#4CAF50" for g in transfer_gaps]
    ax.bar(x, transfer_gaps, color=colors_gap, alpha=0.85, edgecolor="black", linewidth=0.5)
    ax.axhline(0, color="black", linewidth=0.5)
    mean_gap = np.mean(transfer_gaps)
    ax.axhline(mean_gap, color="gray", linestyle="--", linewidth=1,
               label=f"mean={mean_gap:.4f}")
    ax.set_xticks(x)
    ax.set_xticklabels(animals, rotation=45, ha="right", fontsize=8)
    ax.set_ylabel("Within − Cross LL")
    ax.set_title("(b) Transfer Gap")
    ax.legend(fontsize=7)

    ax = axes[2]
    ax.bar(x, ft_improvements, color="#FF9800", alpha=0.85, edgecolor="black", linewidth=0.5)
    ax.axhline(0, color="black", linewidth=0.5)
    mean_ft = np.mean(ft_improvements)
    ax.axhline(mean_ft, color="gray", linestyle="--", linewidth=1,
               label=f"mean={mean_ft:.4f}")
    ax.set_xticks(x)
    ax.set_xticklabels(animals, rotation=45, ha="right", fontsize=8)
    ax.set_ylabel("FT − Cross LL")
    ax.set_title("(c) Fine-Tuning Improvement")
    ax.legend(fontsize=7)

    ax = axes[3]
    ax.scatter(within_probes, cross_probes, s=50, c="#9C27B0", alpha=0.85,
               edgecolors="black", linewidth=0.5, zorder=3)
    for i, a in enumerate(animals):
        ax.annotate(a, (within_probes[i], cross_probes[i]), fontsize=7,
                    textcoords="offset points", xytext=(4, 4))
    lim_lo = min(min(within_probes), min(cross_probes)) - 0.02
    lim_hi = max(max(within_probes), max(cross_probes)) + 0.02
    ax.plot([lim_lo, lim_hi], [lim_lo, lim_hi], "k--", linewidth=0.8, alpha=0.5,
            label="identity")
    ax.set_xlim(lim_lo, lim_hi)
    ax.set_ylim(lim_lo, lim_hi)
    ax.set_xlabel("Within Probe Acc")
    ax.set_ylabel("Cross Probe Acc")
    ax.set_title("(d) Probe Accuracy")
    ax.legend(fontsize=7)
    ax.set_aspect("equal")

    fig.suptitle("Leave-One-Mouse-Out Cross-Validation", fontsize=13, y=1.02)
    plt.tight_layout()
    plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"Saved figure: {save_path}", flush=True)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--quick", action="store_true", help="Smoke test: 2 folds, fewer epochs")
    args = parser.parse_args()

    t0 = time.time()
    animals = ANIMALS[:2] if args.quick else ANIMALS
    irl_epochs = 50 if args.quick else IRL_EPOCHS
    gru_epochs = 50 if args.quick else GRU_EPOCHS
    ft_epochs = 5 if args.quick else FT_EPOCHS
    mode = "QUICK" if args.quick else "FULL"

    print("Loading Rosenberg trajectories...", flush=True)
    all_trajs = {}
    for animal in animals:
        trajs = load_rosenberg_trajectories(animals=[animal])
        all_trajs[animal] = trajs
        print(f"{animal}: {len(trajs)} bouts", flush=True)

    total_bouts = sum(len(v) for v in all_trajs.values())
    print(f"Total: {total_bouts} bouts from {len(animals)} mice", flush=True)

    T_tensor = build_binary_tree_transition_tensor()

    outdir = tempfile.mkdtemp(prefix="leave_one_out_")
    jobs = []  # (fold_type, held_out, ckpt_path)

    for animal in animals:
        train_trajs = []
        for other in animals:
            if other != animal:
                train_trajs.extend(all_trajs[other])
        test_trajs = all_trajs[animal]

        ckpt_path = os.path.join(outdir, f"input_cross_{animal}.pt")
        torch.save({
            "train_trajs": train_trajs,
            "test_trajs": test_trajs,
            "fold_type": "cross",
            "held_out": animal,
            "T_tensor": T_tensor,
        }, ckpt_path)
        jobs.append(("cross", animal, ckpt_path))

        train_trajs_w, val_trajs_w = split_rosenberg_trajectories(
            all_trajs[animal], val_fraction=0.2, seed=42)

        ckpt_path_w = os.path.join(outdir, f"input_within_{animal}.pt")
        torch.save({
            "train_trajs": train_trajs_w,
            "test_trajs": val_trajs_w,
            "fold_type": "within",
            "held_out": animal,
            "T_tensor": T_tensor,
        }, ckpt_path_w)
        jobs.append(("within", animal, ckpt_path_w))

    total_jobs = len(jobs)
    print(f"\n[{mode}] Launching {total_jobs} jobs ({len(animals)} animals x 2 fold types, "
          f"IRL={irl_epochs}ep, GRU={gru_epochs}ep, FT={ft_epochs}ep, "
          f"{N_WORKERS} parallel workers)...", flush=True)

    running = {}
    pending = list(jobs)
    results = {}
    completed = 0

    while pending or running:
        while pending and len(running) < N_WORKERS:
            fold_type, animal, ckpt_path = pending.pop(0)
            cmd = [PYTHON, WORKER_SCRIPT,
                   "--ckpt", ckpt_path, "--outdir", outdir,
                   "--irl_epochs", str(irl_epochs),
                   "--gru_epochs", str(gru_epochs),
                   "--ft_epochs", str(ft_epochs)]
            proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            running[proc] = (fold_type, animal)

        done = []
        for proc, key in running.items():
            ret = proc.poll()
            if ret is not None:
                done.append((proc, key, ret))

        for proc, key, ret in done:
            del running[proc]
            fold_type, animal = key
            stdout = proc.stdout.read().decode().strip()
            if ret == 0:
                res_path = os.path.join(outdir, f"{fold_type}_{animal}.pt")
                res = torch.load(res_path, weights_only=False)
                results[key] = res
                completed += 1
                print(f"[{completed}/{total_jobs}] {stdout}", flush=True)
            else:
                stderr = proc.stderr.read().decode().strip()
                print(f"FAILED {fold_type}/{animal}: {stderr[-500:]}", flush=True)

        if running and not done:
            time.sleep(1)

    shutil.rmtree(outdir, ignore_errors=True)

    cross_results = {}
    within_results = {}
    for (fold_type, animal), res in results.items():
        if fold_type == "cross":
            cross_results[animal] = res
        else:
            within_results[animal] = res

    print("Leave-One-Mouse-Out Cross-Validation ", flush=True)

    header = (f"{'Animal':<8} {'Cross LL':<12} {'Within LL':<12} {'Gap':<10} "
              f"{'FT LL':<12} {'FT delta':<10} {'Cross Acc':<11} {'Cross Probe':<13} "
              f"{'Cross ρ':<10}")
    print(header, flush=True)

    gaps, ft_deltas = [], []
    for animal in animals:
        if animal not in cross_results or animal not in within_results:
            continue
        cr = cross_results[animal]
        wr = within_results[animal]
        gap = wr["ll"] - cr["ll"]
        gaps.append(gap)
        ft_ll = cr.get("ft_ll", cr["ll"])
        ft_delta = ft_ll - cr["ll"]
        ft_deltas.append(ft_delta)

        print(f"{animal:<8} {cr['ll']:+.4f}     {wr['ll']:+.4f}     {gap:+.4f}   "
              f"{ft_ll:+.4f}     {ft_delta:+.4f}   "
              f"{cr['acc']:.3f}      {cr['probe_acc']:.3f}        "
              f"{cr['pc1_rho']:.3f}", flush=True)

    if gaps:
        print(f"Transfer gap: {np.mean(gaps):+.4f} +/- {np.std(gaps):.4f} bits/dec", flush=True)
    if ft_deltas:
        print(f"Fine-tuning delta: {np.mean(ft_deltas):+.4f} +/- {np.std(ft_deltas):.4f} bits/dec",
              flush=True)

    os.makedirs(os.path.join(PROJECT_ROOT, "figures"), exist_ok=True)
    fig_path = os.path.join(PROJECT_ROOT, "figures", "leave_one_out.png")
    if cross_results and within_results:
        plot_leave_one_out(cross_results, within_results, fig_path)

    os.makedirs(os.path.join(PROJECT_ROOT, "checkpoints"), exist_ok=True)
    ckpt_save = os.path.join(PROJECT_ROOT, "checkpoints", "leave_one_out.pt")
    torch.save({
        "results": {f"{ft}_{a}": r for (ft, a), r in results.items()},
        "cross_results": cross_results,
        "within_results": within_results,
        "animals": animals,
        "gaps": gaps,
        "ft_deltas": ft_deltas,
        "irl_epochs": irl_epochs,
        "gru_epochs": gru_epochs,
        "ft_epochs": ft_epochs,
        "hidden_dim": HIDDEN_DIM,
        "quick": args.quick,
    }, ckpt_save)
    print(f"\nSaved to {ckpt_save}", flush=True)

    elapsed = time.time() - t0
    print(f"Total time: {elapsed:.0f}s ({elapsed / 60:.1f}min)", flush=True)


if __name__ == "__main__":
    main()
