"""Bootstrap CI on the GRU vs Bayes-optimal filter gap."""
import os
import sys

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

import numpy as np
import torch
from collections import defaultdict
from sklearn.linear_model import LogisticRegression

from src.gru_policy import GRUPolicy
from src.rosenberg_data import load_rosenberg_everything, build_bc_targets
from src.stats import bca_bootstrap_ci, bca_bootstrap_ci_paired

sys.path.insert(0, os.path.join(PROJECT_ROOT, "scripts"))
from bayesian_filter_baseline import build_obs_class_map, build_transition_matrices

N_STATES = 127
N_ACTIONS = 3
HIDDEN_DIM = 128
OBS_DIM = 12
MAX_SEQ_LEN = 200
SEEDS = [0, 1, 2, 3, 4]
N_BOOTSTRAP = 9999


def build_structural_obs(node, n_states=127):
    first_leaf = (n_states - 1) // 2
    is_root = 1.0 if node == 0 else 0.0
    is_leaf = 1.0 if node >= first_leaf else 0.0
    if node == 0:
        degree = 2.0
    elif node >= first_leaf:
        degree = 1.0
    else:
        degree = 3.0
    left_dest = 2 * node + 1 if 2 * node + 1 < n_states else node
    right_dest = 2 * node + 2 if 2 * node + 2 < n_states else node
    reverse_dest = (node - 1) // 2 if node > 0 else 0
    features = [degree / 3.0, is_root, is_leaf]
    for dest in [left_dest, right_dest, reverse_dest]:
        d_is_root = 1.0 if dest == 0 else 0.0
        d_is_leaf = 1.0 if dest >= first_leaf else 0.0
        if dest == 0:
            d_degree = 2.0
        elif dest >= first_leaf:
            d_degree = 1.0
        else:
            d_degree = 3.0
        features.extend([d_is_leaf, d_is_root, d_degree / 3.0])
    return torch.tensor(features, dtype=torch.float32)


def action_prediction_bootstrap():

    ckpt = torch.load("checkpoints/disagreement_analysis.pt", weights_only=False)
    records = ckpt["records"]
    gru_overall = ckpt["gru_overall_accuracy"]
    bayes_overall = ckpt["bayes_overall_accuracy"]

    print(f"Loaded {len(records)} timesteps, {ckpt['n_val_trajs']} trajectories")
    print(f"GRU overall action accuracy: {100 * gru_overall:.1f}%")
    print(f"Bayes overall action accuracy: {100 * bayes_overall:.1f}%")
    print(f"Raw gap: {100 * (gru_overall - bayes_overall):.1f} pp\n")

    traj_gru = defaultdict(lambda: [0, 0])   # [correct, total]
    traj_bayes = defaultdict(lambda: [0, 0])

    for r in records:
        ti = r["traj_idx"]
        traj_gru[ti][0] += r["gru_correct"]
        traj_gru[ti][1] += 1
        traj_bayes[ti][0] += r["bayes_correct"]
        traj_bayes[ti][1] += 1

    traj_ids = sorted(traj_gru.keys())
    n_trajs = len(traj_ids)

    gru_accs = np.array([traj_gru[ti][0] / traj_gru[ti][1] for ti in traj_ids])
    bayes_accs = np.array([traj_bayes[ti][0] / traj_bayes[ti][1] for ti in traj_ids])
    weights = np.array([traj_gru[ti][1] for ti in traj_ids], dtype=float)
    diffs = gru_accs - bayes_accs

    gru_wmean = np.average(gru_accs, weights=weights)
    bayes_wmean = np.average(bayes_accs, weights=weights)
    diff_wmean = gru_wmean - bayes_wmean

    rng = np.random.RandomState(42)
    boot_diffs = []
    for _ in range(N_BOOTSTRAP):
        idx = rng.choice(n_trajs, size=n_trajs, replace=True)
        w = weights[idx]
        g = np.average(gru_accs[idx], weights=w)
        b = np.average(bayes_accs[idx], weights=w)
        boot_diffs.append(g - b)
    boot_diffs = np.array(boot_diffs)

    ci_lo = np.percentile(boot_diffs, 2.5)
    ci_hi = np.percentile(boot_diffs, 97.5)

    p_one_sided = np.mean(boot_diffs <= 0)

    unwt_diff = np.mean(diffs)
    unwt_se = np.std(diffs, ddof=1) / np.sqrt(n_trajs)

    n_gru_wins = np.sum(diffs > 0)
    n_bayes_wins = np.sum(diffs < 0)
    n_ties = np.sum(diffs == 0)

    print(f"paired bootstrap ({n_trajs} trajs): gap {100 * diff_wmean:.2f} pp, "
          f"95% CI [{100 * ci_lo:.2f}, {100 * ci_hi:.2f}], p={p_one_sided:.4f}")
    print(f"unweighted {100 * unwt_diff:.2f} pp (SE={100 * unwt_se:.2f}), "
          f"GRU/tie/Bayes wins: {n_gru_wins}/{n_ties}/{n_bayes_wins}")

    return {
        "gap_weighted": diff_wmean,
        "ci_lo": ci_lo,
        "ci_hi": ci_hi,
        "p_one_sided": p_one_sided,
        "gap_unweighted": unwt_diff,
        "se_unweighted": unwt_se,
        "n_trajs": n_trajs,
        "n_gru_wins": int(n_gru_wins),
        "n_bayes_wins": int(n_bayes_wins),
        "n_ties": int(n_ties),
        "boot_diffs": boot_diffs,
    }


def node_decoding_bootstrap():
    """Compute probe accuracy for each seed model, CI on gap vs Bayes."""

    d = load_rosenberg_everything()
    val_trajs = d["val_trajs"]
    train_sa = d["train_sa"]

    bc_policy = build_bc_targets(train_sa, n_states=N_STATES,
                                 n_actions=N_ACTIONS, laplace=1.0)

    structural_obs = {}
    for s in range(N_STATES):
        structural_obs[s] = build_structural_obs(s, N_STATES)

    val_data = []
    for traj in val_trajs:
        states = traj["states"]
        actions = traj["actions"]
        T = len(actions)
        for start in range(0, T, MAX_SEQ_LEN):
            end = min(start + MAX_SEQ_LEN, T)
            chunk_states = states[start:end + 1]
            obs_seq = torch.stack([structural_obs[s] for s in chunk_states])
            val_data.append({
                "obs": obs_seq,
                "states": [int(s) for s in chunk_states],
            })

    print(f"{len(val_trajs)} val trajectories, {len(val_data)} chunks")

    print("Running Bayes obs-only filter...", flush=True)
    obs_class, obs_masks = build_obs_class_map()
    T_action = build_transition_matrices()
    bc_np = bc_policy.numpy()

    M = np.zeros((N_STATES, N_STATES))
    for a in range(N_ACTIONS):
        M += np.diag(bc_np[:, a]) @ T_action[a]

    bayes_preds_all = []  # (true_state, pred_state) per timestep
    for traj in val_trajs:
        states = traj["states"]
        actions = traj["actions"]
        T = len(actions)
        for start in range(0, T, MAX_SEQ_LEN):
            end = min(start + MAX_SEQ_LEN, T)
            cs = states[start:end + 1]
            ca = actions[start:end]

            c0 = obs_class[cs[0]]
            belief = obs_masks[c0].copy()
            belief /= belief.sum()

            for t in range(len(ca)):
                true_s = cs[t]
                pred_s = int(np.argmax(belief))
                bayes_preds_all.append((int(true_s), pred_s))

                new_belief = belief @ M
                if t + 1 < len(cs):
                    c_next = obs_class[cs[t + 1]]
                    new_belief *= obs_masks[c_next]
                    total = new_belief.sum()
                    if total > 1e-30:
                        new_belief /= total
                    else:
                        new_belief = obs_masks[c_next].copy()
                        new_belief /= new_belief.sum()
                belief = new_belief

    bayes_correct = np.array([1 if t == p else 0
                              for t, p in bayes_preds_all])
    bayes_acc = bayes_correct.mean()
    print(f"Bayes obs-only node accuracy: {100 * bayes_acc:.1f}%")

    print("\nComputing per-seed GRU probe accuracy...", flush=True)
    seed_probe_accs = []

    for seed in SEEDS:
        model_path = f"checkpoints/structural_gru_model_seed{seed}.pt"
        if not os.path.exists(model_path):
            print(f"seed {seed}: model not found, skipping")
            continue

        policy = GRUPolicy(obs_dim=OBS_DIM, hidden_dim=HIDDEN_DIM,
                           n_actions=N_ACTIONS)
        state_dict = torch.load(model_path, weights_only=False)
        policy.load_state_dict(state_dict)
        policy.eval()

        all_hidden = []
        all_states = []

        for chunk in val_data:
            obs = chunk["obs"].unsqueeze(0)
            states_chunk = chunk["states"]
            with torch.no_grad():
                encoded = policy.obs_encoder(obs)
                output, _ = policy.gru(encoded)
                h = output[0].numpy()  # (T+1, hidden_dim)
            n_act = len(states_chunk) - 1
            all_hidden.append(h[:n_act])
            all_states.extend(states_chunk[:n_act])

        H = np.concatenate(all_hidden, axis=0)
        positions = np.array(all_states)

        n = len(positions)
        rng = np.random.RandomState(seed)
        perm = rng.permutation(n)
        n_train = int(0.8 * n)
        train_idx = perm[:n_train]
        test_idx = perm[n_train:]

        clf = LogisticRegression(max_iter=1000, solver="lbfgs", C=1.0)
        clf.fit(H[train_idx], positions[train_idx])
        acc = clf.score(H[test_idx], positions[test_idx])
        seed_probe_accs.append(acc)

        print(f"seed {seed}: probe_acc = {100 * acc:.1f}%")

    seed_probe_accs = np.array(seed_probe_accs)
    probe_mean = seed_probe_accs.mean()
    probe_std = seed_probe_accs.std()

    gaps = seed_probe_accs - bayes_acc
    gap_mean = gaps.mean()

    ci_lo, ci_hi = bca_bootstrap_ci(gaps, n_resamples=N_BOOTSTRAP)

    all_above = np.all(seed_probe_accs > bayes_acc)

    print(f"GRU probe {100*probe_mean:.1f}% +/- {100*probe_std:.1f}%, "
          f"Bayes {100*bayes_acc:.1f}%, gap {100*gap_mean:.1f} pp "
          f"95% CI [{100*ci_lo:.1f}, {100*ci_hi:.1f}], all seeds above: {all_above}")
    print(f"Min seed probe acc: {100 * seed_probe_accs.min():.1f}%")

    return {
        "seed_probe_accs": seed_probe_accs,
        "bayes_acc": bayes_acc,
        "gap_mean": gap_mean,
        "ci_lo": ci_lo,
        "ci_hi": ci_hi,
        "all_above": bool(all_above),
    }


def node_decoding_trajectory_bootstrap():
    """Per-trajectory node decoding accuracy, GRU probe vs Bayes belief."""

    d = load_rosenberg_everything()
    val_trajs = d["val_trajs"]
    train_sa = d["train_sa"]

    bc_policy = build_bc_targets(train_sa, n_states=N_STATES,
                                 n_actions=N_ACTIONS, laplace=1.0)
    bc_np = bc_policy.numpy()

    structural_obs = {}
    for s in range(N_STATES):
        structural_obs[s] = build_structural_obs(s, N_STATES)

    obs_class, obs_masks = build_obs_class_map()
    T_action = build_transition_matrices()
    M = np.zeros((N_STATES, N_STATES))
    for a in range(N_ACTIONS):
        M += np.diag(bc_np[:, a]) @ T_action[a]

    policy = GRUPolicy(obs_dim=OBS_DIM, hidden_dim=HIDDEN_DIM,
                       n_actions=N_ACTIONS)
    state_dict = torch.load("checkpoints/structural_gru_model.pt",
                            weights_only=False)
    policy.load_state_dict(state_dict)
    policy.eval()

    traj_records = []  # list of (traj_idx, [(true_s, gru_probe_correct, bayes_correct), ...])

    for ti, traj in enumerate(val_trajs):
        states = traj["states"]
        actions = traj["actions"]
        T = len(actions)

        traj_hidden = []
        traj_true_states = []
        traj_bayes_correct = []

        for start in range(0, T, MAX_SEQ_LEN):
            end = min(start + MAX_SEQ_LEN, T)
            cs = states[start:end + 1]
            ca = actions[start:end]

            obs_seq = torch.stack([structural_obs[s] for s in cs]).unsqueeze(0)
            with torch.no_grad():
                encoded = policy.obs_encoder(obs_seq)
                output, _ = policy.gru(encoded)
                h = output[0].numpy()

            c0 = obs_class[cs[0]]
            belief = obs_masks[c0].copy()
            belief /= belief.sum()

            for t in range(len(ca)):
                traj_hidden.append(h[t])
                traj_true_states.append(int(cs[t]))

                pred_s = int(np.argmax(belief))
                traj_bayes_correct.append(int(pred_s == int(cs[t])))

                new_belief = belief @ M
                if t + 1 < len(cs):
                    c_next = obs_class[cs[t + 1]]
                    new_belief *= obs_masks[c_next]
                    total = new_belief.sum()
                    if total > 1e-30:
                        new_belief /= total
                    else:
                        new_belief = obs_masks[c_next].copy()
                        new_belief /= new_belief.sum()
                belief = new_belief

        traj_records.append({
            "traj_idx": ti,
            "hidden": np.array(traj_hidden),
            "true_states": np.array(traj_true_states),
            "bayes_correct": np.array(traj_bayes_correct),
        })

    all_hidden = np.concatenate([tr["hidden"] for tr in traj_records])
    all_states = np.concatenate([tr["true_states"] for tr in traj_records])

    n = len(all_states)
    rng = np.random.RandomState(42)
    perm = rng.permutation(n)
    n_train = int(0.8 * n)
    train_idx = perm[:n_train]

    clf = LogisticRegression(max_iter=1000, solver="lbfgs", C=1.0)
    clf.fit(all_hidden[train_idx], all_states[train_idx])

    all_probe_preds = clf.predict(all_hidden)

    offset = 0
    for tr in traj_records:
        n_t = len(tr["true_states"])
        tr["probe_correct"] = (
            all_probe_preds[offset:offset + n_t] == tr["true_states"]
        ).astype(int)
        offset += n_t

    gru_traj_accs = []
    bayes_traj_accs = []
    traj_weights = []

    for tr in traj_records:
        n_t = len(tr["true_states"])
        gru_traj_accs.append(tr["probe_correct"].mean())
        bayes_traj_accs.append(tr["bayes_correct"].mean())
        traj_weights.append(n_t)

    gru_traj_accs = np.array(gru_traj_accs)
    bayes_traj_accs = np.array(bayes_traj_accs)
    traj_weights = np.array(traj_weights, dtype=float)
    n_trajs = len(gru_traj_accs)

    gru_wmean = np.average(gru_traj_accs, weights=traj_weights)
    bayes_wmean = np.average(bayes_traj_accs, weights=traj_weights)

    rng = np.random.RandomState(42)
    boot_diffs = []
    for _ in range(N_BOOTSTRAP):
        idx = rng.choice(n_trajs, size=n_trajs, replace=True)
        w = traj_weights[idx]
        g = np.average(gru_traj_accs[idx], weights=w)
        b = np.average(bayes_traj_accs[idx], weights=w)
        boot_diffs.append(g - b)
    boot_diffs = np.array(boot_diffs)

    ci_lo = np.percentile(boot_diffs, 2.5)
    ci_hi = np.percentile(boot_diffs, 97.5)
    p_one_sided = np.mean(boot_diffs <= 0)

    diffs = gru_traj_accs - bayes_traj_accs
    n_gru_wins = np.sum(diffs > 0)
    n_bayes_wins = np.sum(diffs < 0)
    n_ties = np.sum(diffs == 0)

    print(f"GRU probe {100 * gru_wmean:.1f}% vs Bayes {100 * bayes_wmean:.1f}%, "
          f"gap {100 * (gru_wmean - bayes_wmean):.2f} pp, "
          f"95% CI [{100 * ci_lo:.2f}, {100 * ci_hi:.2f}], p={p_one_sided:.4f}")
    print(f"GRU/tie/Bayes wins: {n_gru_wins}/{n_ties}/{n_bayes_wins}")

    return {
        "gru_node_acc": gru_wmean,
        "bayes_node_acc": bayes_wmean,
        "gap_weighted": gru_wmean - bayes_wmean,
        "ci_lo": ci_lo,
        "ci_hi": ci_hi,
        "p_one_sided": p_one_sided,
        "n_trajs": n_trajs,
    }


def main():
    print("Reviewer Concern #2: Statistical Reliability of GRU vs Bayes Gap \n")

    action_results = action_prediction_bootstrap()

    seed_results = node_decoding_bootstrap()

    traj_results = node_decoding_trajectory_bootstrap()

    os.makedirs("checkpoints", exist_ok=True)
    torch.save({
        "action_prediction": action_results,
        "node_decoding_seeds": seed_results,
        "node_decoding_traj": traj_results,
    }, "checkpoints/bayes_gap_ci.pt")

    print("\nSUMMARY FOR REBUTTAL ")

    print("\n1. Action prediction (trajectory-level paired bootstrap, n=315):")
    print(f"GRU: {100 * (action_results['gap_weighted'] + 0.503):.1f}%  "
          f"Bayes: ~50.3%")
    print(f"Gap: {100 * action_results['gap_weighted']:.2f} pp  "
          f"95% CI: [{100 * action_results['ci_lo']:.2f}, "
          f"{100 * action_results['ci_hi']:.2f}]")
    print(f"p(gap<=0) = {action_results['p_one_sided']:.4f}")

    print(f"\n2. Node decoding (5 seeds, 128-dim GRU vs deterministic Bayes):")
    print(f"GRU probe: {100 * seed_results['seed_probe_accs'].mean():.1f}% "
          f"+/- {100 * seed_results['seed_probe_accs'].std():.1f}%")
    print(f"Bayes: {100 * seed_results['bayes_acc']:.1f}%")
    print(f"Gap: {100 * seed_results['gap_mean']:.1f} pp  "
          f"95% CI: [{100 * seed_results['ci_lo']:.1f}, "
          f"{100 * seed_results['ci_hi']:.1f}]")
    print(f"All 5 seeds exceed Bayes: {seed_results['all_above']}")

    print(f"\n3. Node decoding (trajectory-level paired bootstrap, n={traj_results['n_trajs']}):")
    print(f"GRU probe: {100 * traj_results['gru_node_acc']:.1f}%  "
          f"Bayes: {100 * traj_results['bayes_node_acc']:.1f}%")
    print(f"Gap: {100 * traj_results['gap_weighted']:.2f} pp  "
          f"95% CI: [{100 * traj_results['ci_lo']:.2f}, "
          f"{100 * traj_results['ci_hi']:.2f}]")
    print(f"p(gap<=0) = {traj_results['p_one_sided']:.4f}")

    print(f"\nSaved to checkpoints/bayes_gap_ci.pt")


if __name__ == "__main__":
    main()
