"""GRU vs Bayesian filter disagreement analysis by region, ambiguity, and phase."""
import os
import sys
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

import numpy as np
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import bootstrap
from collections import defaultdict

from src.rosenberg_data import load_rosenberg_everything, build_bc_targets
from src.gru_policy import GRUPolicy
from src.utils import trajectories_to_sa_pairs

sys.path.insert(0, os.path.join(PROJECT_ROOT, "scripts"))
from bayesian_filter_baseline import build_obs_class_map, build_transition_matrices

N_STATES = 127
N_ACTIONS = 3
MAX_SEQ_LEN = 200
HIDDEN_DIM = 128
OBS_DIM = 12

COLOR_GRU = "#4393c3"
COLOR_BAYES = "#d6604d"
DPI = 200
FONTSIZE = 11


def build_structural_obs(node, n_states=127):
    first_leaf = (n_states - 1) // 2
    is_root = 1.0 if node == 0 else 0.0
    is_leaf = 1.0 if node >= first_leaf else 0.0
    if node == 0:
        degree = 2.0
    elif node >= first_leaf:
        degree = 1.0
    else:
        degree = 3.0

    left_dest = 2 * node + 1 if 2 * node + 1 < n_states else node
    right_dest = 2 * node + 2 if 2 * node + 2 < n_states else node
    reverse_dest = (node - 1) // 2 if node > 0 else 0

    features = [degree / 3.0, is_root, is_leaf]
    for dest in [left_dest, right_dest, reverse_dest]:
        d_is_root = 1.0 if dest == 0 else 0.0
        d_is_leaf = 1.0 if dest >= first_leaf else 0.0
        if dest == 0:
            d_degree = 2.0
        elif dest >= first_leaf:
            d_degree = 1.0
        else:
            d_degree = 3.0
        features.extend([d_is_leaf, d_is_root, d_degree / 3.0])
    return torch.tensor(features, dtype=torch.float32)


def build_obs_dataset_structural(trajs, structural_obs, bc_policy, max_len=MAX_SEQ_LEN):
    dataset = []
    for traj in trajs:
        states = traj["states"]
        actions = traj["actions"]
        T = len(actions)
        for start in range(0, T, max_len):
            end = min(start + max_len, T)
            chunk_states = states[start:end + 1]
            chunk_actions = actions[start:end]
            obs_seq = torch.stack([structural_obs[s] for s in chunk_states])
            target_seq = torch.stack([bc_policy[s] for s in chunk_states])
            action_seq = torch.tensor([int(a) for a in chunk_actions], dtype=torch.long)
            state_seq = [int(s) for s in chunk_states]
            dataset.append({
                "obs": obs_seq,
                "targets": target_seq,
                "actions": action_seq,
                "states": state_seq,
            })
    return dataset


def get_depth(node):
    """Return depth of node in binary tree (root=0)."""
    d = 0
    n = node
    while n > 0:
        n = (n - 1) // 2
        d += 1
    return d


def depth_to_region(d):
    if d <= 2:
        return "Root (d0-2)"
    elif d <= 4:
        return "Mid-tree (d3-4)"
    else:
        return "Periphery (d5-6)"


def obs_class_to_ambiguity(c):
    if c == 0:
        return "Unique (1)"
    elif c == 1:
        return "Low (2)"
    else:
        return "High (28+)"


def phase_label(timestep, chunk_len):
    frac = timestep / chunk_len
    if frac < 0.25:
        return "Early"
    elif frac < 0.75:
        return "Middle"
    else:
        return "Late"


def bca_ci(traj_accs_by_traj, n_resamples=9999, confidence=0.95):
    """BCa 95% CI over trajectory-level accuracies; returns (mean, ci_low, ci_high)."""
    traj_indices = sorted(traj_accs_by_traj.keys())
    if len(traj_indices) < 2:
        total_c = sum(v[0] for v in traj_accs_by_traj.values())
        total_n = sum(v[1] for v in traj_accs_by_traj.values())
        mean = total_c / total_n if total_n > 0 else 0.0
        return mean, mean, mean

    traj_accs = np.array([traj_accs_by_traj[i][0] / traj_accs_by_traj[i][1]
                          for i in traj_indices])
    traj_weights = np.array([traj_accs_by_traj[i][1] for i in traj_indices],
                            dtype=float)

    def weighted_mean(indices):
        idx = indices[0]  # bootstrap gives (n,) array
        w = traj_weights[idx]
        return np.average(traj_accs[idx], weights=w)

    overall_mean = np.average(traj_accs, weights=traj_weights)

    import warnings
    rng = np.random.RandomState(42)
    boot_means = []
    for _ in range(n_resamples):
        idx = rng.choice(len(traj_indices), size=len(traj_indices), replace=True)
        w = traj_weights[idx]
        boot_means.append(np.average(traj_accs[idx], weights=w))
    boot_means = np.array(boot_means)

    try:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            res = bootstrap(
                (np.arange(len(traj_indices)),),
                weighted_mean,
                n_resamples=n_resamples,
                method='BCa',
                confidence_level=confidence,
                random_state=42,
            )
            lo, hi = res.confidence_interval.low, res.confidence_interval.high
            if np.isnan(lo) or np.isnan(hi):
                raise ValueError("NaN CI")
            return overall_mean, lo, hi
    except Exception:
        alpha = (1 - confidence) / 2
        return overall_mean, np.percentile(boot_means, 100 * alpha), np.percentile(boot_means, 100 * (1 - alpha))


def compute_breakdown(records, category_fn, category_order):
    """Compute GRU and Bayes accuracy + BCa CIs for each category."""
    results = {}
    for cat in category_order:
        gru_by_traj = defaultdict(lambda: [0, 0])
        bayes_by_traj = defaultdict(lambda: [0, 0])

        for r in records:
            if category_fn(r) != cat:
                continue
            ti = r['traj_idx']
            gru_by_traj[ti][0] += r['gru_correct']
            gru_by_traj[ti][1] += 1
            bayes_by_traj[ti][0] += r['bayes_correct']
            bayes_by_traj[ti][1] += 1

        gru_traj = {k: tuple(v) for k, v in gru_by_traj.items()}
        bayes_traj = {k: tuple(v) for k, v in bayes_by_traj.items()}

        gru_mean, gru_lo, gru_hi = bca_ci(gru_traj)
        bayes_mean, bayes_lo, bayes_hi = bca_ci(bayes_traj)

        n_total = sum(v[1] for v in gru_traj.values())
        results[cat] = {
            'gru': (gru_mean, gru_lo, gru_hi),
            'bayes': (bayes_mean, bayes_lo, bayes_hi),
            'n': n_total,
        }
    return results


def plot_contingency(contingency, save_path):
    total = contingency.sum()
    labels = np.array([[f"{contingency[i, j]}\n({100 * contingency[i, j] / total:.1f}%)"
                        for j in range(2)] for i in range(2)])

    fig, ax = plt.subplots(figsize=(5, 4))
    sns.heatmap(contingency, annot=labels, fmt='', cmap='Blues',
                xticklabels=['Bayes correct', 'Bayes incorrect'],
                yticklabels=['GRU correct', 'GRU incorrect'],
                ax=ax, cbar_kws={'label': 'Count'},
                annot_kws={'fontsize': FONTSIZE})
    ax.set_title('GRU vs Bayesian Filter Agreement', fontsize=FONTSIZE + 1)
    ax.tick_params(labelsize=FONTSIZE)
    fig.tight_layout()
    fig.savefig(save_path, dpi=DPI)
    plt.close(fig)
    print(f"Saved {save_path}")


def plot_grouped_bars(breakdown, category_order, title, xlabel, save_path):
    x = np.arange(len(category_order))
    width = 0.35

    gru_means = [breakdown[c]['gru'][0] for c in category_order]
    gru_lo = [max(0, breakdown[c]['gru'][0] - breakdown[c]['gru'][1]) for c in category_order]
    gru_hi = [max(0, breakdown[c]['gru'][2] - breakdown[c]['gru'][0]) for c in category_order]
    bayes_means = [breakdown[c]['bayes'][0] for c in category_order]
    bayes_lo = [max(0, breakdown[c]['bayes'][0] - breakdown[c]['bayes'][1]) for c in category_order]
    bayes_hi = [max(0, breakdown[c]['bayes'][2] - breakdown[c]['bayes'][0]) for c in category_order]

    fig, ax = plt.subplots(figsize=(6, 4))
    ax.bar(x - width / 2, gru_means, width, yerr=[gru_lo, gru_hi],
           label='GRU', color=COLOR_GRU, capsize=4, error_kw={'linewidth': 1.2})
    ax.bar(x + width / 2, bayes_means, width, yerr=[bayes_lo, bayes_hi],
           label='Bayes', color=COLOR_BAYES, capsize=4, error_kw={'linewidth': 1.2})

    ax.axhline(1 / 3, color='gray', linestyle='--', linewidth=1, label='Chance (1/3)')
    ax.set_xlabel(xlabel, fontsize=FONTSIZE)
    ax.set_ylabel('Action Prediction Accuracy', fontsize=FONTSIZE)
    ax.set_title(title, fontsize=FONTSIZE + 1)
    ax.set_xticks(x)
    ax.set_xticklabels(category_order, fontsize=FONTSIZE - 1)
    ax.tick_params(labelsize=FONTSIZE - 1)
    ax.legend(fontsize=FONTSIZE - 1)
    ax.set_ylim(0, max(max(gru_means), max(bayes_means)) * 1.3)
    fig.tight_layout()
    fig.savefig(save_path, dpi=DPI)
    plt.close(fig)
    print(f"Saved {save_path}")


def main():

    d = load_rosenberg_everything()
    val_trajs = d['val_trajs']
    train_sa = d['train_sa']
    depths = d['depths']
    print(f"{len(d['trajs'])} total bouts, {len(val_trajs)} validation", flush=True)

    bc_policy = build_bc_targets(train_sa, n_states=N_STATES, n_actions=N_ACTIONS, laplace=1.0)
    bc_policy_np = bc_policy.numpy()

    obs_class, obs_masks = build_obs_class_map()
    T_action = build_transition_matrices()

    M = np.zeros((N_STATES, N_STATES))
    for a in range(N_ACTIONS):
        M += np.diag(bc_policy_np[:, a]) @ T_action[a]

    structural_obs = {}
    for s in range(N_STATES):
        structural_obs[s] = build_structural_obs(s, N_STATES)

    val_data = build_obs_dataset_structural(val_trajs, structural_obs, bc_policy)
    print(f"{len(val_data)} validation chunks", flush=True)

    policy = GRUPolicy(obs_dim=OBS_DIM, hidden_dim=HIDDEN_DIM, n_actions=N_ACTIONS)
    state_dict = torch.load("checkpoints/structural_gru_model.pt", weights_only=False)
    policy.load_state_dict(state_dict)
    policy.eval()
    print("GRU model loaded", flush=True)


    records = []
    chunk_counter = 0

    traj_chunk_map = []
    for ti, traj in enumerate(val_trajs):
        T = len(traj['actions'])
        for start in range(0, T, MAX_SEQ_LEN):
            traj_chunk_map.append(ti)

    assert len(traj_chunk_map) == len(val_data), \
        f"Chunk count mismatch: {len(traj_chunk_map)} vs {len(val_data)}"

    for ci, chunk in enumerate(val_data):
        traj_idx = traj_chunk_map[ci]
        states = chunk['states']
        actions = chunk['actions']
        obs = chunk['obs']
        chunk_len = len(actions)

        with torch.no_grad():
            obs_input = obs.unsqueeze(0)
            logits, _ = policy(obs_input)
            gru_preds = logits[0].argmax(dim=-1).numpy()

        c0 = obs_class[states[0]]
        belief = obs_masks[c0].copy()
        belief /= belief.sum()

        for t in range(chunk_len):
            true_s = states[t]
            true_a = int(actions[t])
            c = obs_class[true_s]
            d = get_depth(true_s)

            gru_pred = int(gru_preds[t])
            gru_correct = int(gru_pred == true_a)

            bayes_pi = belief @ bc_policy_np
            bayes_pred = int(np.argmax(bayes_pi))
            bayes_correct = int(bayes_pred == true_a)

            records.append({
                'traj_idx': traj_idx,
                'chunk_idx': ci,
                'timestep': t,
                'true_state': true_s,
                'true_action': true_a,
                'gru_pred': gru_pred,
                'gru_correct': gru_correct,
                'bayes_pred': bayes_pred,
                'bayes_correct': bayes_correct,
                'obs_class': c,
                'depth': d,
                'chunk_len': chunk_len,
            })

            new_belief = belief @ M
            if t + 1 < len(states):
                c_next = obs_class[states[t + 1]]
                new_belief *= obs_masks[c_next]
                total = new_belief.sum()
                if total > 1e-30:
                    new_belief /= total
                else:
                    new_belief = obs_masks[c_next].copy()
                    new_belief /= new_belief.sum()
            belief = new_belief

        chunk_counter += 1

    assert chunk_counter == len(val_data), \
        f"Processed {chunk_counter} chunks but expected {len(val_data)}"

    total_timesteps = len(records)
    gru_overall = sum(r['gru_correct'] for r in records) / total_timesteps
    bayes_overall = sum(r['bayes_correct'] for r in records) / total_timesteps
    print(f"{total_timesteps} timesteps collected", flush=True)
    print(f"GRU overall accuracy: {100 * gru_overall:.1f}%", flush=True)
    print(f"Bayes overall accuracy: {100 * bayes_overall:.1f}%", flush=True)

    contingency = np.zeros((2, 2), dtype=int)
    for r in records:
        contingency[1 - r['gru_correct'], 1 - r['bayes_correct']] += 1

    n = contingency.sum()
    agree = contingency[0,0] + contingency[1,1]
    print(f"agreement {100*agree/n:.1f}%, GRU only {100*contingency[0,1]/n:.1f}%, Bayes only {100*contingency[1,0]/n:.1f}%")


    region_order = ["Root (d0-2)", "Mid-tree (d3-4)", "Periphery (d5-6)"]
    ambiguity_order = ["Unique (1)", "Low (2)", "High (28+)"]
    phase_order = ["Early", "Middle", "Late"]

    region_breakdown = compute_breakdown(
        records, lambda r: depth_to_region(r['depth']), region_order)
    for cat in region_order:
        g = region_breakdown[cat]
        print(f"{cat}: GRU={100*g['gru'][0]:.1f}% [{100*g['gru'][1]:.1f}-{100*g['gru'][2]:.1f}], "
              f"Bayes={100*g['bayes'][0]:.1f}% [{100*g['bayes'][1]:.1f}-{100*g['bayes'][2]:.1f}] (n={g['n']})")

    ambiguity_breakdown = compute_breakdown(
        records, lambda r: obs_class_to_ambiguity(r['obs_class']), ambiguity_order)
    for cat in ambiguity_order:
        g = ambiguity_breakdown[cat]
        print(f"{cat}: GRU={100*g['gru'][0]:.1f}% [{100*g['gru'][1]:.1f}-{100*g['gru'][2]:.1f}], "
              f"Bayes={100*g['bayes'][0]:.1f}% [{100*g['bayes'][1]:.1f}-{100*g['bayes'][2]:.1f}] (n={g['n']})")

    phase_breakdown = compute_breakdown(
        records, lambda r: phase_label(r['timestep'], r['chunk_len']), phase_order)
    for cat in phase_order:
        g = phase_breakdown[cat]
        print(f"{cat}: GRU={100*g['gru'][0]:.1f}% [{100*g['gru'][1]:.1f}-{100*g['gru'][2]:.1f}], "
              f"Bayes={100*g['bayes'][0]:.1f}% [{100*g['bayes'][1]:.1f}-{100*g['bayes'][2]:.1f}] (n={g['n']})")

    os.makedirs("figures", exist_ok=True)

    plot_contingency(contingency, "figures/disagreement_contingency.png")
    plot_grouped_bars(region_breakdown, region_order,
                      "Action Accuracy by Tree Region", "Region",
                      "figures/disagreement_by_region.png")
    plot_grouped_bars(ambiguity_breakdown, ambiguity_order,
                      "Action Accuracy by Observation Ambiguity", "Ambiguity",
                      "figures/disagreement_by_ambiguity.png")
    plot_grouped_bars(phase_breakdown, phase_order,
                      "Action Accuracy by Trial Phase", "Phase",
                      "figures/disagreement_by_phase.png")

    os.makedirs("checkpoints", exist_ok=True)
    save_data = {
        'records': records,
        'contingency': contingency,
        'region_breakdown': region_breakdown,
        'ambiguity_breakdown': ambiguity_breakdown,
        'phase_breakdown': phase_breakdown,
        'gru_overall_accuracy': gru_overall,
        'bayes_overall_accuracy': bayes_overall,
        'total_timesteps': total_timesteps,
        'n_val_trajs': len(val_trajs),
        'n_chunks': len(val_data),
        'max_seq_len': MAX_SEQ_LEN,
    }
    torch.save(save_data, "checkpoints/disagreement_analysis.pt")
    print(f"Saved checkpoints/disagreement_analysis.pt")

    print(f"GRU {100*gru_overall:.1f}%, Bayes {100*bayes_overall:.1f}%, "
          f"agreement {100*(contingency[0,0]+contingency[1,1])/total_timesteps:.1f}%, "
          f"n={total_timesteps}")


if __name__ == "__main__":
    main()
