import tqdm
import torch
import pickle
import numpy as np
from captum.attr import LayerGradientXActivation, LayerIntegratedGradients

from .attribution_utils import normalize_attributions
from .plotting_functions import save_highlighted_html
from .brain_encoding_model import BrainEncodingModel

def perform_feature_attribution_analysis(context_feature_presence, word_combined_attrs, device):
    # Get indices of words corresponding to discourse feature
    non_zero_indices = [i for i, feature in enumerate(context_feature_presence) if torch.any(feature)]
    non_zero_count = sum([(feature > 0).sum().item() for feature in context_feature_presence])

    if non_zero_count <= 0:
        return np.full(word_combined_attrs.shape[-1], np.nan), np.full(word_combined_attrs.shape[-1], np.nan), np.full(word_combined_attrs.shape[-1], np.nan)

    feature_strenght = torch.stack(context_feature_presence, dim=1).unsqueeze(-1).repeat(1, 1, word_combined_attrs.shape[-1])
    feature_strenght = feature_strenght[:, non_zero_indices].to(device) # (f, L_nonneg, R)

    # Compute sum of attributions for words corresponding to discourse feature
    positive_attributions_idxs = word_combined_attrs[non_zero_indices] > 0
    positive_attributions = (feature_strenght * (word_combined_attrs[non_zero_indices] * positive_attributions_idxs).unsqueeze(0)).sum(dim=0).sum(dim=0) / ((word_combined_attrs[non_zero_indices] * positive_attributions_idxs).sum(dim=0) + 1e-10) # (R,)

    negative_attributions_idxs = word_combined_attrs[non_zero_indices] < 0
    negative_attributions = (feature_strenght * (word_combined_attrs[non_zero_indices] * negative_attributions_idxs).unsqueeze(0)).sum(dim=0).sum(dim=0) / ((word_combined_attrs[non_zero_indices] * negative_attributions_idxs).abs().sum(dim=0) + 1e-10) # (R,)

    total_attributions = (feature_strenght * word_combined_attrs[non_zero_indices].unsqueeze(0)).abs().sum(dim=0).sum(dim=0) / (word_combined_attrs.abs().sum(0) + 1e-10) # (R,)

    total_attributions /= non_zero_count if non_zero_count > 0 else torch.zeros_like(total_attributions)
    positive_attributions /= non_zero_count if non_zero_count > 0 else torch.zeros_like(positive_attributions)
    negative_attributions /= non_zero_count if non_zero_count > 0 else torch.zeros_like(negative_attributions)

    return total_attributions.cpu().numpy(), positive_attributions.cpu().numpy(), negative_attributions.cpu().numpy()

def compute_attribution(
        args, model, ridge_weights, dataset,
        subject_idx, fold_idx, layer_idx, delay_idxs, test_idxs,
        subject_rois, roi_names,
        attribution_dir_data, attribution_dir_plots, device):
    # Split weights by delay
    hidden_dim = args.experiment.num_red_components if args.experiment.apply_pca else args.model.hidden_dim
    ridge_weights = ridge_weights.reshape(args.experiment.num_delays, hidden_dim, -1) # (D, H, V)
    
    fold_word_attrs, extended_contexts, extended_contexts_word_idxs = [], [], []
    for k, test_idx in enumerate(tqdm.tqdm(test_idxs)):
        for d in range(args.experiment.num_delays-1, -1, -1):
            # Get contexts for each word in the TR
            tr_idx = delay_idxs[test_idx, d]
            if tr_idx == -1:
                continue

            # Define full encoding model and attribution method
            encoding_model = BrainEncodingModel(model, ridge_weights[d], args.model.type, args.model.name, device)
            if args.attribution.method == 'gxi':
                attr_method = LayerGradientXActivation(
                    forward_func=encoding_model.forward,
                    layer=encoding_model.language_model.get_input_embeddings()
                )
            else:
                attr_method = LayerIntegratedGradients(
                    forward_func=encoding_model.forward,
                    layer=encoding_model.language_model.get_input_embeddings()
                )

            contexts = [dataset.contexts[word_idx] for word_idx in dataset.tr_to_word_idxs[tr_idx]]
            contexts_word_idxs = [dataset.context_word_idxs[word_idx] for word_idx in dataset.tr_to_word_idxs[tr_idx]]
            
            ## For the context of each word in the TR, compute token attributions for each ROI
            # Exract token IDs for the contexts in the TR
            data_loader, token_idxs_to_avg, word_idx_to_tok_idx = dataset.get_context_token_ids(len(contexts), contexts)
            token_ids, attention_mask = next(iter(data_loader))
            token_ids = token_ids.to(device)
            attention_mask = attention_mask.to(device)
            
            token_attrs = torch.zeros((token_ids.shape[0], token_ids.shape[1], len(subject_rois.keys())), device=device) # (4, M, R)
            for roi_idx, roi in enumerate(subject_rois.keys()):
                roi_mask = subject_rois[roi]
                if args.attribution.method == 'gxi':
                    # For LayerGradientXActivation, we need to provide the "target" token IDs
                    attrs = attr_method.attribute(token_ids, additional_forward_args=(attention_mask, layer_idx, token_idxs_to_avg, dataset.subjects[subject_idx][tr_idx], roi_mask))  # (4, M, H)
                    attrs = torch.norm(attrs, p=2, dim=-1)  # shape: [batch_size, seq_len]
                else:
                    attrs = attr_method.attribute(
                        token_ids,  # shape [batch_size, seq_len]
                        baselines=torch.zeros_like(token_ids).to(device),
                        additional_forward_args=(attention_mask, layer_idx, token_idxs_to_avg, dataset.subjects[subject_idx][tr_idx], roi_mask),
                        internal_batch_size=args.model.batch_size,
                    )  # shape [batch_size, seq_len, hidden_dim]
                    attrs = attrs.sum(dim=-1)  # shape: [batch_size, seq_len]
                token_attrs[:, :, roi_idx] = attrs

            for context_idx, context in enumerate(contexts):
                context = context.split(' ')
                c_word_idx_to_tok_idx = word_idx_to_tok_idx[context_idx]

                # Compute word attributions for all the contexts in the TR by averaging token attributions
                word_attrs = torch.zeros((len(context), token_attrs.shape[-1]), device=device) # (L, H)
                for word_idx, tok_idxs in c_word_idx_to_tok_idx.items():
                    word_attrs[word_idx] = token_attrs[context_idx, tok_idxs].sum(dim=0)

                # Combine the attributions for all words in the delayed TR
                if len(extended_contexts) <= k:
                    word_combined_attrs = word_attrs # (L_ext, R)
                    extended_contexts.append(context.copy())
                    extended_contexts_word_idxs.append([contexts_word_idxs[context_idx][0], contexts_word_idxs[context_idx][1]])
                else:
                    new_words_count = contexts_word_idxs[context_idx][1] - extended_contexts_word_idxs[k][1]
                    if new_words_count <= 0:
                        new_words_count = 1
                    word_combined_attrs = torch.cat((word_combined_attrs, torch.zeros((new_words_count, word_attrs.shape[1]), device=device)), dim=0) # (L_ext, R)
                    word_combined_attrs[-word_attrs.shape[0]:] += word_attrs
                    extended_contexts[k] += context[-new_words_count:]
                    extended_contexts_word_idxs[k][1] = contexts_word_idxs[context_idx][1]
        
        # Store unnormalized attributions
        fold_word_attrs.append(word_combined_attrs.cpu().numpy())

        # Normalize attributions
        #word_combined_attrs = normalize_attributions(word_combined_attrs)
        # Plot attributions
        #save_highlighted_html(word_combined_attrs.cpu().numpy(), extended_contexts[k], roi_names, attribution_dir_plots / f"subj_{subject_idx}_layer_{layer_idx}_fold_{fold_idx}_tr_{test_idx}.html")

    # Store attributions for the fold
    with open(attribution_dir_data / f"subj_{subject_idx}_layer_{layer_idx}_fold_{fold_idx}.pkl", 'wb') as f:
        pickle.dump(fold_word_attrs, f)