"""Multi-seed ablation on water-UNRESTRICTED mice: GRU vs MLP across k-hop radii."""
import sys
sys.path.insert(0, ".")

import torch
import numpy as np
import time

from src.utils import load_everything
from src.maxent_irl import MaxCausalEntIRL, train_irl, compute_log_likelihood_mdp
from src.observations import precompute_khop_masks, trajectories_to_obs_dataset
from src.gru_policy import GRUPolicy, train_gru_policy, MLPPolicy, train_mlp_policy
from src.evaluation import (
    compute_log_likelihood_gru, compute_prediction_accuracy,
    compute_log_likelihood_mlp, compute_prediction_accuracy_mlp,
    compute_per_node_accuracy, compute_per_node_accuracy_mlp,
    compute_behavioral_cloning_ll,
)

SEEDS = [0, 1, 2, 3, 4]
K_VALUES = [1, 2, 3, 6]
N_EPOCHS = 150
HIDDEN_DIM = 128
OBS_DIM = 127
N_ACTIONS = 4


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)


def main():
    t0 = time.time()

    print("Loading water-unrestricted mouse data...")
    d = load_everything(restricted=False)
    G = d["G"]

    all_states = set()
    for traj in d["trajs"]:
        all_states.update(int(s) for s in traj["states"])
    traj_len = len(d["trajs"][0]["states"])
    print(f"Trajectories: {len(d['trajs'])} (train={len(d['train_trajs'])}, val={len(d['val_trajs'])})")
    print(f"Trajectory length: {traj_len}")
    print(f"Unique nodes visited: {len(all_states)}")
    print(f"Total state-action pairs: {len(d['trajs']) * traj_len}")

    print("\nTraining IRL...")
    set_seed(42)
    model = MaxCausalEntIRL(n_states=127, T=d["T"], gamma=0.99, n_vi_iters=200, l2_reg=0.01)
    irl_history = train_irl(model, d["train_sa"], d["val_sa"], n_epochs=300, lr=0.01, print_every=100)
    V, Q, pi = model.soft_vi(model.reward_params)
    Q_det = Q.detach()
    rewards = model.reward_params.detach().numpy()

    maxent_ll = compute_log_likelihood_mdp(Q_det, d["val_sa"])
    bc_ll = compute_behavioral_cloning_ll(d["train_sa"], d["val_sa"])
    print(f"MaxEnt IRL LL = {maxent_ll:.4f}, BC LL = {bc_ll:.4f}")

    top_reward_idx = np.argsort(rewards)[::-1][:10]
    print("\n  Top 10 reward nodes:")
    for i, idx in enumerate(top_reward_idx):
        print(f"{i+1}. Node {idx}: reward={rewards[idx]:.4f}")

    print("\nPrecomputing observation masks...")
    masks_cache = {}
    data_cache = {}
    for k in K_VALUES:
        masks_k, _ = precompute_khop_masks(G, k)
        masks_cache[k] = masks_k
        data_cache[k] = {
            "train": trajectories_to_obs_dataset(d["train_trajs"], masks_k, Q_det),
            "val": trajectories_to_obs_dataset(d["val_trajs"], masks_k, Q_det),
        }

    results = {}
    total_runs = len(SEEDS) * len(K_VALUES) * 2
    run_idx = 0

    for seed in SEEDS:
        for k in K_VALUES:
            train_data = data_cache[k]["train"]
            val_data = data_cache[k]["val"]

            for model_type in ["gru", "mlp"]:
                run_idx += 1
                print(f"\n{run_idx}/{total_runs}: {model_type.upper()} k={k} seed={seed}")
                set_seed(seed)

                if model_type == "gru":
                    pol = GRUPolicy(obs_dim=OBS_DIM, hidden_dim=HIDDEN_DIM, n_actions=N_ACTIONS)
                    pol, _ = train_gru_policy(pol, train_data, n_epochs=N_EPOCHS, lr=3e-4, print_every=N_EPOCHS)
                    ll = compute_log_likelihood_gru(pol, val_data)
                    acc = compute_prediction_accuracy(pol, val_data)
                    per_node = compute_per_node_accuracy(pol, val_data)
                else:
                    pol = MLPPolicy(obs_dim=OBS_DIM, hidden_dim=HIDDEN_DIM, n_actions=N_ACTIONS)
                    pol, _ = train_mlp_policy(pol, train_data, n_epochs=N_EPOCHS, lr=3e-4, print_every=N_EPOCHS)
                    ll = compute_log_likelihood_mlp(pol, val_data)
                    acc = compute_prediction_accuracy_mlp(pol, val_data)
                    per_node = compute_per_node_accuracy_mlp(pol, val_data)

                results[(model_type, k, seed)] = {
                    "ll": ll,
                    "acc": acc,
                    "per_node": per_node,
                }
                print(f"LL={ll:.4f}, Acc={acc:.3f}")

    summary = {}
    for model_type in ["gru", "mlp"]:
        for k in K_VALUES:
            lls = [results[(model_type, k, s)]["ll"] for s in SEEDS]
            accs = [results[(model_type, k, s)]["acc"] for s in SEEDS]

            all_nodes = set()
            for s in SEEDS:
                all_nodes.update(results[(model_type, k, s)]["per_node"].keys())

            per_node_mean = {}
            for node in sorted(all_nodes):
                node_accs = []
                for s in SEEDS:
                    pn = results[(model_type, k, s)]["per_node"]
                    if node in pn:
                        c, t = pn[node]
                        node_accs.append(c / t)
                if node_accs:
                    per_node_mean[node] = (np.mean(node_accs), np.std(node_accs))

            summary[(model_type, k)] = {
                "ll_mean": np.mean(lls),
                "ll_std": np.std(lls),
                "acc_mean": np.mean(accs),
                "acc_std": np.std(accs),
                "per_node": per_node_mean,
            }

    for model_type in ["gru", "mlp"]:
        for k in K_VALUES:
            s = summary[(model_type, k)]
            print(f"{model_type.upper()} k={k}: LL {s['ll_mean']:.4f} +/- {s['ll_std']:.4f}, "
                  f"acc {s['acc_mean']:.3f} +/- {s['acc_std']:.3f}")
    print(f"MaxEnt IRL LL={maxent_ll:.4f}, BC LL={bc_ll:.4f}")

    save_data = {
        "results": results,
        "summary": summary,
        "maxent_ll": maxent_ll,
        "bc_ll": bc_ll,
        "seeds": SEEDS,
        "k_values": K_VALUES,
        "n_epochs": N_EPOCHS,
        "irl_rewards": rewards,
        "irl_history": irl_history,
        "dataset_info": {
            "n_trajs": len(d["trajs"]),
            "n_train": len(d["train_trajs"]),
            "n_val": len(d["val_trajs"]),
            "traj_len": traj_len,
            "n_unique_nodes": len(all_states),
            "unique_nodes": sorted(all_states),
        },
    }
    torch.save(save_data, "checkpoints/ablation_unrestricted.pt")
    print(f"\nSaved to checkpoints/ablation_unrestricted.pt")

    elapsed = time.time() - t0
    print(f"Total time: {elapsed:.0f}s ({elapsed/60:.1f}min)")


if __name__ == "__main__":
    main()
