"""Radial arm maze: GRU vs MLP under structural aliasing."""
import os
import sys
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

import torch
import numpy as np
import time
import json
import subprocess
import tempfile
import shutil

import src.radial_arm_env as ram

SEEDS = [0, 1, 2, 3, 4]
N_WORKERS = 6
PYTHON = sys.executable
WORKER_SCRIPT = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             "_radial_arm_worker.py")


def main():
    t0 = time.time()

    print("Generating foraging trajectories...", flush=True)
    trajs = ram.generate_foraging_trajectories(
        n_trajs=1500, max_steps=300, optimal_prob=0.8, seed=42,
    )
    n_tips = [bin(t['visited_seq'][-1]).count('1') for t in trajs]
    lengths = [len(t['actions']) for t in trajs]
    print(f"{len(trajs)} trajectories, mean length {np.mean(lengths):.0f}, "
          f"mean tips {np.mean(n_tips):.1f}/8", flush=True)

    train_trajs, val_trajs = ram.split_trajectories(trajs, val_fraction=0.2, seed=42)
    train_sa = ram.trajectories_to_sa_pairs(train_trajs)
    val_sa = ram.trajectories_to_sa_pairs(val_trajs)
    print(f"{len(train_trajs)} train, {len(val_trajs)} val", flush=True)

    obs_encoding = ram.structural_obs_encoding()
    obs_dim = len(obs_encoding[0])

    unique_obs = set(tuple(obs_encoding[s].tolist()) for s in range(ram.N_NODES))
    print(f"obs_dim={obs_dim}, unique obs classes={len(unique_obs)}", flush=True)

    random_baseline = -np.log2(ram.N_ACTIONS)
    print(f"Random baseline: {random_baseline:.4f} bits/dec", flush=True)

    os.makedirs(os.path.join(PROJECT_ROOT, "checkpoints"), exist_ok=True)
    ckpt_path = os.path.join(PROJECT_ROOT, "checkpoints", "radial_arm_tmp.pt")
    torch.save({
        "obs_encoding": obs_encoding,
        "obs_dim": obs_dim,
        "train_trajs": train_trajs,
        "val_trajs": val_trajs,
    }, ckpt_path)

    outdir = tempfile.mkdtemp(prefix="radial_arm_")
    jobs = []
    for seed in SEEDS:
        for model_type in ["gru", "mlp"]:
            jobs.append((model_type, seed))

    total = len(jobs)
    print(f"\nLaunching {total} training jobs ({N_WORKERS} parallel)...", flush=True)

    running = {}
    pending = list(jobs)
    results = {}
    completed = 0

    while pending or running:
        while pending and len(running) < N_WORKERS:
            model_type, seed = pending.pop(0)
            cmd = [PYTHON, WORKER_SCRIPT,
                   "--model", model_type, "--seed", str(seed),
                   "--ckpt", ckpt_path, "--outdir", outdir]
            proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            running[proc] = (model_type, seed)

        done = []
        for proc, key in running.items():
            ret = proc.poll()
            if ret is not None:
                done.append((proc, key, ret))

        for proc, key, ret in done:
            del running[proc]
            model_type, seed = key
            stdout = proc.stdout.read().decode().strip()
            if ret == 0:
                res_path = os.path.join(outdir, f"{model_type}_s{seed}.pt")
                res = torch.load(res_path, weights_only=False)
                results[key] = res
                completed += 1
                print(f"[{completed}/{total}] {stdout}", flush=True)
            else:
                stderr = proc.stderr.read().decode().strip()
                print(f"FAILED {model_type} seed={seed}: {stderr[-500:]}", flush=True)

        if running and not done:
            time.sleep(1)

    os.remove(ckpt_path)
    shutil.rmtree(outdir, ignore_errors=True)

    summary = {}
    for model_type in ["gru", "mlp"]:
        lls = [results[(model_type, s)]["ll"] for s in SEEDS]
        accs = [results[(model_type, s)]["acc"] for s in SEEDS]

        all_nodes = set()
        for s in SEEDS:
            all_nodes.update(results[(model_type, s)]["per_node"].keys())

        per_node_mean = {}
        for node in sorted(all_nodes):
            node_accs = []
            for s in SEEDS:
                pn = results[(model_type, s)]["per_node"]
                if node in pn:
                    c, t = pn[node]
                    node_accs.append(c / t)
            if node_accs:
                per_node_mean[node] = (np.mean(node_accs), np.std(node_accs))

        summary[model_type] = {
            "ll_mean": np.mean(lls),
            "ll_std": np.std(lls),
            "acc_mean": np.mean(accs),
            "acc_std": np.std(accs),
            "per_node": per_node_mean,
        }

    for model_type in ["gru", "mlp"]:
        s = summary[model_type]
        print(f"{model_type.upper()}: LL {s['ll_mean']:.4f} +/- {s['ll_std']:.4f}, "
              f"acc {s['acc_mean']:.3f} +/- {s['acc_std']:.3f}", flush=True)
    gru_ll = summary["gru"]["ll_mean"]
    mlp_ll = summary["mlp"]["ll_mean"]
    print(f"random baseline {random_baseline:.4f}, GRU-MLP gap {gru_ll - mlp_ll:+.4f} bits/dec", flush=True)

    save_data = {
        "results": results,
        "summary": summary,
        "random_baseline": random_baseline,
        "seeds": SEEDS,
        "obs_dim": obs_dim,
        "n_unique_obs": len(unique_obs),
        "n_epochs_policy": 50,
        "n_actions": ram.N_ACTIONS,
        "n_nodes": ram.N_NODES,
        "n_trajs": len(trajs),
        "n_train": len(train_trajs),
        "n_val": len(val_trajs),
        "optimal_prob": 0.8,
    }
    torch.save(save_data, os.path.join(PROJECT_ROOT, "checkpoints",
                                        "radial_arm_ablation.pt"))
    print(f"\nSaved to checkpoints/radial_arm_ablation.pt", flush=True)

    elapsed = time.time() - t0
    print(f"Total time: {elapsed:.0f}s ({elapsed/60:.1f}min)", flush=True)


if __name__ == "__main__":
    main()
