from pathlib import Path
import gzip, pickle
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt

from brain_alignment import get_delay_idxs, k_fold_test_idxs
from datasets import get_dataset
from models import *
from utils import *


def load_pickle(path: Path):
    open_fn = gzip.open if path.suffix == ".gz" else open
    with open_fn(path, "rb") as f:
        return pickle.load(f)

def compute_important_word_idxs(brain_align_attrs, next_word_attrs, t: float = 0.75):
    """
    Parameters
    ----------
    brain_align_attrs : list[np.ndarray]
        One array per context, shape (n_words, n_voxels) **or** (n_words,).
    next_word_attrs   : list[np.ndarray]
        Same length, same shapes as brain_align_attrs.
    t : float
        Cumulative-mass threshold (0 < t ≤ 1). 0.98 → keep words whose
        attribution sums to at least 98 % of the total.

    Returns
    -------
    brain_sets : list[set[int]]
        Selected word indices for the brain-alignment objective (one set
        per context).
    lm_sets    : list[set[int]]
        Selected word indices for next-word prediction (one set per context).
    intersect_sets : list[set[int]]
        Selected word indices for the intersection of both methods (one
        set per context).
    attr_lengths : list[int]
        Length of each attribution vector (one per context).
    """

    # 1 — Absolute attribution magnitudes
    A_brain = [np.abs(a) for a in brain_align_attrs]  # Keep all ROIs
    A_lm    = [np.abs(a) for a in next_word_attrs]

    brain_sets, lm_sets, intersect_sets, attr_lengths = [], [], [], []
    for a_b, a_l in zip(A_brain, A_lm):
        def top_mass_indices(vec, mass):
            """Return minimal index set whose sum ≥ mass·total(vec)."""
            if vec.sum() == 0:
                return set()
            order = np.argsort(vec)[::-1]
            cumsum = np.cumsum(vec[order])
            k = np.searchsorted(cumsum, mass * cumsum[-1]) + 1
            return set(order[:k])

        # Compute important word indices for each ROI
        roi_sets = [top_mass_indices(a_b[:, roi_idx], t) for roi_idx in range(a_b.shape[1])]
        brain_sets.append(roi_sets)

        # Compute important word indices for next-word prediction
        keep_l = top_mass_indices(a_l, t)
        lm_sets.append(keep_l)

        # Compute intersection for each ROI
        intersect_sets.append([roi_set & keep_l for roi_set in roi_sets])

        attr_lengths.append(a_b.shape[0])  # Length of the attribution vector for this context

    return brain_sets, lm_sets, intersect_sets, attr_lengths


@measure_performance
def main():
    parser = argparse.ArgumentParser(description='Attribution analysis')
    parser.add_argument('--dataset', type=str, default='HarryPotter', help='Dataset name')
    parser.add_argument('--context_length', type=int, default=640, help='Context length')
    parser.add_argument('--models', type=str, nargs='+', default=['llama3.2-1B', 'gemma'], help='Comma‑separated model names') # ['llama3.2-1B','falcon3','gemma', 'gemma-it', 'gemma-7B', 'mamba','zamba']
    parser.add_argument('--attr-method', type=str, default='ig', help='Attribution method')
    parser.add_argument("--top-mass", type=int, nargs='+', default=[10,60,80], help="Top mass percentages")
    parser.add_argument("--layers", type=list, nargs='+', default=[[4,9,12], [5,9,12]], help="Layer indices") #[4,9,12,15], [5,9,12,17], [5,6,12,17], [5,11,12,17], [4,16,18,27], [15,29,32,47], [11,17,30,37]
    parser.add_argument('--num-folds', type=int, default=4, help='Number of CV folds')
    parser.add_argument('--num-delays', type=int, default=4, help='Number of delays')
    parser.add_argument('--num-tr-trim', type=int, default=5, help='Number of training TRs to trim')
    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # ---------------------------------------------------------------------
    # Prepare output directories
    # ---------------------------------------------------------------------
    output_dir = Path('./plots') / 'attribution_analysis/df_analysis_ig'
    output_dir.mkdir(parents=True, exist_ok=True)

    # ---------------------------------------------------------------------
    # Load dataset
    # ---------------------------------------------------------------------
    dataset_dir = Path('./data') / args.dataset
    dataset = get_dataset(
        args.dataset,
        dataset_dir,
        tokenizer=None,
        device=device,
        context_length=args.context_length,
        remove_format_chars=True,
        remove_punc_spacing=True,
    )
    story_features, _ = dataset.get_story_features()
    print(story_features["syntactic"]["feature_tags"].shape)  # (5176, )
    print(story_features["semantic"]["feature_tags"].shape)  # (5176, )
    print(story_features["discourse"]["feature_tags"].shape)  # (5176, )
    # Check for at least one feature != 0 for each word (syntactic and semantic)
    for feat_type in ["syntactic", "semantic", "discourse"]:
        tags = story_features[feat_type]["feature_tags"]  # shape (num_words, num_features)
        nonzero_per_word = (tags != 0).sum(axis=1)
        num_words_with_any_feature = (nonzero_per_word > 0).sum()
        num_words_with_multiple_features = (nonzero_per_word > 1).sum()
        print(f"{feat_type.capitalize()} features:")
        print(f"  Words with at least one feature != 0: {num_words_with_any_feature} / {tags.shape[0]}")
        print(f"  Words with multiple features != 0: {num_words_with_multiple_features} / {tags.shape[0]}")
    '''
    Syntactic features:
      Words with at least one feature != 0: 4969 / 5176
      Words with multiple features != 0: 956 / 5176
    Semantic features:
      Words with at least one feature != 0: 3144 / 5176
      Words with multiple features != 0: 2449 / 5176
    '''
    

    # Remove the "visual" key from story_features if it exists
    if "visual" in story_features:
        del story_features["visual"]
    feature_names = list(story_features.keys())
    roi_names = ["all"] + dataset.roi_names


    delay_idxs = get_delay_idxs(dataset.runs_cropped, args.num_delays)
    fold_test_idxs = k_fold_test_idxs(dataset.subjects, 0, args.num_folds, args.num_tr_trim)

    df_data = []
    for model_idx, model_name in enumerate(args.models):
        experiment_dir = Path(f'/BRAIN/ssms2/work/outputs/{args.dataset}_{model_name}_{args.context_length}')
        brain_align_dir = experiment_dir / 'alignment_attrs' / args.attr_method
        next_word_dir = experiment_dir / 'next_word_attrs' / args.attr_method
        for threshold in args.top_mass:
            for layer_idx in args.layers[model_idx]:
                for subject_idx in range(len(dataset.subjects)):
                    for fold_idx, test_idxs in enumerate(fold_test_idxs):
                        print(f"Model: {model_name}, Layer: {layer_idx}, Subject: {subject_idx}, Fold: {fold_idx}, Threshold: {threshold}")
                        p_brain = brain_align_dir / f"subj_{subject_idx}_layer_{layer_idx}_fold_{fold_idx}.pkl"
                        p_lm    = next_word_dir / f"fold_{fold_idx}.pkl"
                        if not p_brain.exists():
                            print("SKIP (missing)", p_brain); continue
                        if not p_lm.exists():
                            print("SKIP (missing)", p_lm); continue

                        # 1) Load attributions
                        A_brain = load_pickle(p_brain)  # (N,R)
                        A_lm    = load_pickle(p_lm) # (N,)

                        # 2) Indices of the important words in the extended contexts
                        brain_sets, lm_sets, intersect_sets, num_words = compute_important_word_idxs(A_brain, A_lm, t=threshold/100)

                        # 3) Get discourse features masks for the extended contexts
                        for k, test_idx in enumerate(test_idxs):
                            tr_word_idxs = []
                            for d in range(args.num_delays - 1, -1, -1):
                                tr_idx = delay_idxs[test_idx, d]
                                if tr_idx == -1:
                                    continue

                                # Gather the 4 contexts in each of the d TRs
                                contexts_word_idxs = [dataset.context_word_idxs[word_idx] for word_idx in dataset.tr_to_word_idxs[tr_idx]]
                                tr_word_idxs.extend(contexts_word_idxs)

                            # Get the discourse features for the contexts
                            # all_dfs = (num_dfs, num_words) where num_words is the number of words in the extended context
                            for df_idx, df_name in enumerate(feature_names):
                                context_df = []
                                for word_idx in range(tr_word_idxs[0][0], tr_word_idxs[-1][1]):
                                    context_df.append(story_features[df_name]['feature_tags'][word_idx].sum().item()) # Number of discourse features represented by this word
                                context_df = np.array(context_df) # (num_words,)

                                if context_df.shape[0] != num_words[k]:
                                    print(f"SKIP (mismatch): {context_df.shape[0]} != {num_words[k]}")
                                    print(f"Context word indices: {tr_word_idxs}")
                                    continue
                                
                                # 4) Compute the proportion of discourse features represented by the important words in each context
                                total_df_words = context_df.sum()  # Total number of words representing this discourse feature
                                df_lm_words_proportion = context_df[list(lm_sets[k])].sum() / num_words[k] * 100
                                for roi_idx in range(len(brain_sets[k])):
                                    if roi_idx > 0:
                                        break
                                    if total_df_words > 0:
                                        df_brain_words_proportion = context_df[list(brain_sets[k][roi_idx])].sum() / total_df_words * 100
                                        df_intersect_words_proportion = context_df[list(intersect_sets[k][roi_idx])].sum() / total_df_words * 100
                                    else:
                                        df_brain_words_proportion = 0
                                        df_intersect_words_proportion = 0

                                    df_data.append({
                                        "model": model_name,
                                        "layer": layer_idx,
                                        "subject": subject_idx,
                                        "threshold": threshold,
                                        "feature": df_name,
                                        "BA_prop": df_brain_words_proportion,
                                        "NWP_prop": df_lm_words_proportion,
                                        "INT_prop": df_intersect_words_proportion
                                    })

    # Convert df_data to a DataFrame for easier plotting
    df = pd.DataFrame(df_data)


    PASTEL_RED    = '#FF9999'
    PASTEL_GREEN  = '#99CC99'
    PASTEL_ORANGE = '#FDBE87'

    for thresh, grp in df.groupby("threshold"):
        # aggregate across subjects → mean, sem (over subjects)
        agg_subjects = (grp
                        .groupby(["feature", "layer", "model"])  # keep 'layer' and 'model' for later aggregation
                        .agg(BA_mean=("BA_prop", "mean"),
                             NWP_mean=("NWP_prop", "mean"),
                             INT_mean=("INT_prop", "mean"))
                        .reset_index())

        # aggregate across layers → mean, sem (over layers)
        agg_layers = (agg_subjects
                      .groupby(["feature", "model"])  # keep 'model' for final aggregation
                      .agg(BA_mean=("BA_mean", "mean"),
                           NWP_mean=("NWP_mean", "mean"),
                           INT_mean=("INT_mean", "mean"))
                      .reset_index())

        # aggregate across models → mean, sem (over models)
        agg_models = (agg_layers
                      .groupby("feature")
                      .agg(BA_mean=("BA_mean", "mean"),
                           NWP_mean=("NWP_mean", "mean"),
                           INT_mean=("INT_mean", "mean"),
                           BA_sem=("BA_mean", "sem"),
                           NWP_sem=("NWP_mean", "sem"),
                           INT_sem=("INT_mean", "sem"))
                      .reset_index())

        bars = []
        sems = []
        for feat in feature_names:
            feat_row = agg_models[agg_models["feature"] == feat]
            bars.append((
                feat_row["BA_mean"].values[0],
                feat_row["NWP_mean"].values[0],
                feat_row["INT_mean"].values[0]
            ))
            sems.append((
                feat_row["BA_sem"].values[0],
                feat_row["NWP_sem"].values[0],
                feat_row["INT_sem"].values[0]
            ))

        bars_arr = np.array(bars)   # shape (F, 3)
        sems_arr = np.array(sems)

        payload = {
            "threshold"  : thresh,
            "feature_names": feature_names,          # explicit order
            "bars"       : bars_arr,                 # (F,3)
            "sems"       : sems_arr,                 # (F,3)
            # optional – keeps doors open:
            "per_model_layer": agg_layers.to_dict(), # or the DataFrame itself
        }
        with open(output_dir / f"thr-{thresh}_summary.pkl", "wb") as f:
            pickle.dump(payload, f)

        x = np.arange(len(feature_names))
        width = .25

        plt.figure(figsize=(2 * len(feature_names), 4))
        plt.bar(x - width, bars_arr[:, 0], width,
                yerr=sems_arr[:, 0], capsize=3,
                color=PASTEL_RED, label="Brain alignment")
        plt.bar(x, bars_arr[:, 1], width,
                yerr=sems_arr[:, 1], capsize=3,
                color=PASTEL_GREEN, label="Next‑word prediction")
        plt.bar(x + width, bars_arr[:, 2], width,
                yerr=sems_arr[:, 2], capsize=3,
                color=PASTEL_ORANGE, label="Intersection")

        plt.xticks(x, feature_names, fontsize=16)
        plt.yticks(fontsize=14)
        plt.ylabel("Percentage of important words", fontsize=16)
        current_ylim = plt.gca().get_ylim()[1]  # Get current y-limit (upper bound)
        plt.ylim(0, current_ylim + 5)  # Increase y-limit
        if thresh == 10:
            plt.legend(loc="upper left", fontsize=16, frameon=False)
        plt.tight_layout()

        plt.savefig(output_dir / f"thr-{thresh}_mean_across_models.png", dpi=300)
        plt.savefig(output_dir / f"thr-{thresh}_mean_across_models.pdf", dpi=1200)
        plt.close()


if __name__ == "__main__":
    main()

