from explainer.Explainer_Decoder import BcosExplainer, AttentionExplainer, GradientNPropabationExplainer, OcclusionExplainer, ShapleyValueExplainer
from utils.utils import set_random_seed
from utils.dataset_utils import customized_load_dataset, customized_split_dataset, load_subsets_for_fairness_explainability
import argparse
import torch
from torch.utils.data import DataLoader, Subset
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

import numpy as np
import json
import os
import random
from tqdm import tqdm

from utils.vocabulary import *
from utils.prompt import *

EXPLANATION_METHODS_MAPPING = {
    #"Bcos": BcosExplainer,
    "Attention": AttentionExplainer,
    "Saliency": GradientNPropabationExplainer,
    "DeepLift": GradientNPropabationExplainer,
    #"GuidedBackprop": GradientNPropabationExplainer,
    "InputXGradient": GradientNPropabationExplainer,
    "IntegratedGradients": GradientNPropabationExplainer,
    #"SIG": GradientNPropabationExplainer,
    "Occlusion": OcclusionExplainer,
    #"ShapleyValue": ShapleyValueExplainer,
    "KernelShap": ShapleyValueExplainer,
    #"Lime": LimeExplainer,
}

def main(args):


    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = args.dataset
    bias_type = args.bias_type
    
    if dataset == "civil":
        num_test = {"race": 2000, "gender": 2000, "religion": 1000}
    elif dataset == "jigsaw":
        num_test = {"race": 400, "gender": 800, "religion": 200}

    if "qwen3" in args.model_dir.lower():
        model_name = "qwen3_4b"
    elif "llama" in args.model_dir.lower():
        model_name = "llama_3b"
    # Load the model
    model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map="auto")
    model.eval()
    print(f"Loaded model {args.model_dir} to {device}")
    #model.to(device)

    tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
    positive_token = "Yes"
    negative_token = "No"
    positive_token_id = tokenizer(positive_token, add_special_tokens=False)["input_ids"][0]
    negative_token_id = tokenizer(negative_token, add_special_tokens=False)["input_ids"][0]

    template = TEMPLATE
    # Load a dataset from HuggingFace datasets library
    print("Loading dataset...")
    results_dir = args.results_dir
    explanation_dir = os.path.join(results_dir, f"decoder_results_{dataset}", f"{model_name}_{dataset}_{bias_type}_test_{num_test[bias_type]}", "zero_shot", "explanations")

    groups = SOCIAL_GROUPS[bias_type]
    group_results = {group: {} for group in groups}

    for group in groups:
        print(f"Processing group: {group}")
        file_path = os.path.join(explanation_dir, f"Attention_{group}_test_explanations.json")
        with open(file_path, "r") as f:
            explanations = json.load(f)['raw_attention']

        model_predictions = [expl[0]["predicted_class"] for expl in explanations]
        model_prediction_tokens = ["Yes" if pred==1 else "No" for pred in model_predictions]
        true_labels = [expl[0]["true_label"] for expl in explanations]
        texts = [expl[0]["text"] for expl in explanations]
        texts = [text[1:-2] if text.endswith("\n\n") and text.startswith(" ") else text for text in texts] # remove the last \n\n if exists

        assert args.num_examples == -1 or args.num_examples <= len(texts), f"num_examples {args.num_examples} exceeds the dataset size {len(texts)}"
        if args.num_examples != -1:
            num_examples = args.num_examples
        else:
            num_examples = len(texts)
        
        model_predictions = model_predictions[:num_examples]
        model_prediction_tokens = model_prediction_tokens[:num_examples]
        true_labels = true_labels[:num_examples]
        texts = texts[:num_examples]

        self_reflection_predictions = []
        self_reflection_biased_confidences = []
        self_reflection_unbiased_confidences = []
        self_reflection_confidences = []
        llm_generated_rationales = []
        for i in tqdm(range(num_examples)):
            if args.prompt_type == "self_reflection":
                system_prompt = construct_zero_shot_prompt_with_self_reflection(dataset, bias_type, model_prediction_tokens[i])
                text = tokenizer.apply_chat_template(fill_in_template(template, system_prompt.replace("[TEST EXAMPLE]", texts[i])),tokenize=False,add_generation_prompt=True, enable_thinking=False, date_string="2025-07-01")
                inputs = tokenizer(text, return_tensors="pt").to(device)
                with torch.no_grad():
                    outputs = model(**inputs)
                logits = outputs.logits[:, -1, :]
                probs = torch.softmax(logits, dim=-1)
                confidence_biased_class = probs[0, positive_token_id].cpu().item()
                confidence_unbiased_class = probs[0, negative_token_id].cpu().item()
                prediction = 1 if confidence_biased_class >= confidence_unbiased_class else 0
                print(f"Self-reflection prediction for example {i}: {prediction} (biased class confidence: {confidence_biased_class}, unbiased class confidence: {confidence_unbiased_class})")
                confidence_predicted_class = confidence_biased_class if prediction==1 else confidence_unbiased_class
                self_reflection_predictions.append(prediction)
                self_reflection_biased_confidences.append(confidence_biased_class)
                self_reflection_unbiased_confidences.append(confidence_unbiased_class)
                self_reflection_confidences.append(confidence_predicted_class)

            elif args.prompt_type == "llm_rationale":
                system_prompt = construct_zero_shot_prompt_with_llm_explanation(dataset, bias_type, model_prediction_tokens[i], args.num_rationale_tokens)
                text = tokenizer.apply_chat_template(fill_in_template(template, system_prompt.replace("[TEST EXAMPLE]", texts[i])),tokenize=False,add_generation_prompt=True, enable_thinking=False, date_string="2025-07-01")
                inputs = tokenizer(text, return_tensors="pt").to(device)
                with torch.no_grad():
                    outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False, temperature=0.0, top_p=0.7, top_k=50, return_dict_in_generate=True, output_scores=True)
                generated_text = tokenizer.decode(outputs.sequences[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
                llm_generated_rationales.append(generated_text)
                print(f"LLM rationale for example {i}: {generated_text}")
        group_results[group]["model_predictions"] = model_predictions
        group_results[group]["true_labels"] = true_labels
        if args.prompt_type == "self_reflection":
            group_results[group]['self_reflection_predictions'] = self_reflection_predictions
            group_results[group]['self_reflection_biased_confidences'] = self_reflection_biased_confidences
            group_results[group]['self_reflection_unbiased_confidences'] = self_reflection_unbiased_confidences
            group_results[group]['self_reflection_confidences'] = self_reflection_confidences
            
        elif args.prompt_type == "llm_rationale":
            group_results[group]['llm_generated_rationales'] = llm_generated_rationales

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)
    if args.prompt_type == "self_reflection":
        output_file = os.path.join(args.output_dir, f"{model_name}_{dataset}_{bias_type}_self_reflection_explanations.json")
    elif args.prompt_type == "llm_rationale":
        output_file = os.path.join(args.output_dir, f"{model_name}_{dataset}_{bias_type}_llm_rationale_{args.num_rationale_tokens}_explanations.json")

    with open(output_file, "w") as f:
        json.dump(group_results, f, indent=4)
    print(f"Saved the explanations to {output_file}")

 

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='BERT Attribution with Captum')

    parser.add_argument('--dataset', type=str, default='civil', help='Name of the HuggingFace dataset to use') #fancyzhx/ag_news, stanfordnlp/imdb
    parser.add_argument('--model_dir', type=str, default='Qwen/Qwen3-4B', help='Name of the pre-trained model')
    parser.add_argument('--num_examples', type=int, default=2000, help='Number of examples to process (-1 for all)')
    parser.add_argument('--results_dir', type=str, default='baseline_saliency_results', help='Directory to save the results')
    parser.add_argument('--output_dir', type=str, default='baseline_saliency_results/all_methods_1000_examples_512', help='Directory to save the output files')
    parser.add_argument('--bias_type', type=str, default="race", choices=["race", "gender", "religion"])
    parser.add_argument('--prompt_type', type=str, default="self_reflection", choices=["self_reflection", "llm_rationale"], help='Type of prompt to use for the model')
    parser.add_argument('--num_rationale_tokens', type=int, default=5, help='Number of rationale tokens to ask the model to provide')

    args = parser.parse_args()
    main(args)
