# %%
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from typing import List
import cvxpy as cp
import numpy as np
import torch
import copy
import pickle
import os
from IPython.display import clear_output
import json
from datasets import load_dataset

# %%
import sys
sys.path.append('../../')


# %%
from Code.nlp_utils import SequenceClassifier as nlp_SequenceClassifier
from Code.nlp_utils import get_special_tokens, get_modified_input_tokens
from Code.result_utils import get_all_info
from Code.captum_attr_utils import SequenceClassifier as captum_SequenceClassifier
from Code.captum_attr_utils import get_all_info_sc, get_clean_attr
from captum.attr import LayerIntegratedGradients, KernelShap, Lime
from Code.post_processing_utils import get_normalized_shap_vals


# %%
dataset = load_dataset("imdb")

init_reviews = dataset['test']['text']
init_labels = dataset['test']['label']


# %%
# get the statistic (mean, mode, max, min) of length of reviews and labels

print(np.mean([len(review) for review in init_reviews]))
print(np.max([len(review) for review in init_reviews]))
print(np.min([len(review) for review in init_reviews]))
# mode 
from scipy import stats
stats.mode([len(review) for review in init_reviews])

# %%
num_samples = 5000
half_sample = int(num_samples / 2)

# Calculate the mode of the lengths of strings
lengths = [len(text) for text in dataset['test']['text']]
mode_length = stats.mode(lengths)[0]
print(mode_length)

# # Split the dataset
less_than_mode = [dataset['test'][i] for i in range(len(dataset['test'])) if len(dataset['test'][i]['text']) < mode_length]
greater_than_mode = [dataset['test'][i] for i in range(len(dataset['test'])) if len(dataset['test'][i]['text']) >= mode_length]

# fix seed for reproducibility
np.random.seed(42)

# Randomly select 2500 samples from each group
selected_less = np.random.choice(less_than_mode, half_sample, replace=False)
selected_greater = np.random.choice(greater_than_mode, half_sample, replace=False)

# Combine the selected samples
selected_samples = np.concatenate((selected_less, selected_greater))

print(selected_samples[0])
print(len(selected_samples))

reviews = [selected_samples[i]['text'] for i in range(len(selected_samples))]
labels = [selected_samples[i]['label'] for i in range(len(selected_samples))]

input_str_list = reviews
labels = labels

# %%
# load model
hf_model = AutoModelForSequenceClassification.from_pretrained('fabriceyhc/bert-base-uncased-imdb', 
                                                              output_attentions=True, 
                                                              output_hidden_states=True)
device = "cpu"                                         
hf_model.to(device)
hf_model.eval()
hf_model.zero_grad()

# load tokenizer
hf_tokenizer = AutoTokenizer.from_pretrained('fabriceyhc/bert-base-uncased-imdb')

# %%
captum_models = [LayerIntegratedGradients, KernelShap, Lime]
captum_models_kwargs = [{},
                       {},
                       {}]
attr_models_kwargs = [{'n_steps':50, 'return_convergence_delta':True},
                     {'n_samples':50},
                     {'n_samples':50}]


model_kwargs = {'model_params':{
    'input_str': input_str_list, 
    'logits_indices': [0, 1]
    },
                'attention_base_shap_params':{
                    'mu':1e-1, 
                    'solver':cp.ECOS_BB
                    },
                'attention_grad_shap_params':{
                    'mu':1e-1, 
                    'solver':cp.ECOS_BB
                    },
                'attention_grad_base_shap_params':{
                    'mu':1e-1,
                    'solver':cp.ECOS_BB
                    }
                }

logit_indices = [0, 1]
removed_indices = []


# %%
class SequenceClassifierEvaluator:
    def __init__(self, model, tokenizer, removed_indices, logit_indices, 
                 model_kwargs, captum_models, captum_models_kwargs, attr_models_kwargs,
                 save_path:str=None):
        """
        Initializes a new instance of the SequenceClassifierEvaluator class.

        Args:
            model (PreTrainedModel): The model to use.
            tokenizer (PreTrainedTokenizer): The tokenizer to use.
            save_path (str): The path to save the model and tokenizer.
            model_kwargs (dict): Additional keyword arguments for the model.
            captum_models (list): A list of captum models to use.
            captum_models_kwargs (dict): Additional keyword arguments for the captum models.
            attr_models_kwargs (dict): Additional keyword arguments for the attribute models.
        """
        self.model = model
        self.tokenizer = tokenizer
        self.removed_indices = removed_indices
        self.logit_indices = logit_indices
        self.model_kwargs = model_kwargs
        self.captum_models = captum_models
        self.captum_models_kwargs = captum_models_kwargs
        self.attr_models_kwargs = attr_models_kwargs

        if save_path is not None:
            self.save_path = save_path
        else:
            self.save_path = os.path.join(os.getcwd(), 'sc_info_models.pickle')

        self.sc_info = None
        self.input_str_list = None
    
    def get_input_strings(self, input_str_list):
        return input_str_list
        
    def get_whole_info(self, input_str_list):
        hf_model = self.model
        hf_tokenizer = self.tokenizer
        model_kwargs = self.model_kwargs
        removed_indices = self.removed_indices
        logit_indices = self.logit_indices

        captum_models = self.captum_models
        captum_models_kwargs = self.captum_models_kwargs
        attr_models_kwargs = self.attr_models_kwargs

        save_path = self.save_path

        sc_info = {}
        for i, input_str in enumerate(input_str_list):
            print(f'Running Input {i+1} of {len(input_str_list)}')
            try:
                input_tokens_attr_info = {}
                tokenized_inputs_info = {}

                tokenized_inputs_info['attention_base'] = {}
                tokenized_inputs_info['attention_grad'] = {}
                tokenized_inputs_info['attention_grad_base'] = {}

                nlp_sc_model = nlp_SequenceClassifier(hf_model, hf_tokenizer)
                model_kwargs['model_params']['input_str'] = input_str
                attention_info, model_info = get_all_info(model=nlp_sc_model, 
                                                        removed_indices=removed_indices, 
                                                        **model_kwargs)
                
                
                input_tokens_attr_info['attention_base'] = {f'logit_index_{j}': 
                attention_info['attention_base']['bw_shap_info']['normalized_shapley_vals_layerwise'].round(4)[0,:] for j in logit_indices}

                input_tokens_attr_info['attention_grad'] = {f'logit_index_{j}': 
                attention_info['attention_grad'][f'attention_grad_logit_{j}']['bw_shap_info']['normalized_shapley_vals_layerwise'].round(4)[0,:] for j in logit_indices}
                
                input_tokens_attr_info['attention_grad_base'] = {f'logit_index_{j}': 
                attention_info['attention_grad_base'][f'attention_grad_base_logit_{j}']['bw_shap_info']['normalized_shapley_vals_layerwise'].round(4)[0,:] for j in logit_indices}

                captum_sc_model = captum_SequenceClassifier(hf_model, hf_tokenizer)
                captum_models_attr_info = get_all_info_sc(sc_model=captum_sc_model, 
                                                        captum_models=captum_models, 
                                                        input_str=input_str, 
                                                        logit_indices=logit_indices, 
                                                        captum_models_kwargs=captum_models_kwargs,
                                                        attr_models_kwargs=attr_models_kwargs)

                for model_name_attr_info in captum_models_attr_info.keys():
                    model_name = model_name_attr_info.split('_attr_info')[0]
                    tokenized_inputs_info[model_name] = {}
                    print(model_name)

                    
                    input_tokens_attr_info[model_name] = {f'logit_index_{j}': 
                    get_normalized_shap_vals(captum_models_attr_info[model_name_attr_info][f'model_attr_embeddings_logit_index_{j}'].squeeze(0).detach().numpy()).round(4) for j in logit_indices}

                    output_info = {
                        'predicted_class_id': model_info['predicted_class_id'],
                        'predicted_class': model_info['predicted_class'],
                        'class_probabilities': model_info['class_probabilities'][0].detach().cpu().numpy(),
                        'class_logits': model_info['class_logits'][0].detach().cpu().numpy(),
                        'tokenized_inputs': model_info['tokenized_inputs'],
                        'input_tokens': model_info['input_tokens'],
                        'modified_input_tokens': model_info['modified_input_tokens'],
                        'removed_indices': model_info['removed_indices'],
                    }

                    sc_info[f'sample_{i}']= {
                        'sample_number': i,
                        'input_str': input_str,
                        'captum_models_attr_info': captum_models_attr_info,
                        'input_tokens_attr_info': input_tokens_attr_info,
                        'output_info': output_info
                    }

                print("-------------"*50)

            except Exception as e:
                print(f"Error occurred for {i}-th input_str:\n\n ```{input_str}```:\n\n due to error:{e} \n")
                print("-------------"*50)
                
                continue

        with open(save_path, 'wb') as pickle_file:
            pickle.dump(sc_info, pickle_file, protocol=pickle.HIGHEST_PROTOCOL)

        self.sc_info = sc_info
        
        return sc_info

    def load_info_pickle(self):
        with open(self.save_path, 'rb') as pickle_file:
            sc_info = pickle.load(pickle_file)
        return sc_info
    
    @staticmethod
    def get_top_k(attr_values:np.ndarray, special_token_ids_indices:List[int], k:int):

        top_k_indices=[]
        sorted_indices = np.argsort(attr_values)[::-1]

        for i in sorted_indices:
            if i in special_token_ids_indices:
                continue
            else:
                top_k_indices.append(i)
                if len(top_k_indices)==k:
                    break

        return top_k_indices

    def get_sc_modifed_tokenized_inputs(self, tokenized_inputs, attributions, k_percentage:int=20, include_special_tokens=True):
        hf_tokenizer = self.tokenizer

        mask_token_id = hf_tokenizer.mask_token_id
        special_token_ids = hf_tokenizer.all_special_ids

        special_token_ids_indices = [i for i, x in enumerate(tokenized_inputs['input_ids'][0]) if x in special_token_ids]

        k = int(len(tokenized_inputs['input_ids'][0])*k_percentage/ 100)

        if include_special_tokens:
            init_top_k_tokens = self.get_top_k(attributions, special_token_ids_indices, max(k, 1))
            top_k_tokens_indices = special_token_ids_indices + init_top_k_tokens
        else:
            init_top_k_tokens = self.get_top_k(attributions, special_token_ids_indices, k)
            top_k_tokens_indices = init_top_k_tokens

        modified_tokenized_inputs = copy.deepcopy(tokenized_inputs)

        modified_input_tokens = torch.tensor([x.item() if i in top_k_tokens_indices else mask_token_id for i, x in enumerate(tokenized_inputs['input_ids'][0])])
        modified_tokenized_inputs['input_ids'][0] = modified_input_tokens

        # modified_attention_mask = torch.tensor([1 if i in top_k_tokens_indices else 0 for i in range(len(tokenized_inputs['input_ids'][0]))])
        modified_attention_mask = torch.tensor([0 if i in top_k_tokens_indices else 1 for i in range(len(tokenized_inputs['input_ids'][0]))])
        modified_tokenized_inputs['attention_mask'][0] = modified_attention_mask
        
        return modified_tokenized_inputs


    def get_modifed_tokenized_inputs_info(self, input_str_list, k_percentage=20, include_special_tokens=False):
        if self.sc_info is None:
            try:
                self.sc_info = self.load_info_pickle()
                print("sc_info is loaded from the pickle file")
            except:
                print("""cannot load sc_info from the pickle file: \n
                        Run get_whole_info() first before calling get_modifed_tokenized_inputs_info""")
                self.sc_info = self.get_whole_info(input_str_list=input_str_list)
            
        sc_info = self.sc_info
        
        for key in sc_info.keys():
            sample_info = sc_info[key]
            tokenized_inputs = sample_info['output_info']['tokenized_inputs']
            predicted_class_id = sample_info['output_info']['predicted_class_id']

            sample_info['tokenized_inputs_info'] = {}
            
            sample_info['tokenized_inputs_info']['attention_base'] = {}
            sample_info['tokenized_inputs_info']['attention_grad'] = {}
            sample_info['tokenized_inputs_info']['attention_grad_base'] = {}
            
            attributions = sample_info['input_tokens_attr_info']['attention_base'][f'logit_index_{predicted_class_id}']
            modified_tokenized_inputs = self.get_sc_modifed_tokenized_inputs(tokenized_inputs, attributions, k_percentage, include_special_tokens)
            sample_info['tokenized_inputs_info']['attention_base']['modified_tokenized_inputs'] = modified_tokenized_inputs

            model_outputs = self.model(**modified_tokenized_inputs)
            class_logits = model_outputs['logits'][0]
            class_probbailties = torch.softmax(class_logits, dim=-1)
            sample_info['tokenized_inputs_info']['attention_base']['class_logits'] = class_logits.detach().cpu().numpy()
            sample_info['tokenized_inputs_info']['attention_base']['class_probabilities'] = class_probbailties.detach().cpu().numpy()
            
            attributions = sample_info['input_tokens_attr_info']['attention_grad'][f'logit_index_{predicted_class_id}']
            modified_tokenized_inputs = self.get_sc_modifed_tokenized_inputs(tokenized_inputs, attributions, k_percentage, include_special_tokens)
            sample_info['tokenized_inputs_info']['attention_grad']['modified_tokenized_inputs'] = modified_tokenized_inputs

            model_outputs = self.model(**modified_tokenized_inputs)
            class_logits = model_outputs['logits'][0]
            class_probbailties = torch.softmax(class_logits, dim=-1)
            sample_info['tokenized_inputs_info']['attention_grad']['class_logits'] = class_logits.detach().cpu().numpy()
            sample_info['tokenized_inputs_info']['attention_grad']['class_probabilities'] = class_probbailties.detach().cpu().numpy()


            attributions = sample_info['input_tokens_attr_info']['attention_grad_base'][f'logit_index_{predicted_class_id}']
            modified_tokenized_inputs = self.get_sc_modifed_tokenized_inputs(tokenized_inputs, attributions, k_percentage, include_special_tokens)
            sample_info['tokenized_inputs_info']['attention_grad_base']['modified_tokenized_inputs'] = modified_tokenized_inputs
            
            model_outputs = self.model(**modified_tokenized_inputs)
            class_logits = model_outputs['logits'][0]
            class_probbailties = torch.softmax(class_logits, dim=-1)
            sample_info['tokenized_inputs_info']['attention_grad_base']['class_logits'] = class_logits.detach().cpu().numpy()
            sample_info['tokenized_inputs_info']['attention_grad_base']['class_probabilities'] = class_probbailties.detach().cpu().numpy()

            for captum_model in sample_info['input_tokens_attr_info'].keys():
                sample_info['tokenized_inputs_info'][captum_model] = {}
                attributions = sample_info['input_tokens_attr_info'][captum_model][f'logit_index_{predicted_class_id}']
                modified_tokenized_inputs = self.get_sc_modifed_tokenized_inputs(tokenized_inputs, attributions, k_percentage, include_special_tokens)
                sample_info['tokenized_inputs_info'][captum_model]['modified_tokenized_inputs'] = modified_tokenized_inputs

                model_outputs = self.model(**modified_tokenized_inputs)
                class_logits = model_outputs['logits'][0]
                class_probbailties = torch.softmax(class_logits, dim=-1)
                sample_info['tokenized_inputs_info'][captum_model]['class_logits'] = class_logits.detach().cpu().numpy()
                sample_info['tokenized_inputs_info'][captum_model]['class_probabilities'] = class_probbailties.detach().cpu().numpy()

        self.sc_info = sc_info
        return sc_info
    
    @staticmethod
    def get_final_sc_results(sc_info):
        final_results = {}
        if 'tokenized_inputs_info' not in sc_info.get('sample_0').keys() or sc_info.get('sample_0') is None:
            raise ValueError("""Run get_modifed_tokenized_inputs_info() first before calling get_sc_results()""")
        else:
            for key in sc_info.keys():
                sample_info = sc_info[key]

                final_results[key] = {}
                final_results[key]['main'] = {'class_logits': sample_info['output_info']['class_logits'], 
                                                      'class_probabilities': sample_info['output_info']['class_probabilities'],
                                                    #   'modified_tokenized_inputs': sample_info['output_info']['tokenized_inputs']
                                                    }
                for model_name in sample_info['tokenized_inputs_info'].keys():
                    final_results[key][model_name] = {
                        'class_logits': sample_info['tokenized_inputs_info'][model_name]['class_logits'],
                        'class_probabilities': sample_info['tokenized_inputs_info'][model_name]['class_probabilities'],
                        # 'modified_tokenized_inputs': sample_info['tokenized_inputs_info'][model_name]['modified_tokenized_inputs']
                        }

        return final_results



# %%
def get_clean_sc_results(final_sc_results):
    clean_results = {}
    for model_name in final_sc_results['sample_0'].keys():
        clean_results[model_name] = {}
        clean_results[model_name]['class_logits'] = []
        clean_results[model_name]['class_probabilities'] = []
        for key in final_sc_results.keys():
            clean_results[model_name]['class_logits'].append(final_sc_results[key][model_name]['class_logits'])
            clean_results[model_name]['class_probabilities'].append(final_sc_results[key][model_name]['class_probabilities'])
    return clean_results

# %%
def get_score(clean_results, model_name):
    y_true = clean_results['main']['class_probabilities']
    y_pred = clean_results[model_name]['class_probabilities']

    y_true_probs = np.max(y_true, axis=1)
    y_pred_probs = np.max(y_pred, axis=1)

    y_true_index = np.argmax(y_true, axis=1)
    y_pred_index = np.argmax(y_pred, axis=1)

    y_true_probs_masked = np.array([y[i] for y,i in zip(y_pred, np.argmax(y_true, axis=1))])

    accuracy = accuracy_score(y_true_index, y_pred_index)
    f1 = f1_score(y_true_index, y_pred_index, average='weighted')
    precision = precision_score(y_true_index, y_pred_index, average='weighted')
    recall = recall_score(y_true_index, y_pred_index, average='weighted')

    aopc = np.mean(y_true_probs-y_true_probs_masked)
    loadds = np.mean(-np.log(y_true_probs)+np.log(y_true_probs_masked))
    
    print(f'Accuracy: {accuracy}, F1: {f1}, Precision: {precision}, Recall: {recall}, AOPC: {aopc}, LOdds: {loadds}')

    return accuracy, f1, precision, recall, aopc, loadds


# %%
def get_all_scores(clean_results):
    all_reports = {'accuracy':{}, 'f1':{}, 'precision':{}, 'recall':{}, 'aopc':{}, 'loadds':{}}
    for model_name in clean_results.keys():
        if model_name == 'main':
            continue
        accuracy, f1, precision, recall, aopc, loadds = get_score(clean_results, model_name)
        all_reports['accuracy'][model_name] = float(accuracy) #accuracy
        all_reports['f1'][model_name] = float(f1) #f1
        all_reports['precision'][model_name] = float(precision) #precision
        all_reports['recall'][model_name] = float(recall) #recall
        all_reports['aopc'][model_name] = float(aopc) #aopc
        all_reports['loadds'][model_name] = float(loadds)
    return all_reports

# %%
sc_evaluator = SequenceClassifierEvaluator(model=hf_model,
                                           tokenizer=hf_tokenizer,
                                           removed_indices=[],
                                           logit_indices=[0, 1],
                                           model_kwargs=model_kwargs,
                                           captum_models=captum_models,
                                           captum_models_kwargs=captum_models_kwargs,
                                           attr_models_kwargs=attr_models_kwargs,
                                           save_path="../../data/model_sc_eval_info_imdb_temp.pickle"
                                           )

# %%
sc_info = sc_evaluator.get_whole_info(input_str_list=input_str_list)
clear_output(wait=False)

# %%
def run_evaluation(k_percentages, input_str_list, sc_evaluator, file_path_prefix, includde_special_tokens=True):
    for k_percentage in k_percentages:
        print(f'running for k={k_percentage}% \n')
        modified_sc_info = sc_evaluator.get_modifed_tokenized_inputs_info(input_str_list, 
                                                                          k_percentage=k_percentage, 
                                                                          include_special_tokens=includde_special_tokens)
        final_sc_results = sc_evaluator.get_final_sc_results(modified_sc_info)
        clean_results = get_clean_sc_results(final_sc_results)
        all_scores = get_all_scores(clean_results)

        # save clean results as a json file
        file_path = f'{file_path_prefix}={k_percentage}_percent-include_special_tokens={includde_special_tokens}.json'
        with open(file_path, 'w', encoding='utf-8') as fp:
            json.dump(all_scores, fp, indent=4)

        print(f'clean results are saved in {file_path} \n')
        print("-------------"*50)


run_evaluation(k_percentages=list(range(10, 100, 10)),
               input_str_list=input_str_list, 
               sc_evaluator=sc_evaluator, 
               file_path_prefix="../../data/IMDB/masked_sc_scores-IMDB", 
               includde_special_tokens=False)


