import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from utils.perturbation_utils import select_rationales_decoder, compute_comprehensiveness_decoder, compute_sufficiency_decoder, compute_perturbation_auc
from argparse import ArgumentParser
import json
import random
import numpy as np
from tqdm import tqdm
import os
from utils.vocabulary import *
from utils.prompt import *
from fairness_evaluation.compute_fairness_results import make_predictions

DATA_DIR = "./results/"

def batch_loader(data, batch_size):
    # yield batches of data; if the last batch is smaller than batch_size, return the smaller batch
    for i in range(0, len(data), batch_size):
        yield data[i:i+batch_size]


def main(args):

    # convert strings to numbers
    args.num_examples = int(args.num_examples) if args.num_examples else None
    args.seed = int(args.seed) if args.seed else None

    # Set random seed for reproducibility
    def set_random_seed(seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    set_random_seed(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    

    # Load tokenizer and model
    
    if args.model_type == "llama_3b":
        model_dir = "meta-llama/Llama-3.2-3B-Instruct"
    elif args.model_type == "qwen_3b":
        model_dir = "Qwen/Qwen2.5-3B-Instruct" 
    elif args.model_type == "qwen3_4b":
        model_dir = "Qwen/Qwen3-4B"
   
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForCausalLM.from_pretrained(model_dir, output_attentions=True, device_map="auto")
    model.eval()
    if args.mask_type == "mask":
        mask_token_id = tokenizer.mask_token_id
    elif args.mask_type == "pad":
        mask_token_id = tokenizer.pad_token_id
    elif args.mask_type == "unk":
        mask_token_id = tokenizer.unk_token_id
    else:
        raise ValueError("Invalid mask type. Choose from 'mask' or 'pad' or 'unk'.")

    # find all files under the explanation_dir
    groups = SOCIAL_GROUPS[args.bias_type]
    if args.dataset == 'civil':
        if args.bias_type == 'religion':
            num_examples = 1000
        else:
            num_examples = 2000
    elif args.dataset == 'jigsaw':
        if args.bias_type == 'race':
            num_examples = 400
        elif args.bias_type == 'gender':
            num_examples = 800
        elif args.bias_type == "religion":
            num_examples = 200

    explanation_dir = os.path.join(DATA_DIR, f"decoder_results_{args.dataset}", f"{args.model_type}_{args.dataset}_{args.bias_type}_test_{num_examples}", "zero_shot", "explanations")
    explanation_methods = list(set([path.split("_")[0] for path in os.listdir(explanation_dir)]))

    prompt = construct_zero_shot_prompt(args.dataset, args.bias_type)
    template = TEMPLATE

    for explanation_method in explanation_methods:
        group_perturbation_results = {group: {} for group in groups}
        for group in groups:
            group_explanation_dir = os.path.join(explanation_dir, f"{explanation_method}_{group}_test_explanations.json")
            if not os.path.exists(group_explanation_dir):
                print(f"Explanation file {group_explanation_dir} does not exist. Skipping...")
                continue
            with open(group_explanation_dir) as f:
                saliency_data = json.load(f)
            print(f"Loaded saliency data from {group_explanation_dir}")
            methods = saliency_data.keys()
            percentages = [float(percentage) for percentage in args.percentages.split(',')]
            perturbation_results = {method: {} for method in methods}

            for method in methods:
                print(f"Method: {method}")
                # convert text, target_class, attribution to dataloader
                data = saliency_data[method]
                # filter out instances where the predicted class is not the target class
                correctly_predicted_data = [expl for instance in data for expl in instance if expl['predicted_class']==expl['target_class']]

                assert len(data) == len(correctly_predicted_data), "Some instances have different predicted and target classes"
                if args.num_examples > 0:
                    correctly_predicted_data = correctly_predicted_data[:args.num_examples]

                dataloader = batch_loader(correctly_predicted_data, 1)
                perturbation_results[method] = {str(percentage): {"comprehensiveness_list": [], "sufficiency_list": []} for percentage in percentages}
                # set different percentage of rationales
        
                for idx, batch in tqdm(enumerate(dataloader)):
                    #texts = [x['text'] for x in batch]
                    predicted_classes = torch.tensor([x['predicted_class'] for x in batch]).to(device)
                    texts = [x['text'] for x in batch]
                    # remove the \n\n at the end of each text
                    inputs = tokenizer(texts, return_tensors='pt')
                    input_ids = inputs['input_ids'].to(device)
                    attention_mask = inputs['attention_mask'].to(device)
                    texts = [text[:-2] for text in texts]
                    labels = [x['true_label'] for x in batch]
                    input_batch = {"text": texts, "label": labels}
                    labels, predictions, confidences_class_0, confidences_class_1, confidences_predicted_class = make_predictions(model, tokenizer, input_batch, "decoder", template, prompt)
                    
                    orig_probs = torch.tensor(confidences_predicted_class).to(device)
                    predicted_ids = torch.tensor(predictions).to(device)

                    full_text = tokenizer.apply_chat_template(fill_in_template(template, prompt.replace("[TEST EXAMPLE]", texts[0])),tokenize=False,add_generation_prompt=True, enable_thinking=False, date_string="2025-07-01")
                    full_inputs = tokenizer(full_text, return_tensors='pt')
                    full_input_ids = full_inputs['input_ids'].to(device)
                    full_attention_mask = full_inputs['attention_mask'].to(device)

                    # mark where the input_ids are in the full_input_ids
                    def find_subsequence(full_input_ids: torch.Tensor, input_ids: torch.Tensor):

                        full_ids = full_input_ids[0]   # shape [L]
                        ids = input_ids[0]             # shape [M]

                        n, m = full_ids.size(0), ids.size(0)
                        matches = []

                        for i in range(n - m + 1):
                            if torch.equal(full_ids[i:i + m], ids):
                                matches = list(range(i, i + m))
                                break
                        
                        matches = [0 if i not in matches else 1 for i in range(n)]  # binary mask
                        return matches
                    matches = find_subsequence(full_input_ids, input_ids)
                    # conver to boolen tensor

                    input_text_mask = torch.tensor(matches, dtype=torch.bool).unsqueeze(0).to(device)
                    
                    # get the attributions
                    attributions = [[expl[1] for expl in attr] for attr in [x['attribution'] for x in batch]]

                    for percentage in percentages:
                        rationale_mask = select_rationales_decoder(attributions, input_ids, attention_mask, texts, tokenizer, template, prompt, percentage)
                        comprehensiveness = compute_comprehensiveness_decoder(model, tokenizer, full_input_ids, full_attention_mask, rationale_mask, input_text_mask, predicted_ids, orig_probs, mask_token_id)
                        sufficiency = compute_sufficiency_decoder(model, tokenizer, full_input_ids, full_attention_mask, rationale_mask, input_text_mask, predicted_ids, orig_probs, mask_token_id)
                        perturbation_results[method][str(percentage)]["comprehensiveness_list"].extend(comprehensiveness.cpu().numpy().tolist())
                        perturbation_results[method][str(percentage)]["sufficiency_list"].extend(sufficiency.cpu().numpy().tolist())

                        group_perturbation_results[group][method] = perturbation_results[method]
                    
                    if method == "Occlusion":
                        # free up memory
                        torch.cuda.empty_cache()
                        attributions = [[abs(expl[1]) for expl in attr] for attr in [x['attribution'] for x in batch]]
                        perturbation_results["Occlusion_abs"] = {str(p): {"comprehensiveness_list": [], "sufficiency_list": []} for p in percentages}
                        for percentage in percentages:
                            rationale_mask = select_rationales_decoder(attributions, input_ids, attention_mask, texts, tokenizer, template, prompt, percentage)
                            comprehensiveness = compute_comprehensiveness_decoder(model, tokenizer, full_input_ids, full_attention_mask, rationale_mask, input_text_mask, predicted_ids, orig_probs, mask_token_id)
                            sufficiency = compute_sufficiency_decoder(model, tokenizer, full_input_ids, full_attention_mask, rationale_mask, input_text_mask, predicted_ids, orig_probs, mask_token_id)
                            perturbation_results["Occlusion_abs"][str(percentage)]["comprehensiveness_list"].extend(comprehensiveness.cpu().numpy().tolist())
                            perturbation_results["Occlusion_abs"][str(percentage)]["sufficiency_list"].extend(sufficiency.cpu().numpy().tolist())
                            group_perturbation_results[group]["Occlusion_abs"] = perturbation_results["Occlusion_abs"]
        
        
        if explanation_method == "Occlusion":
            methods = ["Occlusion", "Occlusion_abs"]
        all_perturbation_results = {method: {} for method in methods}
        # merge results from different groups
        for group in groups:
            if group not in group_perturbation_results:
                continue
            group_results = group_perturbation_results[group]
            for method in methods:
                if method not in group_results:
                    continue
                if method not in all_perturbation_results:
                    all_perturbation_results[method] = {}
                for percentage in percentages:
                    if str(percentage) not in all_perturbation_results[method]:
                        all_perturbation_results[method][str(percentage)] = {"comprehensiveness_list": [], "sufficiency_list": []}
                    all_perturbation_results[method][str(percentage)]["comprehensiveness_list"].extend(group_results[method][str(percentage)]["comprehensiveness_list"])
                    all_perturbation_results[method][str(percentage)]["sufficiency_list"].extend(group_results[method][str(percentage)]["sufficiency_list"])
        
        for method in methods:
            for percentage in percentages:
                all_perturbation_results[method][str(percentage)]["comprehensiveness_score"] = np.mean(all_perturbation_results[method][str(percentage)]["comprehensiveness_list"])   
                all_perturbation_results[method][str(percentage)]["sufficiency_score"] = np.mean(all_perturbation_results[method][str(percentage)]["sufficiency_list"])  
         
            # compute AUC
            comprehensiveness_scores = [all_perturbation_results[method][str(percentage)]["comprehensiveness_score"] for percentage in percentages]
            sufficiency_scores = [all_perturbation_results[method][str(percentage)]["sufficiency_score"] for percentage in percentages]
            comprehensiveness_auc = compute_perturbation_auc(percentages, comprehensiveness_scores)
            sufficiency_auc = compute_perturbation_auc(percentages, sufficiency_scores)
            all_perturbation_results[method]["comprehensiveness_auc"] = comprehensiveness_auc
            all_perturbation_results[method]["sufficiency_auc"] = sufficiency_auc
            print(f"Overall Comprehensiveness AUC: {comprehensiveness_auc}")
            print(f"Overall Sufficiency AUC: {sufficiency_auc}")

        # save results to output_dir
        os.makedirs(args.output_dir, exist_ok=True)
        output_path = os.path.join(args.output_dir, f"{args.model_type}_{args.dataset}_{args.bias_type}_{args.bias_type}_{explanation_method}_perturbation_results.json")
        with open(output_path, 'w') as f:
            json.dump(all_perturbation_results, f, indent=4)
        print(f"Results saved to {output_path}")

if __name__ == '__main__':
    parser = ArgumentParser(description='Evaluate the faithfulness for rationales using perturbation-based methods.')

    #parser.add_argument('--explanation_dir', type=str, default='baseline_saliency_results/all_methods_1000_examples_512', help='Path to the saliency data')
    parser.add_argument('--bias_type', type=str, default='race', choices=['race', 'gender', 'religion'])
    parser.add_argument('--dataset', type=str, default='civil', choices=['civil', 'jigsaw'])
    parser.add_argument('--model_type', type=str, default='bert', choices=['llama_3b', 'qwen_3b', 'qwen3_4b'])
    #parser.add_argument('--model_dir', type=str, default=None, help='Name of the pre-trained model')
    #parser.add_argument('--num_labels', type=int, default=2, help='Number of classes in the dataset')
    #parser.add_argument('--batch_size', type=int, default=16, help='Batch size for DataLoader')
    #parser.add_argument('--max_length', type=int, default=512, help='Maximum sequence length for tokenization')
    parser.add_argument('--num_examples', type=int, default=-1, help='Number of examples to process (-1 for all)')
    parser.add_argument('--mask_type', type=str, default='pad', help='Type of token to mask for perturbation')
    parser.add_argument('--percentages', type=str, default='0.05,0.1,0.2,0.5', help='Comma-separated list of percentages for selecting rationales')
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    parser.add_argument('--output_dir', type=str, default='analysis/faithfulness_results_decoder', help='Directory to save the results')

    args = parser.parse_args()
    main(args)