from __future__ import annotations

import argparse
import gzip
import pickle
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib import cm
from scipy import stats
import pingouin as pg

from brain_alignment import k_fold_test_idxs, story_fold_test_idxs
from datasets import get_dataset
from models import *
from utils import *

# -----------------------------------------------------------------------------
# Helper utilities
# -----------------------------------------------------------------------------

model_names = {
        "mamba": "Mamba-1.4B",
        "falcon3": "Falcon3-1B",
        "llama3.2-1B": "Llama3.2-1B",
        "qwen2": "Qwen2-1.5B",
        "gemma": "Gemma-2B",
        "gemma-it": "Gemma-2B-Instruct",
        "gemma-7B": "Gemma-7B",
        "zamba": "Zamba2-1.2B",
    }

def load_pickle(path: Path):
    open_fn = gzip.open if path.suffix == '.gz' else open
    with open_fn(path, 'rb') as f:
        return pickle.load(f)


def top_mass_indices(vec: np.ndarray, mass: float):
    'Return minimal index set whose sum ≥ mass·total(vec).'
    if vec.sum() == 0:
        return set()
    order = np.argsort(vec)[::-1]
    cumsum = np.cumsum(vec[order])
    k = np.searchsorted(cumsum, mass * cumsum[-1]) + 1
    return set(order[:k])


def compute_important_word_idxs(a_b: np.ndarray, keep_l: set[int], t: float = 0.75, num_random_trials: int = 100):
    '''Compute intersecting / exclusive important‑word indices and stats.'''
    keep_b = top_mass_indices(a_b, t)
    keep_both = keep_b & keep_l
    union = keep_b | keep_l

    if union:
        cnt_b_only = len(keep_b - keep_l)
        cnt_l_only = len(keep_l - keep_b)
        cnt_inter = len(keep_b & keep_l)
        cnt_union = len(union)
        brain_prop = cnt_b_only / cnt_union
        lm_prop = cnt_l_only / cnt_union
        iou_prop = cnt_inter / cnt_union
    else:
        cnt_b_only = cnt_l_only = 0
        brain_prop = lm_prop = iou_prop = 0.0
    
    # Compute random masks
    n_words = a_b.shape[0]
    rand_ious, rand_iou = [], None
    if len(keep_l) > 0 and len(keep_b) > 0:
        for _ in range(num_random_trials):
            rand_l = set(np.random.choice(n_words, size=len(keep_l), replace=False))
            rand_b = set(np.random.choice(n_words, size=len(keep_b), replace=False))
            iou = len(rand_l & rand_b) / len(rand_l | rand_b)
            rand_ious.append(iou)
        rand_iou = np.mean(rand_ious)

    return (
        keep_b,
        keep_both,
        brain_prop,
        lm_prop,
        iou_prop,
        rand_iou,
        cnt_b_only,
        cnt_l_only,
    )


def nansem(arr: np.ndarray, axis: int = 0):
    'Standard error of the mean that ignores NaNs.'
    valid = np.sum(~np.isnan(arr), axis=axis)
    return np.nanstd(arr, axis=axis, ddof=1) / np.sqrt(np.maximum(valid, 1))

# -----------------------------------------------------------------------------
# Main analysis pipeline
# -----------------------------------------------------------------------------

@measure_performance  # type: ignore  (decorator provided in utils.py)
def main() -> None:  # noqa: C901 – function is necessarily long‑winded
    parser = argparse.ArgumentParser(description='Attribution analysis')
    parser.add_argument('--dataset', type=str, default='HarryPotter', help='Dataset name') # 'HarryPotter', 'MothRadioHour'
    parser.add_argument('--context_length', type=int, default=640, help='Context length')
    parser.add_argument('--models', type=str, nargs='+', default=['llama3.2-1B', 'falcon3', 'gemma', 'mamba','zamba'], help='Comma‑separated model names') # ['qwen2', 'llama3.2-1B', 'falcon3', 'gemma', 'gemma-it', 'gemma-7B', 'mamba','zamba']
    parser.add_argument('--attr-method', type=str, default='gxi', help='Attribution method') # 'ig', 'gxi'
    parser.add_argument("--top-mass", type=int, nargs='+', default=[1,5,10,20,30,40,50,60,70,80,90,95,98], help="Top mass percentages")
    parser.add_argument("--layers", type=list, nargs='+', default=[[4,9,12], [5,9,12], [5,6,12], [15,29,32], [11,17,30]], help="Layer indices") # [[7,15,27], [4,9,12,15], [5,9,12,17], [5,6,12,17], [5,11,12,17], [4,16,18,27], [15,29,32,47], [11,17,30,37]]
    parser.add_argument('--num-folds', type=int, default=4, help='Number of CV folds')
    parser.add_argument('--num-tr-trim', type=int, default=5, help='Number of training TRs to trim')
    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # ---------------------------------------------------------------------
    # Prepare output directories
    # ---------------------------------------------------------------------
    output_dir = Path('./plots') / 'attribution_analysis_moth'
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / 'figure1').mkdir(exist_ok=True)
    (output_dir / 'figure2').mkdir(exist_ok=True)
    (output_dir / 'figure3').mkdir(exist_ok=True)
    (output_dir / 'figure5').mkdir(exist_ok=True)
    (output_dir / 'figure6').mkdir(exist_ok=True)

    # ---------------------------------------------------------------------
    # Load dataset
    # ---------------------------------------------------------------------
    if args.dataset == 'MothRadioHour':
        data_root_dir = Path('/BRAIN/ssms/work')
    else:
        data_root_dir = Path('./data')
    dataset_dir = Path(data_root_dir) / args.dataset
    dataset = get_dataset(
        args.dataset,
        dataset_dir,
        tokenizer=None,
        device=device,
        context_length=args.context_length,
        remove_format_chars=True,
        remove_punc_spacing=True,
    )

    if args.dataset == 'HarryPotter':
        roi_names = ['all'] + dataset.roi_names
        n_rois = len(roi_names)
    else:
        n_rois = 1
        story_idxs = list(dataset.story_idx_to_name.keys())

    # ---------------------------------------------------------------------
    # Collect raw statistics in python lists
    # ---------------------------------------------------------------------
    diff_rows: list[dict] = []
    diff_rows_lm: list[dict] = []
    dist_rows: list[dict] = []

    for model_idx, model_name in enumerate(args.models):
        experiment_dir = Path(f'/BRAIN/ssms2/work/outputs/{args.dataset}_{model_name}_{args.context_length}')
        brain_align_dir = experiment_dir / 'alignment_attrs' / args.attr_method
        next_word_dir = experiment_dir / 'next_word_attrs' / args.attr_method
        for threshold in args.top_mass:
            print(f'Processing {model_name} t = {threshold}%')
            t_mass = threshold / 100.0
            for layer_idx in args.layers[model_idx]:
                for subject_idx in range(len(dataset.subjects)):
                    if args.dataset == "MothRadioHour":
                        fold_test_idxs, _ = story_fold_test_idxs(dataset.subjects, subject_idx, args.num_tr_trim, None)
                    else:
                        fold_test_idxs = k_fold_test_idxs(dataset.subjects, subject_idx, args.num_folds, args.num_tr_trim)
                    if subject_idx == 0:
                        A_lm, fold_keep_idxs_nwp, keep_ls = [], [], []
                    for fold_idx, test_idxs in enumerate(fold_test_idxs):
                        print(f'Processing fold {fold_idx} for subject {subject_idx}, layer {layer_idx}')
                        # Load next‑word attribution (shared across subjects)
                        if subject_idx == 0:
                            p_lm = next_word_dir / f'fold_{fold_idx}.pkl'
                            if not p_lm.exists():
                                print(f"File {p_lm} does not exist, skipping...")
                                continue
                            # Load and filter A_lm: keep only arrays with standard extended context length
                            loaded_a = [np.abs(a) for a in load_pickle(p_lm)]
                            keep_idxs_nwp = [i for i, a in enumerate(loaded_a) if a.shape[0] == args.context_length + 15]
                            A_lm.append([a for i, a in enumerate(loaded_a) if i in keep_idxs_nwp])
                            keep_ls.append([top_mass_indices(a, t_mass) for a in A_lm[fold_idx]])
                            fold_keep_idxs_nwp.append(keep_idxs_nwp)

                        # Load brain alignment attribution (subject‑specific)
                        p_brain = brain_align_dir / f'subj_{subject_idx}_layer_{layer_idx}_fold_{fold_idx}.pkl'
                        if not p_brain.exists():
                            print(f"File {p_brain} does not exist, skipping...")
                            continue
                        A_brain = [np.abs(a) for a in load_pickle(p_brain)]
                        keep_idxs_brain = [i for i, a in enumerate(A_brain) if a.shape[0] == args.context_length + 15]
                        A_brain = [A_brain[i] for i in keep_idxs_brain]

                        assert keep_idxs_brain == fold_keep_idxs_nwp[fold_idx], \
                            f"Test indices do not match: {A_brain}, {A_lm[fold_idx]}"
                        
                        # Check for shape mismatches between LM and brain attributions
                        lm_shapes = [a.shape[0] for a in A_lm[fold_idx]]
                        brain_shapes = [a.shape[0] for a in A_brain]
                        if lm_shapes != brain_shapes:
                            print(f"Shape mismatch for model={model_name}, subject={subject_idx}, layer={layer_idx}, fold={fold_idx}")
                            #print(f"LM shapes: {lm_shapes}")
                            #print(f"Brain shapes: {brain_shapes}")
                            # Print tr_idx (test_idxs) and mismatched values
                            for i, (lm_shape, brain_shape) in enumerate(zip(lm_shapes, brain_shapes)):
                                if lm_shape != brain_shape:
                                    print(f"tr_idx: {test_idxs[i]}, LM shape: {lm_shape}, Brain shape: {A_brain[i].shape}")

                        test_idxs = [test_idxs[i] for i in keep_idxs_brain]
                        for k, test_idx in enumerate(test_idxs):
                            a_l = A_lm[fold_idx][k]
                            a_b = A_brain[k]

                            keep_l = keep_ls[fold_idx][k]

                            n_words = a_l.shape[0]
                            distances = n_words - 1 - np.arange(n_words)
                            counts_total = np.ones(n_words, dtype=int)

                            for r in range(n_rois):
                                if r > 0:
                                    break  # Only process the first ROI ('all')
                                a_b_r = a_b[:,r]
                                (
                                    keep_b,
                                    keep_both,
                                    brain_prop,
                                    lm_prop,
                                    iou_prop,
                                    rand_iou,
                                    brain_cnt,
                                    lm_cnt,
                                ) = compute_important_word_idxs(a_b_r, keep_l, t=t_mass)

                                # Build exclusive masks ----------------------------------------------------
                                brain_only_mask = np.zeros(n_words, dtype=int)
                                lm_only_mask = np.zeros(n_words, dtype=int)
                                both_mask = np.zeros(n_words, dtype=int)
                                for idx in keep_b - keep_l:
                                    brain_only_mask[idx] = 1
                                for idx in keep_l - keep_b:
                                    lm_only_mask[idx] = 1
                                for idx in keep_both:
                                    both_mask[idx] = 1

                                brain_only_counts = np.bincount(distances, brain_only_mask, minlength=n_words)
                                lm_only_counts = np.bincount(distances, lm_only_mask, minlength=n_words)
                                both_counts = np.bincount(distances, both_mask, minlength=n_words)

                                com_brain = (
                                    np.dot(a_b_r[list(keep_b)], distances[list(keep_b)])
                                    / a_b_r[list(keep_b)].sum()
                                )
                                com_lm = (
                                    np.dot(a_l[list(keep_l)], distances[list(keep_l)])
                                    / a_l[list(keep_l)].sum()
                                )
                                idx = list(keep_both)
                                com_inter = (
                                    np.dot(a_b_r[idx], distances[idx])
                                    / (a_b_r[idx].sum()+1e-10)
                                )
                                
                                dist_rows.append({
                                    "model": model_name,
                                    "subject": subject_idx,
                                    "layer": layer_idx,
                                    "context_idx": test_idx,
                                    "threshold": threshold,
                                    "brain_counts": brain_only_counts,
                                    "lm_counts": lm_only_counts,
                                    "intersect_counts": both_counts,
                                    "total_counts": counts_total,
                                    "brain_com": com_brain,
                                    "lm_com": com_lm,
                                    "intersect_com": com_inter,
                                })

                            diff_rows.append(
                                {
                                    'model': model_name,
                                    'subject': subject_idx,
                                    'layer': layer_idx,
                                    'context_idx': test_idx,
                                    'threshold': threshold,
                                    'brain_unique_prop': brain_prop,
                                    'lm_unique_prop': lm_prop,
                                    'iou_prop': iou_prop,
                                    'rand_iou_prop': rand_iou,
                                    'brain_unique_cnt': brain_cnt,
                                }
                            )

                            if subject_idx == 0:
                                diff_rows_lm.append(
                                    {
                                        'model': model_name,
                                        'layer': layer_idx,
                                        'context_idx': test_idx,
                                        'threshold': threshold,
                                        'lm_unique_prop': lm_prop,
                                        'lm_unique_cnt': lm_cnt,
                                    }
                                )

    # ---------------------------------------------------------------------
    # Post‑processing: save CSVs and create figures
    # ---------------------------------------------------------------------
    df_diff = pd.DataFrame(diff_rows)
    df_diff_lm = pd.DataFrame(diff_rows_lm)
    df_dist = pd.DataFrame(dist_rows)

    print('Plotting...')
    create_figures(df_diff, df_diff_lm, df_dist, output_dir)


# -----------------------------------------------------------------------------
# Figure‑generation helpers
# -----------------------------------------------------------------------------

PASTEL_ORANGE = '#FDBE87'
PASTEL_RED = '#FF9999'
PASTEL_GREEN = '#99CC99'


def sem(series: pd.Series) -> float:
    return series.std(ddof=1) / np.sqrt(len(series)) if len(series) > 1 else 0.0


def compute_auc(xs: np.ndarray, ys: np.ndarray) -> float:
    return np.trapz(ys, xs)


def create_figures(df_diff: pd.DataFrame, df_diff_lm: pd.DataFrame, df_dist: pd.DataFrame, out_dir: Path) -> None:
    'Create all requested figures and store them as PNG files.'
    plt.rcParams.update({'figure.dpi': 150})

    _figure1(df_diff, out_dir / 'figure1')
    _figure2(df_diff, out_dir / 'figure2')
    _figure3(df_diff, df_diff_lm, out_dir / 'figure3')
    _figure4(df_diff, df_diff_lm, out_dir / 'figure4')
    _figure5(df_diff, df_diff_lm, out_dir / 'figure5')
    _figure6(df_dist, out_dir / 'figure6')
    _figure7(df_dist, out_dir / 'figure7')


# -----------------------------------------------------------------------------
# Figure 1 – IoU vs. threshold, all models aggregated
# -----------------------------------------------------------------------------
def _figure1(df_diff: pd.DataFrame, fig_dir: Path) -> None:
    """Line = mean IoU across models; symbols = individual models."""
    fig_dir.mkdir(exist_ok=True)

    # ------------------------------- mean across models
    mean_df = (
        df_diff.groupby(['model', 'threshold'])['iou_prop'].mean()  # per‑model
              .groupby('threshold').mean()                          #   → mean
              .reset_index()
              .sort_values('threshold')
    )
    thresholds = mean_df['threshold'].to_numpy()
    x_pos      = np.arange(len(thresholds))

    # ------------------------------- IoU for each model
    model_df = (
        df_diff.groupby(['model', 'threshold'])['iou_prop']
               .mean()
               .reset_index()
    )

    # ------------------------------- Random baseline: mean ± SEM
    rand_df = (
        df_diff.groupby('threshold')['rand_iou_prop']
               .agg(['mean', nansem])
               .reset_index()
               .sort_values('threshold')
    )

    rand_means = rand_df['mean'].to_numpy()
    rand_sems  = rand_df['nansem'].to_numpy()

    # marker palette
    markers = ['o', 's', 'D', '^', 'v', 'P', 'X', '*', '+']
    marker_map = {m: markers[i % len(markers)]
                  for i, m in enumerate(sorted(df_diff['model'].unique()))}

    fig, ax = plt.subplots(figsize=(9, 4))

    # ------------------- Plot: Mean IoU curve across models
    ax.plot(x_pos, mean_df['iou_prop'], '-', color=PASTEL_ORANGE, linewidth=2, label='Mean across models')

    # ------------------- Plot: Per-model scatter
    for model, df_m in model_df.groupby('model'):
        xs = [np.where(thresholds == thr)[0][0] for thr in df_m['threshold']]
        x = xs
        y = df_m['iou_prop'].to_numpy()
        np.savez(fig_dir / f'{model}_iou_data.npz', x=x, y=y)

        ax.scatter(xs,
                   df_m['iou_prop'],
                   marker=marker_map[model],
                   s=25,
                   edgecolor='darkgray',
                   facecolor='none',
                   linewidth=0.8,
                   label=model_names.get(model, model))

    # ------------------- Plot: Random baseline with error bars
    ax.errorbar(x_pos,
                rand_means,
                yerr=rand_sems,
                linestyle='--',
                color='gray',
                capsize=3,
                label='Random baseline')

    # ------------------- Labels and formatting
    ax.set_xlabel('Top % attribution')
    ax.set_ylabel('Intersection over Union (IoU)')
    ax.set_ylim(0, 1)
    ax.set_xticks(x_pos)
    ax.set_xticklabels(thresholds)
    ax.grid(True, ls=':')

    # Legend
    ax.legend(loc='center left', fontsize=14, frameon=False, bbox_to_anchor=(1, 0.5), ncol=1)

    fig.tight_layout()
    fig.savefig(fig_dir / 'figure1_iou_vs_threshold.png', bbox_inches='tight')
    fig.savefig(fig_dir / 'figure1_iou_vs_threshold.pdf', dpi=1200, bbox_inches='tight')
    plt.close(fig)

# -----------------------------------------------------------------------------
# Figure 2 – IoU vs. threshold per model, layer‑wise curves
# -----------------------------------------------------------------------------

def _figure2(df_diff: pd.DataFrame, fig_dir: Path) -> None:
    fig_dir.mkdir(exist_ok=True)
    cmap = cm.get_cmap('Oranges')

    for model, df_m in df_diff.groupby('model'):
        layers = sorted(df_m['layer'].unique())
        color_norm = np.linspace(0.3, 0.9, len(layers))
        thresholds = np.sort(df_m['threshold'].unique())
        x_pos = np.arange(len(thresholds))
        idx_map = {thr: i for i, thr in enumerate(thresholds)}

        fig, ax = plt.subplots(figsize=(5, 3))
        for l_idx, layer in enumerate(layers):
            df_layer = (
                df_m[df_m['layer'] == layer]
                .groupby(['threshold', 'subject'])['iou_prop']
                .mean()
                .groupby('threshold')
                .agg(['mean', sem])
                .reset_index()
            )
            x = df_layer['threshold'].map(idx_map)
            ax.errorbar(
                x,
                df_layer['mean'],
                yerr=df_layer['sem'],
                label=f'Layer {layer}',
                capsize=2,
                color=cmap(color_norm[l_idx]), markersize=3,
            )

        ax.set_xlabel('Top % attribution')
        ax.set_ylabel('Intersection over Union (IoU)')
        ax.set_ylim(0, 1)
        ax.set_title(f'{model_names[model]}')
        ax.set_xticks(x_pos)
        ax.set_xticklabels(thresholds)
        ax.grid(True, ls=':')
        ax.legend(title='', loc='upper left')
        fig.tight_layout()
        fig.savefig(fig_dir / f'figure2_{model}_layers.png')
        fig.savefig(fig_dir / f'figure2_{model}_layers.pdf', dpi=1200)
        plt.close(fig)


# -----------------------------------------------------------------------------
# Figure 3 – Layer‑wise unique counts & AUC bars per model+layer
# -----------------------------------------------------------------------------
plt.rcParams.update({
    "font.size": 10,
    "axes.titlesize": 18,
    "axes.labelsize": 14,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "legend.fontsize": 14,
})

def _annotate_significance(ax, x_pos: float, bar_height: float, p_val: float):
    "Draw an asterisk above bar if p < 0.05."
    if p_val < 0.05:
        text = "∗"
        if p_val < 0.01:
            text = "∗∗"
        if p_val < 0.001:
            text = "∗∗∗"
        ax.text(x_pos, bar_height * 0.9, text, ha="center", va="bottom", fontsize=14)

def _figure3(df_brain: pd.DataFrame, df_lm: pd.DataFrame, path: Path):
    path.mkdir(exist_ok=True)
    thresholds = np.sort(df_brain["threshold"].unique())
    for model in df_brain["model"].unique():
        df_m_br = df_brain[df_brain["model"] == model]
        df_m_lm = df_lm[df_lm["model"] == model]

        global_ymax = 0.0
        global_auc_ymax  = 0.0
        curves_cache, aucs_cache, t_res_cache = {}, {}, {}
        for layer in df_m_br["layer"].unique():
            df_b = df_m_br[df_m_br["layer"] == layer]
            df_l = df_m_lm[df_m_lm["layer"] == layer]
            # aggregate per subject ------------------------------------------------
            curves_b = (
                df_b.groupby(["threshold", "subject", "context_idx"])["brain_unique_cnt"].mean().reset_index().groupby("threshold")["brain_unique_cnt"].agg(["mean", sem])
            )
            curves_l = (
                df_l.groupby(["threshold", "context_idx"])["lm_unique_cnt"].mean().reset_index().groupby("threshold")["lm_unique_cnt"].agg(["mean", sem])
            )
            curves_cache[layer] = (curves_b, curves_l)

            # AUC per context ------------------------------------------------------
            aucs_br, aucs_lm = [], []
            for ctx, d_ctx in df_b.groupby("context_idx"):
                d_sort = d_ctx.sort_values("threshold")
                aucs_br.append(compute_auc(d_sort["threshold"], d_sort["brain_unique_cnt"]))
            for ctx, d_ctx in df_l.groupby("context_idx"):
                d_sort = d_ctx.sort_values("threshold")
                aucs_lm.append(compute_auc(d_sort["threshold"], d_sort["lm_unique_cnt"]))
            auc_stats = {
                "brain": (np.mean(aucs_br), stats.sem(aucs_br) if len(aucs_br) > 1 else 0.0),
                "lm": (np.mean(aucs_lm), stats.sem(aucs_lm) if len(aucs_lm) > 1 else 0.0),
            }
            aucs_cache[layer] = auc_stats   

            # paired t‑test --------------------------------------------------------
            min_len = min(len(aucs_br), len(aucs_lm))
            if min_len == 0:
                continue
            t_res = pg.ttest(aucs_br[:min_len], aucs_lm[:min_len], paired=True, correction='fdr_bh')
            t_res_cache[layer] = t_res         

            # Get global max for line plot
            global_ymax = max(global_ymax, curves_b["mean"].max(), curves_l["mean"].max())

            # Get global max for AUC bar plot
            global_auc_ymax = max(global_auc_ymax, auc_stats["brain"][0], auc_stats["lm"][0])

        global_ymax *= 1.1     # small head-room
        global_auc_ymax *= 1.2

        # Save caches and global values to a file for later use
        cache_file = path / f"figure3_{model}_caches.pkl"
        with open(cache_file, "wb") as f:
            pickle.dump(
                dict(curves_cache=curves_cache,
                    aucs_cache=aucs_cache,
                    t_res_cache=t_res_cache,
                    global_ymax=global_ymax,
                    global_auc_ymax=global_auc_ymax),
                f)

        for layer in df_m_br["layer"].unique():
            curves_b, curves_l = curves_cache[layer]

            t_res = t_res_cache[layer]
            p_val = t_res["p-val"].iloc[0]

            auc_stats = aucs_cache[layer]
            print(f"Model {model} – Layer {layer} – AUC: brain {auc_stats['brain'][0]:.3f}, lm {auc_stats['lm'][0]:.3f}, p-value: {p_val:.3f}")

            # plot ----------------------------------------------------------------
            fig, (ax_line, ax_bar) = plt.subplots(
                ncols=2, figsize=(7.5, 4), gridspec_kw={"width_ratios": [3, 1]}
            )
            ax_line.errorbar(
                thresholds, curves_b["mean"], yerr=curves_b["sem"],
                color=PASTEL_RED, label="Brain alignment", lw=1,
                ecolor="black", elinewidth=1.2, capsize=4, capthick=1.2
            )
            ax_line.errorbar(
                thresholds, curves_l["mean"], yerr=curves_l["sem"],
                color=PASTEL_GREEN, label="Next‑word prediction", lw=1,
                ecolor="black", elinewidth=1.2, capsize=4, capthick=1.2
            )

            ax_line.set_xlabel("Top % attribution", fontsize=18)
            ax_line.set_ylabel("# unique important words", fontsize=18)
            ax_line.grid(ls=":")
            ax_line.legend(loc="upper left", frameon=False)
            ax_line.set_ylim(0, global_ymax)
            ax_line.tick_params(axis='x')
            ax_line.tick_params(axis='y')

            bars = ax_bar.bar(
                [0, 1],
                [auc_stats["brain"][0], auc_stats["lm"][0]],
                yerr=[auc_stats["brain"][1], auc_stats["lm"][1]],
                color=[PASTEL_RED, PASTEL_GREEN],
                width=0.6, linewidth=1.2, capsize=4
            )
            ax_bar.set_ylim(0, global_auc_ymax)
            ax_bar.set_xticks([0, 1], ["BA", "NWP"], fontsize=18)
            ax_bar.set_ylabel("AUC", fontsize=18)
            ax_bar.grid(True, axis="y", ls=":")

            # annotate significance ----------------------------------------------
            higher_idx = int(auc_stats["brain"][0] < auc_stats["lm"][0])  # 0 if brain higher else 1
            _annotate_significance(ax_bar, higher_idx, global_auc_ymax, p_val)

            fig.tight_layout()
            fig.savefig(path / f"figure3_{model}_layer{layer}.png")
            fig.savefig(path / f"figure3_{model}_layer{layer}.pdf", dpi=1200)
            plt.close(fig)

# -----------------------------------------------------------------------------
# Figure 4 – curves, bars, and a p‑value heat‑map across layers × thresholds
# -----------------------------------------------------------------------------
def _figure4(df_brain: pd.DataFrame,
             df_lm:    pd.DataFrame,
             fig_dir:  Path) -> None:
    """
    • Plots Early/Middle/Late‑layer curves (± SEM) and their AUC bars.
    • For every ordinal‑layer × threshold, runs a paired t‑test across contexts
      (brain_unique_cnt vs lm_unique_cnt) and stores the p‑value.
    • Finally renders a heat‑map of those p‑values (thresholds on x‑axis,
      layer‑ordinal on y‑axis).
    """
    fig_dir.mkdir(exist_ok=True)
    ORDINAL_NAMES = {0: 'Early', 1: 'Middle', 2: 'Late'}
    thresholds = np.sort(df_brain["threshold"].unique())
    
    global_ymax = 0.0
    global_auc_ymax  = 0.0
    curves_cache, aucs_cache, p_val_cache = {}, {}, {}

    for ordinal in (0, 1, 2):
        # ------------------------------------------------ pick one “ordinal” layer per model
        chosen = {}
        for model, d in df_brain.groupby("model"):
            layers = sorted(d["layer"].unique())
            if len(layers) > ordinal:
                chosen[model] = layers[ordinal]
        if not chosen:
            continue  # nothing for this ordinal

        # ------------------------------------------------ subset data
        df_b_sel = pd.concat([df_brain[(df_brain["model"] == m) &
                                       (df_brain["layer"] == l)]
                              for m, l in chosen.items()], ignore_index=True)
        df_l_sel = pd.concat([df_lm[(df_lm["model"] == m) &
                                    (df_lm["layer"] == l)]
                              for m, l in chosen.items()], ignore_index=True)


        # ------------------------------------------------ average across models for each subject×context
        brain_ctx = (df_b_sel.groupby(["subject", "context_idx", "threshold"])
                               ["brain_unique_cnt"].mean().reset_index())
        lm_ctx    = (df_l_sel.groupby(["context_idx", "threshold"])
                               ["lm_unique_cnt"].mean().reset_index())

        # ------------------------------------------------ curves (mean ± SEM across contexts)
        curves_brain = brain_ctx.groupby("threshold")["brain_unique_cnt"].agg(["mean", sem])
        curves_lm    = lm_ctx.groupby("threshold")["lm_unique_cnt"].agg(["mean", sem])
        curves_cache[ordinal] = (curves_brain, curves_lm)
        global_ymax = max(global_ymax, curves_brain["mean"].max(), curves_lm["mean"].max())

        # ------------------------------------------------ per‑context AUCs  (and per‑threshold p‑values)
        aucs_brain, aucs_lm = [], []

        #  > AUCs across whole threshold range ------------------------------
        # BA – one value per (subject, context)
        for (subj, ctx), df_pair in brain_ctx.groupby(["subject", "context_idx"]):
            d_sort = df_pair.sort_values("threshold")
            aucs_brain.append((subj, ctx, compute_auc(d_sort["threshold"],
                                                    d_sort["brain_unique_cnt"])))
        # NWP – one value per (context)   →  replicate for every subject
        for ctx, df_pair in lm_ctx.groupby("context_idx"):
            d_sort = df_pair.sort_values("threshold")
            auc_val = compute_auc(d_sort["threshold"], d_sort["lm_unique_cnt"])
            for subj in brain_ctx["subject"].unique():
                aucs_lm.append((subj, ctx, auc_val))
        
        auc_vals_b   = [t[2] for t in aucs_brain]
        auc_vals_l   = [t[2] for t in aucs_lm]

        auc_stats = {
            "brain": (np.mean(auc_vals_b),
                    stats.sem(auc_vals_b) if len(auc_vals_b) > 1 else 0.0),
            "lm":    (np.mean(auc_vals_l),
                    stats.sem(auc_vals_l) if len(auc_vals_l) > 1 else 0.0),
        }
        aucs_cache[ordinal] = auc_stats
        global_auc_ymax = max(global_auc_ymax, auc_stats["brain"][0], auc_stats["lm"][0])

        #  > Paired t‑test for AUCs -----------------------------------------
        df_auc_b = pd.DataFrame(aucs_brain, columns=["subject","context","auc"])
        df_auc_l = pd.DataFrame(aucs_lm,   columns=["subject","context","auc"])

        merged_auc = pd.merge(df_auc_b, df_auc_l,
                            on=["subject","context"],
                            suffixes=("_b","_l"))

        p_auc = (pg.ttest(merged_auc["auc_b"],
                        merged_auc["auc_l"],
                        paired=True,
                        correction='fdr_bh')["p-val"]
                .iloc[0])
        p_val_cache[ordinal] = p_auc

    global_ymax *= 1.1
    global_auc_ymax *= 1.2

    # Save caches and global values to a file for later use
    cache_file = fig_dir / f"figure4_caches.pkl"
    with open(cache_file, "wb") as f:
        pickle.dump(
            dict(curves_cache=curves_cache,
                aucs_cache=aucs_cache,
                p_val_cache=p_val_cache,
                global_ymax=global_ymax,
                global_auc_ymax=global_auc_ymax),
            f)

    for ordinal in (0, 1, 2):
        curves_brain, curves_lm = curves_cache[ordinal]
        p_val = p_val_cache[ordinal]
        auc_stats = aucs_cache[ordinal]

        # ------------------------------------------------ bar‑plus‑curve panel
        fig, (ax_line, ax_bar) = plt.subplots(
            ncols=2, figsize=(7.5, 4), gridspec_kw={"width_ratios": [3, 1]})

        ax_line.errorbar(thresholds, curves_brain["mean"], yerr=curves_brain["sem"],
                        color=PASTEL_RED,   label="Brain alignment", lw=1.5,
                        ecolor="black", elinewidth=1.2, capsize=4, capthick=1.2)
        ax_line.errorbar(thresholds, curves_lm["mean"],   yerr=curves_lm["sem"],
                        color=PASTEL_GREEN, label="Next‑word prediction", lw=1.5,
                        ecolor="black", elinewidth=1.2, capsize=4, capthick=1.2)


        ax_line.set_xlabel("Top % attribution", fontsize=18)
        ax_line.set_ylabel("# unique important words", fontsize=18)
        ax_line.grid(ls=":")
        ax_line.legend(loc="upper left", frameon=False)
        ax_line.set_ylim(0, global_ymax)

        # Plot AUC bars
        bars = ax_bar.bar([0, 1],
                        [auc_stats["brain"][0], auc_stats["lm"][0]],
                        yerr=[auc_stats["brain"][1], auc_stats["lm"][1]],
                        color=[PASTEL_RED, PASTEL_GREEN], width=0.6,
                        linewidth=1.2, capsize=4)
        ax_bar.set_ylim(0, global_auc_ymax)
        ax_bar.set_xticks([0, 1], ["BA", "NWP"], fontsize=18)
        ax_bar.set_ylabel("AUC", fontsize=18)
        ax_bar.grid(True, axis="y", ls=":")

        higher_idx = int(auc_stats["brain"][0] < auc_stats["lm"][0])
        _annotate_significance(ax_bar, higher_idx,
                    global_auc_ymax, p_val)

        fig.tight_layout()
        fig.savefig(fig_dir / f"figure4_{ORDINAL_NAMES[ordinal]}_layer.png")
        fig.savefig(fig_dir / f"figure4_{ORDINAL_NAMES[ordinal]}_layer.pdf", dpi=1200)
        plt.close(fig)

# -----------------------------------------------------------------------------
# Figure 5 – layers‑aggregated unique counts & significance
# -----------------------------------------------------------------------------

def _figure5(df_brain: pd.DataFrame, df_lm: pd.DataFrame, fig_dir: Path) -> None:
    fig_dir.mkdir(exist_ok=True)
    thresholds = np.sort(df_brain["threshold"].unique())

    for model in df_brain["model"].unique():
        df_b_m = df_brain[df_brain["model"] == model]
        df_l_m = df_lm[df_lm["model"] == model]

        brain_mat = (
            df_b_m.groupby(["layer", "threshold"])["brain_unique_cnt"].mean().unstack(level="threshold").reindex(columns=thresholds)
        )
        lm_mat = (
            df_l_m.groupby(["layer", "threshold"])["lm_unique_cnt"].mean().unstack(level="threshold").reindex(columns=thresholds)
        )

        brain_mean = brain_mat.mean(axis=0).values
        brain_sem = brain_mat.apply(sem, axis=0).values
        lm_mean = lm_mat.mean(axis=0).values
        lm_sem = lm_mat.apply(sem, axis=0).values

        # AUC per layer ---------------------------------------------------------
        aucs_brain = [compute_auc(thresholds, row.values) for _, row in brain_mat.iterrows()]
        aucs_lm = [compute_auc(thresholds, row.values) for _, row in lm_mat.iterrows()]

        if len(aucs_brain) and len(aucs_lm):
            t_res = pg.ttest(aucs_brain, aucs_lm, paired=True, correction='fdr_bh')
            print(f"Figure 5: {t_res}")
            p_val = t_res["p-val"].iloc[0]
        else:
            p_val = 1.0

        auc_stats = {
            "brain": (np.mean(aucs_brain), stats.sem(aucs_brain) if len(aucs_brain) > 1 else 0.0),
            "lm": (np.mean(aucs_lm), stats.sem(aucs_lm) if len(aucs_lm) > 1 else 0.0),
        }
        print(f"Model {model} – AUC: brain {auc_stats['brain'][0]:.3f}, lm {auc_stats['lm'][0]:.3f}, p-value: {p_val:.3f}")

        fig, (ax_line, ax_bar) = plt.subplots(ncols=2, figsize=(6.5, 3), gridspec_kw={"width_ratios": [3, 1]})
        ax_line.errorbar(thresholds, brain_mean, yerr=brain_sem, fmt="-o", color=PASTEL_RED, label="Brain alignment")
        ax_line.errorbar(thresholds, lm_mean, yerr=lm_sem, fmt="-o", color=PASTEL_GREEN, label="Next‑word prediction")
        ax_line.set_xlabel("Top % attribution")
        ax_line.set_ylabel("# unique important words")
        ax_line.set_title(f"{model_names[model]} – average across layers")
        ax_line.grid(True, ls=":")
        ax_line.legend()

        bars = ax_bar.bar(
            [0, 1],
            [auc_stats["brain"][0], auc_stats["lm"][0]],
            yerr=[auc_stats["brain"][1], auc_stats["lm"][1]],
            color=[PASTEL_RED, PASTEL_GREEN],
            width=0.6,
        )
        ax_bar.set_xticks([0, 1], ["BA", "NWP"])
        ax_bar.set_ylabel("AUC")
        ax_bar.grid(True, axis="y", ls=":")

        higher_idx = int(auc_stats["brain"][0] < auc_stats["lm"][0])
        _annotate_significance(ax_bar, higher_idx, bars[higher_idx].get_height(), p_val)

        fig.tight_layout()
        fig.savefig(fig_dir / f"figure5_{model}_layers_aggregated.png")
        fig.savefig(fig_dir / f"figure5_{model}_layers_aggregated.pdf", dpi=1200)
        plt.close(fig)


# -----------------------------------------------------------------------------
# Figure 6 – Distance distributions (layer-averaged + per-subject panels)
# -----------------------------------------------------------------------------

def _pad(v: np.ndarray, target: int) -> np.ndarray:
    if v.shape[0] < target:
        return np.pad(v, (0, target - v.shape[0]), constant_values=0)
    return v

def _sum_padded(arrays: list[np.ndarray]) -> np.ndarray:
    """Pad to common length with zeros, then sum column‑wise."""
    if not arrays:
        return np.array([])
    max_len = max(a.shape[0] for a in arrays)
    return np.sum([_pad(a, max_len) for a in arrays], axis=0)

def _figure6(df_dist: pd.DataFrame, fig_dir: Path) -> None:
    fig_dir.mkdir(exist_ok=True)

    # helper to dump the plotted data for later regeneration ------------------
    def _save_npz(fname: Path,
                  x:  np.ndarray,
                  brain: np.ndarray,
                  lm:    np.ndarray,
                  inter: np.ndarray,
                  com_b: float,
                  com_l: float,
                  com_i: float) -> None:
        np.savez(fname,
                 x=x,
                 brain=brain, lm=lm, inter=inter,
                 com_brain=com_b, com_lm=com_l, com_int=com_i)

    for model, df_model in df_dist.groupby("model"):
        for thr in [10, 60, 80]:
            df_thr = df_model[df_model["threshold"] == thr]
            if df_thr.empty:
                continue

            # ──────────────────────────────────────────────────────────────
            # 0)  PRE-COMPUTE SUBJECT-LEVEL distributions
            # ──────────────────────────────────────────────────────────────
            subj_info = {}   # subj → (brain_mean, lm_mean, int_mean, com_b, com_l, com_i)
            for subj, df_s in df_thr.groupby("subject"):

                # ---------- combine layers within this subject ----------
                b_sum, l_sum, i_sum, t_sum = [], [], [], []
                for layer, df_l in df_s.groupby("layer"):
                    b_sum.append(_sum_padded(list(df_l["brain_counts"])))
                    l_sum.append(_sum_padded(list(df_l["lm_counts"])))
                    i_sum.append(_sum_padded(list(df_l["intersect_counts"])))
                    t_sum.append(_sum_padded(list(df_l["total_counts"])))

                # pad to common length and average across layers
                brain  = _sum_padded([b/t for b, t in zip(b_sum, t_sum)]) / len(b_sum)
                lm     = _sum_padded([l/t for l, t in zip(l_sum, t_sum)]) / len(l_sum)
                inter  = _sum_padded([i/t for i, t in zip(i_sum, t_sum)]) / len(i_sum)

                # quick smoothing
                brain  = pd.Series(brain).rolling(5, center=True, min_periods=1).mean().to_numpy()
                lm     = pd.Series(lm   ).rolling(5, center=True, min_periods=1).mean().to_numpy()
                inter  = pd.Series(inter).rolling(5, center=True, min_periods=1).mean().to_numpy()

                subj_info[subj] = (brain,
                                   lm,
                                   inter,
                                   np.nanmean(df_s["brain_com"]),
                                   np.nanmean(df_s["lm_com"]),
                                   np.nanmean(df_s["intersect_com"]))

            # ──────────────────────────────────────────────────────────────
            # 1)  PER-SUBJECT PLOTS  (and .npz dumps)
            # ──────────────────────────────────────────────────────────────
            for subj, (brain, lm, inter, cb, cl, ci) in subj_info.items():
                x      = np.arange(len(brain))
                y_max  = max(np.nanmax(brain), np.nanmax(lm), np.nanmax(inter)) * 1.05
                dists  = [("Brain alignment", brain,  cb, PASTEL_RED),
                          ("Next-word prediction", lm, cl, PASTEL_GREEN),
                          ("Intersection",        inter, ci, PASTEL_ORANGE)]

                # --- dump data for fast redraw later
                fstem = f"figure6_{model}_subj{subj}_thr{thr}"
                _save_npz(fig_dir / f"{fstem}.npz",
                          x, brain, lm, inter, cb, cl, ci)

                fig, axes = plt.subplots(nrows=3, figsize=(10, 8), sharex=False)
                for ax, (title, y, com, col) in zip(axes, dists):
                    ax.bar(x, y, color=col, alpha=0.9, label=title)
                    ax.axvline(com, ls='--', color='k')
                    ax.text(com + 2, 0.66, f"{com:.1f}", rotation=90, fontsize=18)
                    ax.set_ylim(0, 1)
                    ax.grid(ls=':')
                    ax.legend(loc="upper right", frameon=False, fontsize=18)
                    ax.set_xlabel(ax.get_xlabel(), fontsize=20)
                    ax.set_ylabel(ax.get_ylabel(), fontsize=20)
                    ax.tick_params(axis='both', labelsize=18)
                axes[1].set_ylabel("Proportion of important words")
                axes[-1].set_xlabel("Distance from most recent word")
                fig.tight_layout()

                fig.savefig(fig_dir / f"{fstem}.png")
                fig.savefig(fig_dir / f"{fstem}.pdf", dpi=1200)
                plt.close(fig)

            # ──────────────────────────────────────────────────────────────
            # 2)  LAYER-AVERAGED PLOT
            # ──────────────────────────────────────────────────────────────
            subj_props = [v[:3] for v in subj_info.values()]  # unpack arrays only
            brain_mean = _sum_padded([p[0] for p in subj_props]) / len(subj_props)
            lm_mean    = _sum_padded([p[1] for p in subj_props]) / len(subj_props)
            int_mean   = _sum_padded([p[2] for p in subj_props]) / len(subj_props)

            brain_mean = pd.Series(brain_mean).rolling(5, center=True, min_periods=1).mean().to_numpy()
            lm_mean    = pd.Series(lm_mean   ).rolling(5, center=True, min_periods=1).mean().to_numpy()
            int_mean   = pd.Series(int_mean  ).rolling(5, center=True, min_periods=1).mean().to_numpy()

            com_brain = np.mean([v[3] for v in subj_info.values()])
            com_lm    = np.mean([v[4] for v in subj_info.values()])
            com_int   = np.mean([v[5] for v in subj_info.values()])

            dists = [("Brain alignment", brain_mean, com_brain, PASTEL_RED),
                     ("Next-word prediction", lm_mean, com_lm, PASTEL_GREEN),
                     ("Intersection", int_mean, com_int, PASTEL_ORANGE)]
            
            # --- dump the averaged data
            fstem = f"figure6_{model}_thr{thr}"
            _save_npz(fig_dir / f"{fstem}_dists.npz",
                      np.arange(len(brain_mean)),
                      brain_mean, lm_mean, int_mean,
                      com_brain, com_lm, com_int)

            fig, axes = plt.subplots(nrows=3, figsize=(10, 8), sharex=False)
            y_max = np.nanmax([brain_mean, lm_mean, int_mean]) * 1.05

            for ax, (title, y, com, col) in zip(axes, dists):
                x = np.arange(len(y))
                ax.bar(x, y, color=col, label=title)
                ax.legend(loc="upper right", frameon=False, fontsize=18)
                ax.axvline(com, ls="--", color="k")
                ax.text(com+2, 0.66, f"{com:.1f}", rotation=90, fontsize=18)
                ax.set_ylim(0, 1)
                ax.grid(ls=":")
                ax.tick_params(axis='both', labelsize=18)
                ax.set_xlabel(ax.get_xlabel(), fontsize=20)
                ax.set_ylabel(ax.get_ylabel(), fontsize=20)

            axes[1].set_ylabel("Proportion of important words")
            axes[-1].set_xlabel("Distance from most recent word")
            fig.tight_layout()

            fig.savefig(fig_dir / f"{fstem}.png")
            fig.savefig(fig_dir / f"{fstem}.pdf", dpi=1200)
            plt.close(fig)
    
# -----------------------------------------------------------------------------
# Figure 7 – Distance distributions per ordinal layer (no layer averaging)
# -----------------------------------------------------------------------------
def _figure7(df_dist: pd.DataFrame, fig_dir: Path) -> None:
    """
    Identical to Figure 6 except:
      • no averaging across layers;
      • one figure per Early / Middle / Late layer for each model.
    """
    fig_dir.mkdir(exist_ok=True)
    ORD_NAMES = {0: "Early", 1: "Middle", 2: "Late"}

    for model, df_model in df_dist.groupby("model"):
        # figure‑out ordinal mapping for this model
        layers_sorted = sorted(df_model["layer"].unique())
        ordinal_map   = {layers_sorted[i]: i for i in range(min(3, len(layers_sorted)))}

        for thr in [10, 60, 80, 90, 98]:
            df_thr = df_model[df_model["threshold"] == thr]
            if df_thr.empty:
                continue

            for layer, ord_idx in ordinal_map.items():
                df_layer = df_thr[df_thr["layer"] == layer]
                if df_layer.empty:
                    continue

                # ---------- aggregate across subjects (sum counts per distance)
                subj_props = []
                com_b_s, com_l_s, com_i_s = [], [], []

                for subj, df_s in df_layer.groupby("subject"):
                    b_sum = _sum_padded(list(df_s["brain_counts"]))
                    l_sum = _sum_padded(list(df_s["lm_counts"]))
                    i_sum = _sum_padded(list(df_s["intersect_counts"]))
                    t_sum = _sum_padded(list(df_s["total_counts"]))

                    with np.errstate(divide="ignore", invalid="ignore"):
                        subj_props.append((b_sum / t_sum, l_sum / t_sum, i_sum / t_sum))

                    com_b_s.append(np.nanmean(df_s["brain_com"]))
                    com_l_s.append(np.nanmean(df_s["lm_com"]))
                    com_i_s.append(np.nanmean(df_s["intersect_com"]))

                # average across subjects
                brain_mean = _sum_padded([p[0] for p in subj_props]) / len(subj_props)
                lm_mean    = _sum_padded([p[1] for p in subj_props]) / len(subj_props)
                int_mean   = _sum_padded([p[2] for p in subj_props]) / len(subj_props)

                # smooth
                brain_mean = pd.Series(brain_mean).rolling(5, center=True, min_periods=1).mean().to_numpy()
                lm_mean    = pd.Series(lm_mean   ).rolling(5, center=True, min_periods=1).mean().to_numpy()
                int_mean   = pd.Series(int_mean  ).rolling(5, center=True, min_periods=1).mean().to_numpy()

                com_brain = np.nanmean(com_b_s)
                com_lm    = np.nanmean(com_l_s)
                com_int   = np.nanmean(com_i_s)

                dists = [("Brain alignment",  brain_mean, com_brain, PASTEL_RED),
                         ("Next‑word prediction", lm_mean, com_lm, PASTEL_GREEN),
                         ("Intersection",        int_mean, com_int, PASTEL_ORANGE)]

                # ---------------- plotting -----------------------------------
                fig, axes = plt.subplots(nrows=3, figsize=(8, 6), sharex=False)
                y_max = np.nanmax([brain_mean, lm_mean, int_mean]) * 1.05

                for ax, (title, y, com, col) in zip(axes, dists):
                    x = np.arange(len(y))
                    ax.bar(x, y, color=col, alpha=0.8)
                    ax.axvline(com, ls="--", color="k")
                    ax.text(com + 0.85, y_max * 0.7, f"{com:.1f}", rotation=90)
                    ax.set_ylim(0, y_max)
                    ax.grid(ls=":")

                axes[1].set_ylabel("Proportion of important words")
                axes[-1].set_xlabel("Distance from most recent word")
                fig.suptitle(f"{model_names[model]} – {ORD_NAMES[ord_idx]} layer – Top {thr}% attribution")
                fig.tight_layout()

                fstem = f"figure7_{model}_{ORD_NAMES[ord_idx]}_thr{thr}"
                fig.savefig(fig_dir / f"{fstem}.png")
                fig.savefig(fig_dir / f"{fstem}.pdf", dpi=1200)
                plt.close(fig)


# -----------------------------------------------------------------------------

if __name__ == '__main__':
    main()

