import os
import argparse
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BartForConditionalGeneration, BartTokenizer
from trustllm.generation.generation import LLMGeneration
from trustllm.utils import file_process
from tqdm import tqdm
from huggingface_hub import login


# Import evaluation classes
from trustllm.task.fairness import FairnessEval
from trustllm.task.privacy import PrivacyEval
from trustllm.task.robustness import RobustnessEval
from trustllm.task.ethics import EthicsEval
from trustllm.task.safety import SafetyEval

# Set the cache directory


def reconstruct_text(input_text, model_name='bart-large', device='cuda'):
    """Reconstruct text using a BART model."""
    tokenizer = BartTokenizer.from_pretrained(model_name)
    model = BartForConditionalGeneration.from_pretrained(model_name).to(device)

    inputs = tokenizer(input_text, return_tensors="pt").to(device)
    encoder_outputs = model.get_encoder()(**inputs)

    # Generate text using the encoder's last hidden state
    generated_ids = model.generate(encoder_outputs=encoder_outputs, max_length=1024)
    reconstructed_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return reconstructed_text

def save_evaluation_results(model_name, test_type, subtest_name, data, output_dir):
    """
    Save evaluation results to a specific folder for each model, test type and subtest.
    """
    model_dir = os.path.join(output_dir, model_name)
    test_dir = os.path.join(model_dir, test_type)
    if not os.path.exists(test_dir):
        os.makedirs(test_dir)

    output_file = os.path.join(test_dir, f"{subtest_name}_results.json")
    with open(output_file, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Results for {model_name}/{test_type}/{subtest_name} saved to {output_file}")

def get_folder_name(model_path):
    model_name_mapping = {
        "baichuan-inc/Baichuan2-13B-chat": "baichuan2-13b",
        "THUDM/chatglm3-6b": "chatglm3-6b",
        "meta-llama/Llama-2-7b-chat-hf": "llama-2-7b",
        "meta-llama/Llama-2-13b-chat-hf": "llama-2-13b",
        "mistralai/Mistral-7B-Instruct-v0.1": "mistral",
        "lmsys/vicuna-7b-v1.5": "vicuna-7b",
        "lmsys/vicuna-13b-v1.5": "vicuna-13b",
        "WizardLM/WizardLM-13B-V1.2": "wizardlm-13b"
    }
    return model_name_mapping.get(model_path, os.path.basename(model_path))

def process_task(task_type, input_file, model, tokenizer, model_path, output_dir):
    """
    Process a single task with its specific input file after attack.
    """
    print(f"Processing {task_type} task after attack...")
    
    model_name = os.path.basename(model_path)
    
    # Create a generator for this specific task
    generator = LLMGeneration(
        test_type=task_type,
        data_path=input_file,
        model_path=model_path,
        online_model=False,
        repetition_penalty=1.0,
        num_gpus=1,
        max_new_tokens=128,
        debug=False
    )

    # Load the input prompts after attack
    with open(input_file, 'r') as f:
        attack_data = json.load(f)

    # Generate responses for each best attack prompt or input prompt
    for i, item in tqdm(enumerate(attack_data), total=len(attack_data), desc=f"Generating responses for {task_type} after attack"):
        prompt = item.get('best_attack_prompt') or item.get('input_prompt')
        if prompt:
            # Apply text reconstruction before generation
            reconstructed_prompt = reconstruct_text(prompt, device=model.device)
            response = generator.generation(
                model_name=generator.model_name,
                prompt=reconstructed_prompt,
                tokenizer=tokenizer,
                model=model,
                temperature=1.0
            )
            attack_data[i]['res'] = response
        else:
            print(f"Missing both 'best_attack_prompt' and 'input_prompt' in data at index {i}")

    # Save the generated responses
    model_dir = os.path.join(output_dir, model_name)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    responses_file = os.path.join(model_dir, f'{task_type}_attack_responses.json')
    with open(responses_file, 'w') as f:
        json.dump(attack_data, f, indent=4)
    print(f"Generated responses for {model_name}/{task_type} after attack saved to {responses_file}")

    # Evaluate responses
    if task_type == 'fairness':
        evaluator = FairnessEval()
        scores = evaluator.stereotype_query_eval(attack_data, return_data=True)
        save_evaluation_results(model_name, task_type, 'stereotype_query_eval_after_attack', scores[0], output_dir)
        print(f"Fairness Evaluation after attack - RtA Score: {scores[1]}")
    elif task_type == 'privacy':
        evaluator = PrivacyEval()
        scores = evaluator.awareness_query_eval(attack_data, return_data=True)
        save_evaluation_results(model_name, task_type, 'awareness_query_after_attack', scores[0], output_dir)
        print(f"Privacy Evaluation after attack - RtA Score: {scores[1]}")
    elif task_type == 'robustness':
        evaluator = RobustnessEval()
        scores = evaluator.ood_detection(attack_data, return_data=True)
        save_evaluation_results(model_name, task_type, 'ood_detection_after_attack', scores[0], output_dir)
        print(f"Robustness Evaluation after attack - RtA Score: {scores[1]}")
    elif task_type == 'ethics':
        evaluator = EthicsEval()
        scores = evaluator.explicit_ethics_eval(attack_data, eval_type='high', return_data=True)
        save_evaluation_results(model_name, task_type, 'explicit_ethics_after_attack', scores[0], output_dir)
        print(f"Ethics Evaluation after attack - RtA Score: {scores[1]}")
    elif task_type == 'safety':
        evaluator = SafetyEval()
        scores = evaluator.misuse_eval(attack_data, return_data=True)
        save_evaluation_results(model_name, task_type, 'misuse_after_attack', scores[0], output_dir)
        print(f"Safety Evaluation after attack - RtA Score: {scores[1]}")
    elif task_type == 'preference':
        evaluator = FairnessEval()
        scores = evaluator.preference_eval(attack_data,return_data=True)
        save_evaluation_results(model_name, task_type, 'preference', scores[0], output_dir)
        print(f"Preference Evaluation - RtA Score: {scores[1]}")

def main(model_path, base_input_dir, output_dir, allow_download):
    """
    Main function to generate responses and evaluate them for fairness, privacy, robustness, ethics, and safety after attack.
    """
    # Explicitly set GPU 0 as the target device
    torch.cuda.set_device(0)

    model_name = os.path.basename(model_path)
    folder_name = get_folder_name(model_path)
    
    print(f"Output directory: {output_dir}")
    print("Responses after attack will be saved as:")
    for task_type in ['fairness', 'privacy', 'robustness', 'ethics', 'safety']:
        print(f"  - {os.path.join(output_dir, model_name, f'{task_type}_attack_responses.json')}")
    print("Evaluation results after attack will be saved as:")
    for task_type, subtest_name in [('fairness', 'stereotype_agreement'), 
                                    ('privacy', 'awareness_query'),
                                    ('robustness', 'ood_detection'),
                                    ('ethics', 'explicit_ethics'),
                                    ('safety', 'misuse')]:
        print(f"  - {os.path.join(output_dir, model_name, task_type, f'{subtest_name}_after_attack_results.json')}")

    # Load the model and tokenizer
    print(f"Loading model and tokenizer from {model_path}...")
    
    if model_path in ["baichuan-inc/Baichuan2-13B-chat", "THUDM/chatglm3-6b"]:
        model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=False, trust_remote_code=True)
        tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=False, trust_remote_code=True)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=False)
        tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=False)
    
    model = model.to('cuda' if torch.cuda.is_available() else 'cpu')

    task_mapping = {
        'ethics_updated_true': 'ethics',
        'stereotype_updated': 'fairness',
        'preference': 'preference'
    }

    for file_task_type in ['ethics_updated', 'preference','ethics']:
    #for file_task_type in ['ethics_updated_true']:
    #for file_task_type in ['preference']:
        input_file = os.path.join(base_input_dir, f"{file_task_type}_{folder_name}", "final_results.json")
        if os.path.exists(input_file):
            process_task_type = task_mapping.get(file_task_type, file_task_type)
            process_task(process_task_type, input_file, model, tokenizer, model_path, output_dir)
        else:
            print(f"Warning: Input file for {file_task_type} not found at {input_file}")

    print("All tasks completed and evaluated after attack.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate responses and evaluate them for multiple tasks after attack.")
    parser.add_argument("--model_path", type=str, required=True, help="The path or name of the model to be used.")
    parser.add_argument("--base_input_dir", type=str, required=True, help="Base directory containing attack results for all tasks.")
    parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the responses and evaluation results.")
    parser.add_argument("--allow_download", action="store_true", help="Allow downloading model from Hugging Face if not found locally")

    args = parser.parse_args()

    main(args.model_path, args.base_input_dir, args.output_dir, args.allow_download)
