import os
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torch
import math
import pickle
from statistics import mean
import numpy as np
from scipy import stats
from tqdm.auto import tqdm
from scipy.special import softmax
from transformers import AutoTokenizer
from datasets import load_dataset
from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification


class XAIEvaluator:
    def __init__(self, model_name, dataset_name, num_samples=5000, text_column='text', label_column='label', random_seed=42, device='cpu'):
        self.device = device
        self.model = self._load_model(model_name)
        self.tokenizer = self._load_tokenizer(model_name)
        self.dataset = self._load_dataset(dataset_name)
        self.text_column = text_column
        self.label_column = label_column
        self.random_seed = random_seed
        self.selected_samples = self._prepare_samples(num_samples)
        self.explanations = Generator(self.model)
        self.num_classes = len(set(sample['label'] for sample in self.selected_samples))

    def _load_model(self, model_name):
        model = BertForSequenceClassification.from_pretrained(model_name).to(self.device)
        model.eval()
        return model

    def _load_tokenizer(self, model_name):
        return AutoTokenizer.from_pretrained(model_name)

    def _load_dataset(self, dataset_name):
        return load_dataset(dataset_name)

    def _prepare_samples(self, num_samples):
        """
        Prepare samples for evaluation, considering the specified text and label columns.
        
        Args:
        num_samples (int): Number of samples to prepare.
        
        Returns:
        list: Selected samples for evaluation.
        """
        half_sample = int(num_samples / 2)
        lengths = [len(text) for text in self.dataset['test'][self.text_column]]
        mode_length = stats.mode(lengths)[0]
        
        less_than_mode = [self.dataset['test'][i] for i in range(len(self.dataset['test'])) 
                          if len(self.dataset['test'][i][self.text_column]) < mode_length]
        greater_than_mode = [self.dataset['test'][i] for i in range(len(self.dataset['test'])) 
                             if len(self.dataset['test'][i][self.text_column]) >= mode_length]
        
        np.random.seed(self.random_seed)
        selected_less = np.random.choice(less_than_mode, half_sample, replace=False)
        selected_greater = np.random.choice(greater_than_mode, half_sample, replace=False)
        
        return np.concatenate((selected_less, selected_greater))

    def preprocess_sample(self, input_str):
        special_tokens_ids = self.tokenizer.all_special_ids 
        tokenized_input = self.tokenizer(input_str, add_special_tokens=True, truncation=True)
        input_ids = tokenized_input['input_ids']
        text_ids = torch.tensor([input_ids]).to(self.device)
        text_words = self.tokenizer.convert_ids_to_tokens(text_ids[0])
        
        att_mask = tokenized_input['attention_mask']
        special_tokens_indices = [index for index, token_id in enumerate(input_ids) if token_id in special_tokens_ids]
        att_mask = [0 if index in special_tokens_indices else 1 for index, _ in enumerate(att_mask)]
        att_mask = torch.tensor([att_mask]).to(self.device)
        
        return text_ids, att_mask, text_words

    def predict(self, text_ids, target, att_mask=None, seg_ids=None):
        out = self.model(text_ids, attention_mask=att_mask, token_type_ids=seg_ids)
        logits = out[0]
        pred_class = torch.argmax(logits, axis=1).cpu().detach().numpy()
        pred_class_prob = softmax(logits.cpu().detach().numpy(), axis=1)
        return pred_class[0], pred_class_prob[:, target][0]

    def truncate_words(self, sorted_idx, text_words, text_ids, replaced_num, seg_ids=None):
        to_be_replaced_idx = []
        i = 0
        special_tokens = self.tokenizer.all_special_tokens

        while len(to_be_replaced_idx) < replaced_num and i != len(text_words) - 1:
            current_idx = sorted_idx[i]
            if text_words[current_idx] not in special_tokens:
                to_be_replaced_idx.append(current_idx)
            i += 1

        remaining_idx = sorted(list(set(sorted_idx) - set(to_be_replaced_idx)))
        truncated_text_ids = text_ids[0, np.array(remaining_idx)]

        if seg_ids is not None:
            seg_ids = seg_ids[0, np.array(remaining_idx)]

        truncated_text_words = np.array(text_words)[remaining_idx]
        return truncated_text_ids.unsqueeze(0), truncated_text_words, seg_ids

    def replace_words(self, sorted_idx, text_words, text_ids, replaced_num):
        to_be_replaced_idx = []
        i = 0
        special_tokens = self.tokenizer.all_special_tokens

        while len(to_be_replaced_idx) < replaced_num and i != len(text_words) - 1:
            current_idx = sorted_idx[i]
            if text_words[current_idx] not in special_tokens:
                to_be_replaced_idx.append(current_idx)
            i += 1

        mask_token = self.tokenizer.mask_token
        mask_token_id = self.tokenizer.mask_token_id

        replaced_text_ids = text_ids.clone()
        replaced_text_ids[0, to_be_replaced_idx] = mask_token_id
        replaced_text_words = np.copy(text_words)
        replaced_text_words[to_be_replaced_idx] = mask_token

        return replaced_text_ids, replaced_text_words

    def evaluate(self, file_path_prefix, degrade_step=10, seg_ids=None):
        result_info, degradation_info, del_info = self.test(degrade_step, seg_ids)
        
        # if file_path_prefix does not exist, create it
        if not os.path.exists(file_path_prefix):
            os.makedirs(file_path_prefix)

        # save results
        with open(os.path.join(file_path_prefix, 'result_info.pkl'), 'wb') as f:
            pickle.dump(result_info, f, pickle.HIGHEST_PROTOCOL)
        
        return result_info, degradation_info, del_info

    def test(self, degrade_step=10, seg_ids=None):
        original_probs = []
        original_accs = []
        degradation_results = {method: {'probs': [], 'predictions': []} for method in self.explanations.methods}
        del_results = {method: {'probs': [], 'predictions': []} for method in self.explanations.methods}
        
        true_labels = []

        for test_instance in tqdm(self.selected_samples, total=len(self.selected_samples)):
            text = test_instance[self.text_column]
            target = test_instance[self.label_column]
            true_labels.append(target)

            text_ids, att_mask, text_words = self.preprocess_sample(text)
            total_len = len(text_words)

            granularity = [i/degrade_step for i in range(1, degrade_step+1)]
            trunc_words_num = [int(g * total_len) for g in granularity]
            trunc_words_num = list(dict.fromkeys(trunc_words_num))

            original_class, original_prob = self.predict(text_ids, target)

            expln_info = self.explanations.generate_all_explanations(text_ids, att_mask, target)

            for method in self.explanations.methods:
                attribution = expln_info[method].cpu().detach().tolist()
                sorted_idx = sorted(range(len(attribution)), key=lambda k: attribution[k], reverse=True)

                instance_degradation_probs = []
                instance_degradation_predictions = []
                instance_replace_probs = []
                instance_replace_predictions = []

                for num in trunc_words_num:
                    truncated_text_ids, _, _ = self.truncate_words(sorted_idx, text_words, text_ids, num, seg_ids)
                    replaced_text_ids, _ = self.replace_words(sorted_idx, text_words, text_ids, num)

                    trunc_class, trunc_prob = self.predict(truncated_text_ids, target, seg_ids=seg_ids)
                    rep_class, rep_prob = self.predict(replaced_text_ids, target, seg_ids=seg_ids)

                    instance_degradation_probs.append(trunc_prob)
                    instance_degradation_predictions.append(trunc_class)
                    instance_replace_probs.append(rep_prob)
                    instance_replace_predictions.append(rep_class)

                degradation_results[method]['probs'].append(instance_degradation_probs)
                degradation_results[method]['predictions'].append(instance_degradation_predictions)
                del_results[method]['probs'].append(instance_replace_probs)
                del_results[method]['predictions'].append(instance_replace_predictions)

            original_probs.append(original_prob)
            original_accs.append(original_class == target)

        result_info = {}
        for method in self.explanations.methods:
            # Calculate AOPC and LogOdds
            aopc, aopc_result = self.cal_aopc(original_probs, degradation_results[method]['probs'])
            logodds, logodds_result = self.cal_logodds(original_probs, degradation_results[method]['probs'])
            aopc_del, aopc_del_result = self.cal_aopc(original_probs, del_results[method]['probs'])
            logodds_del, logodds_del_result = self.cal_logodds(original_probs, del_results[method]['probs'])

            # Calculate accuracy, precision, recall, and F1 score
            degradation_metrics = self.calculate_metrics(true_labels, degradation_results[method]['predictions'])
            del_metrics = self.calculate_metrics(true_labels, del_results[method]['predictions'])

            result_info[method] = {
                'aopc': aopc,
                'aopc_result': aopc_result,
                'logodds': logodds,
                'logodds_result': logodds_result,
                'degradation_metrics': degradation_metrics
            }
            result_info[f"{method}_del"] = {
                'aopc': aopc_del,
                'aopc_result': aopc_del_result,
                'logodds': logodds_del,
                'logodds_result': logodds_del_result,
                'del_metrics': del_metrics
            }

        return result_info, degradation_results, del_results
    
    def calculate_metrics(self, true_labels, predictions):
        """
        Calculate accuracy, precision, recall, and F1 score for multi-label classification.
        
        Args:
        true_labels (list): List of true labels for each sample.
        predictions (list): List of predicted labels for each sample and each degradation step.
        
        Returns:
        dict: Dictionary containing the calculated metrics for each degradation step.
        """
        metrics = {
            'accuracy': [],
            'precision': [],
            'recall': [],
            'f1': []
        }
        
        # Transpose predictions to get a list for each degradation step
        predictions_by_step = list(map(list, zip(*predictions)))
        
        for step_predictions in predictions_by_step:
            accuracy = accuracy_score(true_labels, step_predictions)
            if self.num_classes == 2:
                precision, recall, f1, _ = precision_recall_fscore_support(true_labels, step_predictions, average='binary')
            else:
                precision, recall, f1, _ = precision_recall_fscore_support(true_labels, step_predictions, average='weighted')
            
            metrics['accuracy'].append(accuracy)
            metrics['precision'].append(precision)
            metrics['recall'].append(recall)
            metrics['f1'].append(f1)
        
        return metrics
    def cal_aopc(self, original_probs, degradation_probs):
        """
        Calculate the Area Over the Perturbation Curve (AOPC) using lists.
        
        Args:
        original_probs (list): Original probabilities for each sample.
        degradation_probs (list): Degraded probabilities for each sample and each degradation step.
        
        Returns:
        tuple: AOPC value and AOPC results for each degradation step.
        """
        diffs = []
        for orig_prob, deg_probs in zip(original_probs, degradation_probs):
            diffs.append([abs(orig_prob - deg_prob) for deg_prob in deg_probs])
        
        max_steps = max(len(d) for d in diffs)
        result = [mean(diff[i] for diff in diffs if i < len(diff)) for i in range(max_steps)]
        aopc = mean(result)

        return aopc, result

    def cal_logodds(self, original_probs, degradation_probs):
        """
        Calculate the Log-odds ratio using lists.
        
        Args:
        original_probs (list): Original probabilities for each sample.
        degradation_probs (list): Degraded probabilities for each sample and each degradation step.
        
        Returns:
        tuple: Log-odds value and Log-odds results for each degradation step.
        """
        epsilon = 1e-10  # Small value to avoid division by zero or log(0)
        ratios = []
        for orig_prob, deg_probs in zip(original_probs, degradation_probs):
            ratios.append([math.log((deg_prob + epsilon) / (orig_prob + epsilon)) for deg_prob in deg_probs])
        
        max_steps = max(len(r) for r in ratios)
        result = [mean(ratio[i] for ratio in ratios if i < len(ratio)) for i in range(max_steps)]
        logodds = mean(result)

        return logodds, result



class Generator:
    def __init__(self, model, device="cpu"):
        self.device = device
        self.model = model.to(self.device)
        self.model.eval()
        self.methods = ['AttCAT', 'CAT', 'TransAtt', 'FullLRP', 'PartialLRP', 'RawAtt', 'Rollout', 'AttGrads', 'Grads']

    def generate_all_explanations(self, input_ids, attention_mask, index=None):
        explanations = {}
        explanations['AttCAT'] = self.generate_TransCAM(input_ids, attention_mask, index, cat_flag=False)
        explanations['CAT'] = self.generate_TransCAM(input_ids, attention_mask, index, cat_flag=True)
        explanations['TransAtt'] = self.generate_LRP(input_ids, attention_mask, index, start_layer=0)[0]
        explanations['FullLRP'] = self.generate_full_lrp(input_ids, attention_mask, index)[0]
        explanations['PartialLRP'] = self.generate_LRP_last_layer(input_ids, attention_mask, index)[0]
        explanations['RawAtt'] = self.generate_attn_last_layer(input_ids, attention_mask)[0]
        explanations['Rollout'] = self.generate_rollout(input_ids, attention_mask, start_layer=0)[0]
        explanations['AttGrads'] = self.generate_attn_gradcam(input_ids, attention_mask, index, grad_flag=True)[0]
        explanations['Grads'] = self.generate_attn_gradcam(input_ids, attention_mask, index, grad_flag=False)[0]
        return explanations

    def forward(self, input_ids, attention_mask):
        return self.model(input_ids, attention_mask)

    def generate_TransCAM(self, input_ids, attention_mask, index=None, cat_flag=False):
        """
        Generate the Transformer Class Activation Maps (TransCAM) for a given input.
        
        args:
            input_ids (Tensor): The input tensor of token indices.
            attention_mask (Tensor): The attention mask for the input.
            index (int, optional): The index to explain, defaults to None.
            cat_flag (bool, optional): Whether to multiplied the attention maps, defaults to False.
        
        Returns:
            Tensor: The TransCAM explanation tensor.
        """

        result = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        
        logits = result.get("logits")
        hs = result.get("hidden_states")

        kwargs = {"alpha": 1}

        blocks = self.model.bert.encoder.layer

        for blk_id in range(len(blocks)):
            hs[blk_id].retain_grad()

        if index == None:
            index = np.argmax(logits.cpu().data.numpy(), axis=-1)

        one_hot = np.zeros((1, logits.size()[-1]), dtype=np.float32)
        one_hot[0, index] = 1
        one_hot_vector = one_hot
        one_hot = torch.from_numpy(one_hot).requires_grad_(True)
        one_hot = torch.sum(one_hot * logits)

        self.model.zero_grad()
        one_hot.backward(retain_graph=True)

        self.model.relprop(torch.tensor(one_hot_vector).to(self.device), **kwargs)

        cams = {}
        
        for blk_id in range(len(blocks)):
            hs_grads = hs[blk_id].grad
            
            att = blocks[blk_id].attention.self.get_attn().squeeze(0)
            att = att.mean(dim=0)
            att = att.mean(dim=0)
            
            cat = (hs_grads * hs[blk_id]).sum(dim=-1).squeeze(0)
            if not cat_flag:
                cat = cat * att
            
            cams[blk_id] = cat
            
        trans_expln = sum(cams.values())

        return trans_expln

    def generate_LRP(self, input_ids, attention_mask, index=None, start_layer=11):
        """
        Generate LRP for the given input_ids and attention_mask.

        Args:
            input_ids (Tensor): The input tensor of token indices.
            attention_mask (Tensor): The attention mask tensor.
            index (int, optional): The index to use for computing LRP. Defaults to None.
            start_layer (int, optional): The starting layer for computing the rollout attention. Defaults to 11.

        Returns:
            Tensor: The computed rollout attention for the specified index.
        """
        
        result = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = result.get('logits')
        
        kwargs = {"alpha": 1}

        if index == None:
            index = np.argmax(logits.cpu().data.numpy(), axis=-1)

        one_hot = np.zeros((1, logits.size()[-1]), dtype=np.float32)
        one_hot[0, index] = 1
        one_hot_vector = one_hot
        one_hot = torch.from_numpy(one_hot).requires_grad_(True)
        one_hot = torch.sum(one_hot * logits)

        self.model.zero_grad()
        one_hot.backward(retain_graph=True)

        self.model.relprop(torch.tensor(one_hot_vector).to(self.device), **kwargs)

        cams = []
        blocks = self.model.bert.encoder.layer

        for blk in blocks:
            grad = blk.attention.self.get_attn_gradients()
            cam = blk.attention.self.get_attn_cam()
            cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
            grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
            cam = grad * cam
            cam = cam.clamp(min=0).mean(dim=0)
            cams.append(cam.unsqueeze(0))

        rollout = self.compute_rollout_attention(cams, start_layer=start_layer)
        rollout[:, 0, 0] = 0

        return rollout[:, 0]

    def generate_LRP_last_layer(self, input_ids, attention_mask, index=None):
        """
        Generate the last layer relevance propagation (LRP) for the given input_ids and attention_mask.

        Args:
            input_ids (torch.Tensor): The input tensor representing the tokenized input sequence.
            attention_mask (torch.Tensor): The attention mask tensor for the input sequence.
            index (int, optional): The index for selecting the output neuron. 
            If not provided, the index of the maximum value in the output tensor will be used.

        Returns:
            torch.Tensor: The class activation map (CAM) for the last layer.
        """
        result = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = result.get('logits')
        kwargs = {"alpha": 1}

        if index == None:
            index = np.argmax(logits.cpu().data.numpy(), axis=-1)

        one_hot = np.zeros((1, logits.size()[-1]), dtype=np.float32)
        one_hot[0, index] = 1
        one_hot_vector = one_hot
        one_hot = torch.from_numpy(one_hot).requires_grad_(True)
        one_hot = torch.sum(one_hot * logits)

        self.model.zero_grad()
        one_hot.backward(retain_graph=True)

        self.model.relprop(torch.tensor(one_hot_vector).to(self.device), **kwargs)

        cam = self.model.bert.encoder.layer[-1].attention.self.get_attn_cam()[0]
        cam = cam.clamp(min=0).mean(dim=0).unsqueeze(0)
        cam[:, 0, 0] = 0
        return cam[:, 0]

    def generate_full_lrp(self, input_ids, attention_mask, index=None):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
        """
        Generate full LRP for the given input_ids and attention_mask.

        Args:
            input_ids: The input IDs for the model.
            attention_mask: The attention mask for the model.
            index: The index for the output. Defaults to None.

        Returns:
            The generated full LRP.
        """
        kwargs = {"alpha": 1}

        if index == None:
            index = np.argmax(output.cpu().data.numpy(), axis=-1)

        one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
        one_hot[0, index] = 1
        one_hot_vector = one_hot
        one_hot = torch.from_numpy(one_hot).requires_grad_(True)
        one_hot = torch.sum(one_hot * output)

        self.model.zero_grad()
        one_hot.backward(retain_graph=True)

        cam = self.model.relprop(torch.tensor(one_hot_vector).to(self.device), **kwargs)
        cam = cam.sum(dim=2)
        cam[:, 0] = 0
        return cam

    def generate_attn_last_layer(self, input_ids, attention_mask):
        """
        Function to generate attention for the last layer of a model.

        Args:
            self: The object instance
            input_ids: The input tensor for the model
            attention_mask: The attention mask tensor
            index: The index for the attention layer (default is None)

        Returns:
            Tensor: The attention for the last layer
        """
        _ = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
        cam = self.model.bert.encoder.layer[-1].attention.self.get_attn()[0]
        cam = cam.mean(dim=0).unsqueeze(0)
        cam[:, 0, 0] = 0
        return cam[:, 0]

    def generate_rollout(self, input_ids, attention_mask, start_layer=0):
        """
        Generate a rollout attention matrix based on the input_ids and attention_mask for a specific start_layer.

        Args:
            input_ids (Tensor): The input tensor containing the token indices.
            attention_mask (Tensor): The tensor containing the attention mask to avoid performing attention on padding tokens.
            start_layer (int, optional): The starting layer for computing the rollout attention. Defaults to 0.

        Returns:
            Tensor: The computed rollout attention matrix for the specified start_layer.
        """
        self.model.zero_grad()
        _ = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]

        blocks = self.model.bert.encoder.layer
        all_layer_attentions = []

        for blk in blocks:
            attn_heads = blk.attention.self.get_attn()
            avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
            all_layer_attentions.append(avg_heads)

        rollout = self.compute_rollout_attention(all_layer_attentions, start_layer=start_layer)
        rollout[:, 0, 0] = 0

        return rollout[:, 0]

    def generate_attn_gradcam(self, input_ids, attention_mask, index=None, grad_flag:bool=True):
        """
        Generate attention GradCAM for a given input using the model.

        Parameters:
        - input_ids: The input IDs for the model
        - attention_mask: The attention mask for the model
        - index: Index value for selecting the class (default is None)

        Returns:
        - cam: The attention GradCAM for the selected class
        """
        result = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = result.get('logits')
        kwargs = {"alpha": 1}

        if index == None:
            index = np.argmax(logits.cpu().data.numpy(), axis=-1)

        one_hot = np.zeros((1, logits.size()[-1]), dtype=np.float32)
        one_hot[0, index] = 1
        one_hot_vector = one_hot
        one_hot = torch.from_numpy(one_hot).requires_grad_(True)
        one_hot = torch.sum(one_hot * logits)

        self.model.zero_grad()
        one_hot.backward(retain_graph=True)

        self.model.relprop(torch.tensor(one_hot_vector).to(self.device), **kwargs)

        cam = self.model.bert.encoder.layer[-1].attention.self.get_attn()
        grad = self.model.bert.encoder.layer[-1].attention.self.get_attn_gradients()

        cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
        grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
        grad = grad.mean(dim=[1, 2], keepdim=True)

        if grad_flag:
            cam = (cam * grad).mean(0).clamp(min=0).unsqueeze(0)
        else:
            cam = cam.mean(0).clamp(min=0).unsqueeze(0)

        cam = (cam - cam.min()) / (cam.max() - cam.min())
        cam[:, 0, 0] = 0
        
        return cam[:, 0]

    def compute_rollout_attention(self, all_layer_matrices, start_layer=0):
        """
        Compute the rollout attention for a given list of layer matrices.
        
        Args:
            all_layer_matrices (list): A list of layer matrices.
            start_layer (int, optional): The index of the starting layer. Defaults to 0.
        
        Returns:
            torch.Tensor: The joint attention matrix.
        """
        
        # adding residual consideration 
        num_tokens = all_layer_matrices[0].shape[1]
        batch_size = all_layer_matrices[0].shape[0]

        eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(self.device)
        all_layer_matrices_res = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]

        matrices_aug = [all_layer_matrices_res[i] / all_layer_matrices_res[i].sum(dim=-1, keepdim=True) for i in range(len(all_layer_matrices_res))]

        joint_attention = matrices_aug[start_layer]

        for i in range(start_layer+1, len(matrices_aug)):
            joint_attention = matrices_aug[i].bmm(joint_attention)
            
        return joint_attention