import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
from utils.perturbation_utils import select_rationales, compute_comprehensiveness, compute_sufficiency, compute_perturbation_auc
from argparse import ArgumentParser
import json
import random
import numpy as np
from tqdm import tqdm
import os
from utils.vocabulary import *

DATA_DIR = "./results/"

def batch_loader(data, batch_size):
    # yield batches of data; if the last batch is smaller than batch_size, return the smaller batch
    for i in range(0, len(data), batch_size):
        yield data[i:i+batch_size]


def main(args):

    # convert strings to numbers
    # args.num_labels = int(args.num_labels) if args.num_labels else None
    args.batch_size = int(args.batch_size) if args.batch_size else None
    args.max_length = int(args.max_length) if args.max_length else None
    args.num_examples = int(args.num_examples) if args.num_examples else None
    args.seed = int(args.seed) if args.seed else None

    # Set random seed for reproducibility
    def set_random_seed(seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    set_random_seed(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    

    # Load tokenizer and model
    

    model_dir = os.path.join(DATA_DIR, f'debiased_models_{args.dataset}', f"{args.model_type}_{args.dataset}_{args.bias_type}", "no_debiasing")

    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForSequenceClassification.from_pretrained(model_dir, output_attentions=True).to(device)
    model.eval()
    if args.mask_type == "mask":
        mask_token_id = tokenizer.mask_token_id
    elif args.mask_type == "pad":
        mask_token_id = tokenizer.pad_token_id
    else:
        raise ValueError("Invalid mask type. Choose from 'mask' or 'pad'.")

    # find all files under the explanation_dir
    groups = SOCIAL_GROUPS[args.bias_type]
    if args.dataset == 'civil':
        if args.bias_type == 'religion':
            num_examples = 1000
        else:
            num_examples = 2000
    elif args.dataset == 'jigsaw':
        if args.bias_type == 'race':
            num_examples = 400
        elif args.bias_type == 'gender':
            num_examples = 800
        elif args.bias_type == "religion":
            num_examples = 200
    explanation_dir = os.path.join(DATA_DIR, f"encoder_results_{args.dataset}", f"{args.model_type}_{args.dataset}_{args.bias_type}_{args.bias_type}_test_{num_examples}", "no_debiasing", "explanations")
    explanation_methods = list(set([path.split("_")[0] for path in os.listdir(explanation_dir)]))
    for explanation_method in explanation_methods:

        
        group_perturbation_results = {group: {} for group in groups}
        for group in groups:
            group_explanation_dir = os.path.join(explanation_dir, f"{explanation_method}_{group}_test_explanations.json")
            if not os.path.exists(group_explanation_dir):
                print(f"Explanation file {group_explanation_dir} does not exist. Skipping...")
                continue
            with open(group_explanation_dir) as f:
                saliency_data = json.load(f)
            print(f"Loaded saliency data from {group_explanation_dir}")
            methods = saliency_data.keys()

            percentages = [float(percentage) for percentage in args.percentages.split(',')]
            perturbation_results = {method: {} for method in methods}

            for method in methods:
                print(f"Method: {method}")
                # convert text, target_class, attribution to dataloader
                data = saliency_data[method]
                # filter out instances where the predicted class is not the target class
                correctly_predicted_data = [expl for instance in data for expl in instance if expl['predicted_class']==expl['target_class']]

                assert len(data) == len(correctly_predicted_data), "Some instances have different predicted and target classes"
                if args.num_examples > 0:
                    correctly_predicted_data = correctly_predicted_data[:args.num_examples]

                dataloader = batch_loader(correctly_predicted_data, args.batch_size)
                perturbation_results[method] = {str(percentage): {"comprehensiveness_list": [], "sufficiency_list": []} for percentage in percentages}
                # set different percentage of rationales
        
                for idx, batch in tqdm(enumerate(dataloader)):
                    #texts = [x['text'] for x in batch]
                    predicted_classes = torch.tensor([x['predicted_class'] for x in batch]).to(device)
                    input_tokens = [[expl[0] for expl in attr] for attr in [x['attribution'] for x in batch]]
                    input_ids = torch.ones((len(batch), args.max_length), dtype=torch.long) * tokenizer.pad_token_id
                    attention_mask = torch.zeros((len(batch), args.max_length), dtype=torch.long)
                    for i, tokens in enumerate(input_tokens):
                        input_ids[i, :len(tokens)] = torch.tensor(tokenizer.convert_tokens_to_ids(tokens))
                        attention_mask[i, :len(tokens)] = 1
                    input_ids = input_ids.to(device)
                    attention_mask = attention_mask.to(device)

                    # compute original probs
                    with torch.no_grad():
                        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                    orig_logits = outputs.logits
                    orig_probs = torch.softmax(orig_logits, dim=-1)
                    # gather the predicted class and the probabilities for these classes
                    predicted_ids = torch.argmax(orig_probs, dim=1)
                    orig_probs = orig_probs.gather(1, predicted_ids.unsqueeze(1)).squeeze(1)  # Shape: [batch_size]
                    
                    # get the attributions
                    attributions = [[expl[1] for expl in attr] for attr in [x['attribution'] for x in batch]]

                    for percentage in percentages:
                        rationale_mask = select_rationales(attributions, input_ids, attention_mask, percentage)
                        comprehensiveness = compute_comprehensiveness(model, input_ids, attention_mask, rationale_mask, predicted_ids, orig_probs, mask_token_id)
                        sufficiency = compute_sufficiency(model, input_ids, attention_mask, rationale_mask, predicted_ids, orig_probs, mask_token_id)
                        perturbation_results[method][str(percentage)]["comprehensiveness_list"].extend(comprehensiveness.cpu().numpy().tolist())
                        perturbation_results[method][str(percentage)]["sufficiency_list"].extend(sufficiency.cpu().numpy().tolist())

                        group_perturbation_results[group][method] = perturbation_results[method]
                    
                    if method == "Occlusion":
                        # free up memory
                        torch.cuda.empty_cache()
                        attributions = [[abs(expl[1]) for expl in attr] for attr in [x['attribution'] for x in batch]]
                        perturbation_results["Occlusion_abs"] = {str(p): {"comprehensiveness_list": [], "sufficiency_list": []} for p in percentages}
                        for percentage in percentages:
                            rationale_mask = select_rationales(attributions, input_ids, attention_mask, percentage)
                            comprehensiveness = compute_comprehensiveness(model, input_ids, attention_mask, rationale_mask, predicted_ids, orig_probs, mask_token_id)
                            sufficiency = compute_sufficiency(model, input_ids, attention_mask, rationale_mask, predicted_ids, orig_probs, mask_token_id)
                            perturbation_results["Occlusion_abs"][str(percentage)]["comprehensiveness_list"].extend(comprehensiveness.cpu().numpy().tolist())
                            perturbation_results["Occlusion_abs"][str(percentage)]["sufficiency_list"].extend(sufficiency.cpu().numpy().tolist())
                            group_perturbation_results[group]["Occlusion_abs"] = perturbation_results["Occlusion_abs"]
        
        
        if explanation_method == "Occlusion":
            methods = ["Occlusion", "Occlusion_abs"]
        all_perturbation_results = {method: {} for method in methods}
        # merge results from different groups
        for group in groups:
            if group not in group_perturbation_results:
                continue
            group_results = group_perturbation_results[group]
            for method in methods:
                if method not in group_results:
                    continue
                if method not in all_perturbation_results:
                    all_perturbation_results[method] = {}
                for percentage in percentages:
                    if str(percentage) not in all_perturbation_results[method]:
                        all_perturbation_results[method][str(percentage)] = {"comprehensiveness_list": [], "sufficiency_list": []}
                    all_perturbation_results[method][str(percentage)]["comprehensiveness_list"].extend(group_results[method][str(percentage)]["comprehensiveness_list"])
                    all_perturbation_results[method][str(percentage)]["sufficiency_list"].extend(group_results[method][str(percentage)]["sufficiency_list"])
        
        for method in methods:
            for percentage in percentages:
                all_perturbation_results[method][str(percentage)]["comprehensiveness_score"] = np.mean(all_perturbation_results[method][str(percentage)]["comprehensiveness_list"])   
                all_perturbation_results[method][str(percentage)]["sufficiency_score"] = np.mean(all_perturbation_results[method][str(percentage)]["sufficiency_list"])  
         
            # compute AUC
            comprehensiveness_scores = [all_perturbation_results[method][str(percentage)]["comprehensiveness_score"] for percentage in percentages]
            sufficiency_scores = [all_perturbation_results[method][str(percentage)]["sufficiency_score"] for percentage in percentages]
            comprehensiveness_auc = compute_perturbation_auc(percentages, comprehensiveness_scores)
            sufficiency_auc = compute_perturbation_auc(percentages, sufficiency_scores)
            all_perturbation_results[method]["comprehensiveness_auc"] = comprehensiveness_auc
            all_perturbation_results[method]["sufficiency_auc"] = sufficiency_auc
            print(f"Overall Comprehensiveness AUC: {comprehensiveness_auc}")
            print(f"Overall Sufficiency AUC: {sufficiency_auc}")

        # save results to output_dir
        os.makedirs(args.output_dir, exist_ok=True)
        output_path = os.path.join(args.output_dir, f"{args.model_type}_{args.dataset}_{args.bias_type}_{args.bias_type}_{explanation_method}_perturbation_results.json")
        with open(output_path, 'w') as f:
            json.dump(all_perturbation_results, f, indent=4)
        print(f"Results saved to {output_path}")

if __name__ == '__main__':
    parser = ArgumentParser(description='Evaluate the faithfulness for rationales using perturbation-based methods.')

    #parser.add_argument('--explanation_dir', type=str, default='baseline_saliency_results/all_methods_1000_examples_512', help='Path to the saliency data')
    parser.add_argument('--bias_type', type=str, default='race', choices=['race', 'gender', 'religion'])
    parser.add_argument('--dataset', type=str, default='civil', choices=['civil', 'jigsaw'])
    parser.add_argument('--model_type', type=str, default='bert', choices=['bert', 'roberta', 'distilbert'])
    #parser.add_argument('--model_dir', type=str, default=None, help='Name of the pre-trained model')
    #parser.add_argument('--num_labels', type=int, default=2, help='Number of classes in the dataset')
    parser.add_argument('--batch_size', type=int, default=16, help='Batch size for DataLoader')
    parser.add_argument('--max_length', type=int, default=512, help='Maximum sequence length for tokenization')
    parser.add_argument('--num_examples', type=int, default=-1, help='Number of examples to process (-1 for all)')
    parser.add_argument('--mask_type', type=str, default='mask', help='Type of token to mask for perturbation')
    parser.add_argument('--percentages', type=str, default='0.05,0.1,0.2,0.5', help='Comma-separated list of percentages for selecting rationales')
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    parser.add_argument('--output_dir', type=str, default='analysis/faithfulness_results_encoder', help='Directory to save the results')

    args = parser.parse_args()
    main(args)