"""GRU vs MLP ablation with structural (aliased) observations."""
import os
import sys
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

import torch
import numpy as np
import time
import subprocess
import tempfile
import shutil

from src.rosenberg_data import (
    load_rosenberg_everything, build_bc_targets,
)
from src.evaluation import compute_behavioral_cloning_ll

SEEDS = [0, 1, 2, 3, 4]
N_ACTIONS = 3
N_STATES = 127
N_WORKERS = 6
RANDOM_BASELINE = -np.log2(3)
PYTHON = sys.executable
WORKER_SCRIPT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "_structural_worker.py")


def build_structural_obs(node, n_states=127):
    first_leaf = (n_states - 1) // 2

    is_root = 1.0 if node == 0 else 0.0
    is_leaf = 1.0 if node >= first_leaf else 0.0

    if node == 0:
        degree = 2.0
    elif node >= first_leaf:
        degree = 1.0
    else:
        degree = 3.0

    left_dest = 2 * node + 1 if 2 * node + 1 < n_states else node
    right_dest = 2 * node + 2 if 2 * node + 2 < n_states else node
    reverse_dest = (node - 1) // 2 if node > 0 else 0

    features = [degree / 3.0, is_root, is_leaf]

    for dest in [left_dest, right_dest, reverse_dest]:
        d_is_root = 1.0 if dest == 0 else 0.0
        d_is_leaf = 1.0 if dest >= first_leaf else 0.0
        if dest == 0:
            d_degree = 2.0
        elif dest >= first_leaf:
            d_degree = 1.0
        else:
            d_degree = 3.0
        features.extend([d_is_leaf, d_is_root, d_degree / 3.0])

    return torch.tensor(features, dtype=torch.float32)


def main():
    t0 = time.time()

    print("Loading Rosenberg data...", flush=True)
    d = load_rosenberg_everything()
    n_trajs = len(d["trajs"])
    n_train = len(d["train_trajs"])
    n_val = len(d["val_trajs"])
    print(f"{n_trajs} bouts total ({n_train} train, {n_val} val)", flush=True)
    print(f"Random baseline: {RANDOM_BASELINE:.4f} bits/dec", flush=True)

    print("\nBuilding BC targets from training data...", flush=True)
    bc_policy = build_bc_targets(d["train_sa"], n_states=N_STATES, n_actions=N_ACTIONS, laplace=1.0)
    bc_ll = compute_behavioral_cloning_ll(d["train_sa"], d["val_sa"],
                                          n_states=N_STATES, n_actions=N_ACTIONS)
    bc_entropy = -(bc_policy * torch.log2(bc_policy.clamp(min=1e-10))).sum(dim=-1)
    print(f"BC LL (full-info ceiling) = {bc_ll:.4f} bits/dec", flush=True)
    print(f"BC policy mean entropy = {bc_entropy.mean():.4f} bits (max={np.log2(3):.4f})",
          flush=True)

    print("\nBuilding structural observations...", flush=True)
    structural_obs = {}
    for s in range(N_STATES):
        structural_obs[s] = build_structural_obs(s, N_STATES)
    obs_dim = len(structural_obs[0])
    print(f"obs_dim = {obs_dim}", flush=True)

    unique_obs = set()
    for s in range(N_STATES):
        unique_obs.add(tuple(structural_obs[s].tolist()))
    print(f"Unique observations: {len(unique_obs)} (out of {N_STATES} nodes)", flush=True)

    os.makedirs("checkpoints", exist_ok=True)
    ckpt_path = os.path.join(PROJECT_ROOT, "checkpoints", "structural_obs_tmp.pt")
    torch.save({
        "bc_policy": bc_policy,
        "structural_obs": structural_obs,
        "obs_dim": obs_dim,
        "train_trajs": d["train_trajs"],
        "val_trajs": d["val_trajs"],
    }, ckpt_path)

    outdir = tempfile.mkdtemp(prefix="structural_obs_")
    jobs = []
    for seed in SEEDS:
        for model_type in ["gru", "mlp"]:
            jobs.append((model_type, seed))

    total = len(jobs)
    print(f"\nLaunching {total} policy training jobs ({N_WORKERS} parallel)...", flush=True)

    running = {}
    pending = list(jobs)
    results = {}
    completed = 0

    while pending or running:
        while pending and len(running) < N_WORKERS:
            model_type, seed = pending.pop(0)
            cmd = [PYTHON, WORKER_SCRIPT,
                   "--model", model_type, "--seed", str(seed),
                   "--ckpt", ckpt_path, "--outdir", outdir]
            proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            running[proc] = (model_type, seed)

        done = []
        for proc, key in running.items():
            ret = proc.poll()
            if ret is not None:
                done.append((proc, key, ret))

        for proc, key, ret in done:
            del running[proc]
            model_type, seed = key
            stdout = proc.stdout.read().decode().strip()
            if ret == 0:
                res_path = os.path.join(outdir, f"{model_type}_s{seed}.pt")
                res = torch.load(res_path, weights_only=False)
                results[key] = res
                completed += 1
                print(f"[{completed}/{total}] {stdout}", flush=True)
            else:
                stderr = proc.stderr.read().decode().strip()
                print(f"FAILED {model_type} seed={seed}: {stderr[-500:]}", flush=True)

        if running and not done:
            time.sleep(1)

    os.remove(ckpt_path)
    shutil.rmtree(outdir, ignore_errors=True)

    summary = {}
    for model_type in ["gru", "mlp"]:
        lls = [results[(model_type, s)]["ll"] for s in SEEDS]
        accs = [results[(model_type, s)]["acc"] for s in SEEDS]

        all_nodes = set()
        for s in SEEDS:
            all_nodes.update(results[(model_type, s)]["per_node"].keys())

        per_node_mean = {}
        for node in sorted(all_nodes):
            node_accs = []
            for s in SEEDS:
                pn = results[(model_type, s)]["per_node"]
                if node in pn:
                    c, t = pn[node]
                    node_accs.append(c / t)
            if node_accs:
                per_node_mean[node] = (np.mean(node_accs), np.std(node_accs))

        summary[model_type] = {
            "ll_mean": np.mean(lls),
            "ll_std": np.std(lls),
            "acc_mean": np.mean(accs),
            "acc_std": np.std(accs),
            "per_node": per_node_mean,
        }

    print("\nStructural Observation Ablation: GRU vs MLP ", flush=True)
    print(f"{'Model':<6} {'LL (bits/dec)':<22} {'Accuracy':<20}", flush=True)
    for model_type in ["gru", "mlp"]:
        s = summary[model_type]
        print(f"{model_type.upper():<6} "
              f"{s['ll_mean']:.4f} +/- {s['ll_std']:.4f}     "
              f"{s['acc_mean']:.3f} +/- {s['acc_std']:.3f}", flush=True)
    print(f"BC LL (full-info ceiling) = {bc_ll:.4f}", flush=True)
    print(f"Random baseline           = {RANDOM_BASELINE:.4f}", flush=True)

    gru_ll = summary["gru"]["ll_mean"]
    mlp_ll = summary["mlp"]["ll_mean"]
    print(f"\nGRU - MLP gap: {gru_ll - mlp_ll:+.4f} bits/dec", flush=True)

    save_data = {
        "results": results,
        "summary": summary,
        "bc_ll": bc_ll,
        "random_baseline": RANDOM_BASELINE,
        "seeds": SEEDS,
        "obs_dim": obs_dim,
        "n_unique_obs": len(unique_obs),
        "n_epochs_policy": 200,
        "n_actions": N_ACTIONS,
        "n_trajs": n_trajs,
        "n_train": n_train,
        "n_val": n_val,
    }
    torch.save(save_data, "checkpoints/structural_obs_ablation.pt")
    print(f"\nSaved to checkpoints/structural_obs_ablation.pt", flush=True)

    elapsed = time.time() - t0
    print(f"Total time: {elapsed:.0f}s ({elapsed/60:.1f}min)", flush=True)


if __name__ == "__main__":
    main()
