"""Architecture baselines: GRU vs LSTM vs GTrXL vs minGRU on structural observations."""
import os
import sys
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

import argparse
import torch
import numpy as np
import time
import subprocess
import tempfile
import shutil
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from src.rosenberg_data import load_rosenberg_everything, build_bc_targets
from src.evaluation import compute_behavioral_cloning_ll

DEFAULT_ARCHS = ["gru", "lstm", "gtrxl", "mingru"]
ALL_ARCHS = ["gru", "lstm", "gtrxl", "mamba", "mingru"]
ARCH_LABELS = {"gru": "GRU", "lstm": "LSTM", "gtrxl": "GTrXL", "mamba": "Mamba", "mingru": "minGRU"}
SEEDS = [0, 1, 2, 3, 4]
N_ACTIONS = 3
N_STATES = 127
N_WORKERS = 5
N_EPOCHS = 200
RANDOM_BASELINE = -np.log2(3)
PYTHON = sys.executable
WORKER_SCRIPT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "_arch_worker.py")


def build_structural_obs(node, n_states=127):
    first_leaf = (n_states - 1) // 2

    is_root = 1.0 if node == 0 else 0.0
    is_leaf = 1.0 if node >= first_leaf else 0.0
    if node == 0:
        degree = 2.0
    elif node >= first_leaf:
        degree = 1.0
    else:
        degree = 3.0

    left_dest = 2 * node + 1 if 2 * node + 1 < n_states else node
    right_dest = 2 * node + 2 if 2 * node + 2 < n_states else node
    reverse_dest = (node - 1) // 2 if node > 0 else 0

    features = [degree / 3.0, is_root, is_leaf]
    for dest in [left_dest, right_dest, reverse_dest]:
        d_is_root = 1.0 if dest == 0 else 0.0
        d_is_leaf = 1.0 if dest >= first_leaf else 0.0
        if dest == 0:
            d_degree = 2.0
        elif dest >= first_leaf:
            d_degree = 1.0
        else:
            d_degree = 3.0
        features.extend([d_is_leaf, d_is_root, d_degree / 3.0])
    return torch.tensor(features, dtype=torch.float32)


def plot_arch_comparison(summary, bc_ll, save_path, archs=None):
    if archs is None:
        archs = list(summary.keys())
    fig, axes = plt.subplots(1, 4, figsize=(16, 4.5))

    labels = [ARCH_LABELS[a] for a in archs]
    all_colors = {"gru": "#2196F3", "lstm": "#4CAF50", "gtrxl": "#FF9800",
                  "mamba": "#9C27B0", "mingru": "#F44336"}
    colors = [all_colors[a] for a in archs]
    x = np.arange(len(archs))

    metrics = [
        ("ll_mean", "ll_std", "Log-Likelihood (bits/dec)", "LL"),
        ("acc_mean", "acc_std", "Action Accuracy", "Accuracy"),
        ("probe_mean", "probe_std", "Probe Accuracy (127-way)", "Probe Acc"),
        ("rho_mean", "rho_std", "|PC1-Depth| Spearman ρ", "PC1-Depth ρ"),
    ]

    for ax, (mean_key, std_key, ylabel, title) in zip(axes, metrics):
        means = [summary[a][mean_key] for a in archs]
        stds = [summary[a][std_key] for a in archs]
        bars = ax.bar(x, means, yerr=stds, capsize=4, color=colors, alpha=0.85,
                       edgecolor="black", linewidth=0.5)
        ax.set_xticks(x)
        ax.set_xticklabels(labels, rotation=30, ha="right")
        ax.set_ylabel(ylabel)
        ax.set_title(title)

        if mean_key == "ll_mean":
            ax.axhline(bc_ll, color="gray", linestyle="--", linewidth=1, label="BC ceiling")
            ax.axhline(RANDOM_BASELINE, color="red", linestyle=":", linewidth=1, label="Random")
            ax.legend(fontsize=7, loc="lower right")
            all_vals = means + [bc_ll, RANDOM_BASELINE]
            margin = 0.05 * (max(all_vals) - min(all_vals))
            ax.set_ylim(min(all_vals) - margin, max(all_vals) + margin)

    fig.suptitle("Architecture Baselines (hidden_dim=64, structural obs)", fontsize=13, y=1.02)
    plt.tight_layout()
    plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"Saved figure: {save_path}", flush=True)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--quick", action="store_true", help="Smoke test: 1 seed, 50 epochs")
    parser.add_argument("--archs", type=str, default=None,
                        help="Comma-separated arch list (default: gru,lstm,gtrxl,mingru)")
    args = parser.parse_args()

    archs = args.archs.split(",") if args.archs else DEFAULT_ARCHS
    seeds = [0] if args.quick else SEEDS
    n_epochs = 50 if args.quick else N_EPOCHS
    t0 = time.time()

    print("Loading Rosenberg data...", flush=True)
    d = load_rosenberg_everything()
    n_trajs = len(d["trajs"])
    n_train = len(d["train_trajs"])
    n_val = len(d["val_trajs"])
    print(f"{n_trajs} bouts total ({n_train} train, {n_val} val)", flush=True)
    print(f"Random baseline: {RANDOM_BASELINE:.4f} bits/dec", flush=True)

    print("\nBuilding BC targets...", flush=True)
    bc_policy = build_bc_targets(d["train_sa"], n_states=N_STATES, n_actions=N_ACTIONS, laplace=1.0)
    bc_ll = compute_behavioral_cloning_ll(d["train_sa"], d["val_sa"],
                                          n_states=N_STATES, n_actions=N_ACTIONS)
    print(f"BC LL (full-info ceiling) = {bc_ll:.4f} bits/dec", flush=True)

    print("\nBuilding structural observations...", flush=True)
    structural_obs = {}
    for s in range(N_STATES):
        structural_obs[s] = build_structural_obs(s, N_STATES)
    obs_dim = len(structural_obs[0])

    unique_obs = set(tuple(structural_obs[s].tolist()) for s in range(N_STATES))
    print(f"obs_dim={obs_dim}, unique observations: {len(unique_obs)}/{N_STATES}", flush=True)

    os.makedirs("checkpoints", exist_ok=True)
    ckpt_path = os.path.join(PROJECT_ROOT, "checkpoints", "arch_baselines_tmp.pt")
    torch.save({
        "bc_policy": bc_policy,
        "structural_obs": structural_obs,
        "obs_dim": obs_dim,
        "train_trajs": d["train_trajs"],
        "val_trajs": d["val_trajs"],
    }, ckpt_path)

    outdir = tempfile.mkdtemp(prefix="arch_baselines_")
    jobs = [(arch, seed) for arch in archs for seed in seeds]
    total = len(jobs)
    mode = "QUICK" if args.quick else "FULL"
    print(f"\n[{mode}] Launching {total} jobs ({len(archs)} archs x {len(seeds)} seeds, "
          f"{n_epochs} epochs, {N_WORKERS} parallel workers)...", flush=True)

    running = {}
    pending = list(jobs)
    results = {}
    completed = 0

    while pending or running:
        while pending and len(running) < N_WORKERS:
            arch, seed = pending.pop(0)
            cmd = [PYTHON, WORKER_SCRIPT,
                   "--arch", arch, "--seed", str(seed),
                   "--ckpt", ckpt_path, "--outdir", outdir,
                   "--n_epochs", str(n_epochs)]
            proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            running[proc] = (arch, seed)

        done = []
        for proc, key in running.items():
            ret = proc.poll()
            if ret is not None:
                done.append((proc, key, ret))

        for proc, key, ret in done:
            del running[proc]
            arch, seed = key
            stdout = proc.stdout.read().decode().strip()
            if ret == 0:
                res_path = os.path.join(outdir, f"{arch}_s{seed}.pt")
                res = torch.load(res_path, weights_only=False)
                results[key] = res
                completed += 1
                print(f"[{completed}/{total}] {stdout}", flush=True)
            else:
                stderr = proc.stderr.read().decode().strip()
                print(f"FAILED {arch} seed={seed}: {stderr[-500:]}", flush=True)

        if running and not done:
            time.sleep(1)

    os.remove(ckpt_path)
    shutil.rmtree(outdir, ignore_errors=True)

    summary = {}
    for arch in archs:
        arch_seeds = [s for s in seeds if (arch, s) in results]
        if not arch_seeds:
            print(f"WARNING: no results for {arch}", flush=True)
            continue

        lls = [results[(arch, s)]["ll"] for s in arch_seeds]
        accs = [results[(arch, s)]["acc"] for s in arch_seeds]
        probes = [results[(arch, s)]["probe_acc"] for s in arch_seeds]
        rhos = [results[(arch, s)]["pc1_rho"] for s in arch_seeds]
        times = [results[(arch, s)]["train_time"] for s in arch_seeds]
        n_params = results[(arch, arch_seeds[0])]["n_params"]

        summary[arch] = {
            "ll_mean": np.mean(lls), "ll_std": np.std(lls),
            "acc_mean": np.mean(accs), "acc_std": np.std(accs),
            "probe_mean": np.mean(probes), "probe_std": np.std(probes),
            "rho_mean": np.mean(rhos), "rho_std": np.std(rhos),
            "time_mean": np.mean(times), "time_std": np.std(times),
            "n_params": n_params,
        }

    print("Architecture Baselines Comparison (hidden_dim=64, structural obs) ", flush=True)
    header = f"{'Arch':<8} {'LL (bits/dec)':<18} {'Accuracy':<16} {'Probe Acc':<16} {'PC1-ρ':<14} {'Time (s)':<12} {'Params':<8}"
    print(header, flush=True)

    for arch in archs:
        if arch not in summary:
            continue
        s = summary[arch]
        print(f"{ARCH_LABELS[arch]:<8} "
              f"{s['ll_mean']:+.4f}+/-{s['ll_std']:.4f}    "
              f"{s['acc_mean']:.3f}+/-{s['acc_std']:.3f}     "
              f"{s['probe_mean']:.3f}+/-{s['probe_std']:.3f}     "
              f"{s['rho_mean']:.3f}+/-{s['rho_std']:.3f}   "
              f"{s['time_mean']:6.0f}+/-{s['time_std']:.0f}   "
              f"{s['n_params']:>7d}", flush=True)

    print(f"BC LL (full-info ceiling) = {bc_ll:.4f}", flush=True)
    print(f"Random baseline           = {RANDOM_BASELINE:.4f}", flush=True)

    os.makedirs("figures", exist_ok=True)
    fig_path = os.path.join(PROJECT_ROOT, "figures", "arch_baselines.png")
    plot_arch_comparison(summary, bc_ll, fig_path, archs=archs)

    save_data = {
        "results": results,
        "summary": summary,
        "bc_ll": bc_ll,
        "random_baseline": RANDOM_BASELINE,
        "seeds": seeds,
        "n_epochs": n_epochs,
        "obs_dim": obs_dim,
        "n_unique_obs": len(unique_obs),
        "n_actions": N_ACTIONS,
        "n_trajs": n_trajs,
        "n_train": n_train,
        "n_val": n_val,
        "hidden_dim": 64,
        "quick": args.quick,
    }
    ckpt_save = os.path.join(PROJECT_ROOT, "checkpoints", "arch_baselines.pt")
    torch.save(save_data, ckpt_save)
    print(f"\nSaved to {ckpt_save}", flush=True)

    elapsed = time.time() - t0
    print(f"Total time: {elapsed:.0f}s ({elapsed/60:.1f}min)", flush=True)


if __name__ == "__main__":
    main()
