from pathlib import Path
import gc
import torch
import tqdm
import gzip
import pickle
from transformers import AutoTokenizer
from captum.attr import LayerGradientXActivation, LayerIntegratedGradients, GradientShap
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

from brain_alignment import get_delay_idxs, k_fold_test_idxs, story_fold_test_idxs, BrainEncodingModel
from datasets import get_dataset
from models import *
from utils import *

class AttributionMethod:
    def __init__(self, method_name, encoding_model):
        self.method_name = method_name
        self.encoding_model = encoding_model

        self.create_attr_method()

    def create_attr_method(self):
        # Define the attribution method
        if self.method_name == 'gxi':
            self.attr_method = LayerGradientXActivation(
                forward_func=self.encoding_model.forward,
                layer=self.encoding_model.language_model.get_input_embeddings()
            )
            self.interpretable_layer = None
        elif self.method_name == 'ig':
            self.attr_method = LayerIntegratedGradients(
                forward_func=self.encoding_model.forward,
                layer=self.encoding_model.language_model.get_input_embeddings()
            )
            self.interpretable_layer = None

    def compute_token_attrs(self, token_ids, attention_mask, layer_idx, token_idxs_to_avg, true_activity, roi_masks, args, device):
        if args.attribution.method == 'gxi':
            first_roi_mask = None
            if roi_masks is not None:
                first_roi_mask = roi_masks[0] if isinstance(roi_masks, list) else list(roi_masks)[0] # Compute just for 'all' ROI
                first_roi_mask = np.expand_dims(first_roi_mask, axis=0)
            token_attrs = self.attr_method.attribute(
                token_ids.to(device),
                additional_forward_args=(attention_mask.to(device), layer_idx, token_idxs_to_avg, true_activity, first_roi_mask),
            )
            return torch.norm(token_attrs, p=2, dim=-1).unsqueeze(-1) # (4, M, 1)
        elif args.attribution.method == 'ig':
            first_roi_mask = None
            if roi_masks is not None:
                first_roi_mask = roi_masks[0] if isinstance(roi_masks, list) else list(roi_masks)[0] # Compute just for 'all' ROI
                first_roi_mask = np.expand_dims(first_roi_mask, axis=0)
            token_attrs = self.attr_method.attribute(
                token_ids.to(device),
                baselines=torch.zeros_like(token_ids).to(device),
                additional_forward_args=(attention_mask.to(device), layer_idx, token_idxs_to_avg, true_activity, first_roi_mask, 20), # Compute just for 'all' ROI
                n_steps=20,
                internal_batch_size=5,
            )
            return token_attrs.sum(dim=-1).unsqueeze(-1)

@measure_performance
def main():
    args = parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    ### Set attribution method
    if args.attr_method is not None:
        args.attribution.method = args.attr_method

    ### Create output directory
    experiment_dir = create_output_directory(args)
    output_dir = experiment_dir / "alignment_attrs" / args.attribution.method
    os.makedirs(output_dir, exist_ok=True)
    
    if args.experiment.verbose:
        print("Outputs will be stored in directory ", str(output_dir))

    ### Get the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model.hugging_face_model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    ### Get the model
    model = BrainAlignLanguageModel(args.model.hugging_face_model_id, args.experiment.apply_pca, args.experiment.num_red_components, device, args.model.type, False)

    ### Get the dataset
    if args.experiment.verbose:
        print(f"Starting brain alignment pipeline.")
        print("Step 1: Extract contexts for each word and get word-TR correspondance.")

    dataset_dir = Path(args.data_root_dir) / args.dataset.name
    dataset = get_dataset(
        args.dataset.name,
        dataset_dir,
        tokenizer=tokenizer,
        device=device,
        context_length=args.context_length,
        remove_format_chars=args.dataset.remove_format_chars,
        remove_punc_spacing=args.dataset.remove_punc_spacing,
        verbose=args.experiment.verbose
    )
    
    if args.experiment.subject_idx == '':
        subject_names = dataset.subject_idxs
        subject_idxs = list(range(len(subject_names)))
    ## Compute correlations for one specific subject
    else:
        subject_names = args.experiment.subject_idx.split(',')
        subject_idxs = [dataset.subject_idxs.index(subject_name) for subject_name in subject_names]

    
    ### Get the aggregated hidden representations
    story_names = []
    if args.dataset.name != "MothRadioHour":
        delay_idxs = get_delay_idxs(dataset.runs_cropped, args.experiment.num_delays)
    else:
        story_names = list(dataset.story_idx_to_name.values())

    # Define the encoding model and the attribution method
    encoding_model = BrainEncodingModel(model, args.model.type, args.model.name, device)
    if hasattr(encoding_model.language_model, 'gradient_checkpointing_enable'):
        encoding_model.language_model.gradient_checkpointing_enable()
    attr_method = AttributionMethod(args.attribution.method, encoding_model)

    if args.experiment.verbose:
        print("Step 2: Compute attributions.")

    for subject_idx in subject_idxs:
        if args.dataset.name == "MothRadioHour":
            fold_test_idxs, story_delay_idxs = story_fold_test_idxs(dataset.subjects, subject_idx, args.experiment.num_tr_trim, args.experiment.num_delays)
        else:
            fold_test_idxs = k_fold_test_idxs(dataset.subjects, subject_idx, args.experiment.num_folds, args.experiment.num_tr_trim)
            subject_rois = dataset.subject_rois[subject_idx]
        for layer_idx in args.attribution.layers:
            for fold_idx, test_idxs in enumerate(fold_test_idxs):
                print(f"Processing subject {subject_idx}, layer {layer_idx}, fold {fold_idx}...")
                # Load ridge weights
                with gzip.open(experiment_dir / "ridge_weights" / f"ridge_weights_subject_{subject_idx}_layer_{layer_idx}_fold_{fold_idx}.pt.gz", "rb") as f:
                    ridge_weights = torch.load(f, map_location="cpu").float()
                # Split weights by delay
                hidden_dim = args.experiment.num_red_components if args.experiment.apply_pca else args.model.hidden_dim
                ridge_weights = ridge_weights.reshape(args.experiment.num_delays, hidden_dim, -1) # (D, H, V)
                
                if args.dataset.name == "MothRadioHour":
                    delay_idxs = story_delay_idxs[fold_idx]
                    story_name = story_names[fold_idx]
                    story_idx = next(k for k, v in dataset.story_idx_to_name.items() if v == story_name)

                fold_word_attrs, extended_contexts, extended_contexts_word_idxs = [], [], []
                for k, test_idx in enumerate(tqdm.tqdm(test_idxs)):
                    word_combined_attrs = None
                    for d in range(args.experiment.num_delays - 1, -1, -1):
                        tr_idx = delay_idxs[test_idx][d]
                        if tr_idx == -1:
                            continue

                        # Define full encoding model and attribution method
                        encoding_model.update_weights(ridge_weights[d].to(device))
                        
                        # Gather the contexts in each of the d TRs
                        if args.dataset.name == "MothRadioHour":
                            contexts = [dataset.contexts[story_name][word_idx] for word_idx in dataset.tr_to_word_idxs[story_name][tr_idx]]
                            contexts_word_idxs = [dataset.context_word_idxs[story_name][word_idx] for word_idx in dataset.tr_to_word_idxs[story_name][tr_idx]]
                        else:
                            contexts = [dataset.contexts[word_idx] for word_idx in dataset.tr_to_word_idxs[tr_idx]]
                            contexts_word_idxs = [dataset.context_word_idxs[word_idx] for word_idx in dataset.tr_to_word_idxs[tr_idx]]

                        if len(contexts) == 0:
                            log_msg = f"{fold_idx}, {subject_idx}, {layer_idx}, {tr_idx} - No contexts found for this TR."
                            print(log_msg)
                            with open("./empty_contexts.log", "a") as log_file:
                                log_file.write(log_msg + "\n")
                            continue

                        ## For the context of each word in the TR, compute token attributions for each ROI
                        # Exract token IDs for the contexts in the TR
                        data_loader, token_idxs_to_avg, word_idx_to_tok_idx = dataset.get_context_token_ids(len(contexts), contexts)
                        token_ids, attention_mask = next(iter(data_loader))
                        
                        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                            token_attrs = attr_method.compute_token_attrs(
                                token_ids,
                                attention_mask,
                                layer_idx,
                                token_idxs_to_avg,
                                dataset.subjects[subject_idx][tr_idx] if args.dataset.name == "HarryPotter" else dataset.subjects[subject_idx][f"story_{story_idx}"][tr_idx],
                                list(subject_rois.values()) if args.dataset.name == "HarryPotter" else None,
                                args,
                                device,
                            )
                        token_attrs = token_attrs.view(len(contexts), -1, token_attrs.shape[-1]).detach().cpu().numpy() # (num_words_in_tr, num_tokens, R)

                        for context_idx, context in enumerate(contexts):
                            context = context.split()
                            c_word_idx_to_tok_idx = word_idx_to_tok_idx[context_idx]

                            # Compute word attributions for all the contexts in the TR by averaging token attributions
                            word_attrs = np.zeros((len(context), token_attrs.shape[-1])) # (L, H)
                            for word_idx, tok_idxs in c_word_idx_to_tok_idx.items():
                                word_attrs[word_idx] = np.sum(token_attrs[context_idx, tok_idxs], axis=0)

                            # Combine the attributions for all words in the delayed TR
                            if len(extended_contexts) <= k:
                                word_combined_attrs = word_attrs # (L_ext, R)
                                extended_contexts.append(context.copy())
                                extended_contexts_word_idxs.append([contexts_word_idxs[context_idx][0], contexts_word_idxs[context_idx][1]])
                            else:
                                new_words_count = contexts_word_idxs[context_idx][1] - extended_contexts_word_idxs[k][1]
                                if new_words_count < 0: # Harry Potter had <= 0
                                    new_words_count = 1
                                if new_words_count == 0:
                                    word_combined_attrs[-word_attrs.shape[0]:] += word_attrs
                                else:
                                    word_combined_attrs = np.concatenate((word_combined_attrs, np.zeros((new_words_count, word_attrs.shape[1]))), axis=0) # (L_ext, R)
                                    word_combined_attrs[-word_attrs.shape[0]:] += word_attrs
                                    extended_contexts[k] += context[-new_words_count:]
                                    extended_contexts_word_idxs[k][1] = contexts_word_idxs[context_idx][1]

                        del token_attrs, token_ids, attention_mask
                        torch.cuda.empty_cache()
                    
                    print(f"TR idx {tr_idx}, attrs: ", word_combined_attrs.shape)
                    # Store unnormalized attributions
                    if word_combined_attrs is not None:
                        fold_word_attrs.append(word_combined_attrs)
                        del word_combined_attrs
                        torch.cuda.empty_cache()

                # Store attributions for the fold
                with open(output_dir / f"subj_{subject_idx}_layer_{layer_idx}_fold_{fold_idx}.pkl", 'wb') as f:
                    pickle.dump(fold_word_attrs, f)
                
                del fold_word_attrs, extended_contexts, extended_contexts_word_idxs
                gc.collect()
                torch.cuda.empty_cache()

if __name__ == "__main__":
    main()
