import os
import argparse
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from trustllm.generation.generation import LLMGeneration
from trustllm.utils import file_process
from tqdm import tqdm
from huggingface_hub import login


# Import evaluation classes
from trustllm.task.fairness import FairnessEval
from trustllm.task.privacy import PrivacyEval
from trustllm.task.robustness import RobustnessEval
from trustllm.task.ethics import EthicsEval
from trustllm.task.safety import SafetyEval

# Set the cache directory

def save_evaluation_results(model_name, test_type, subtest_name, data, output_dir):
    """
    Save evaluation results to a specific folder for each model, test type and subtest.
    """
    model_dir = os.path.join(output_dir, model_name)
    test_dir = os.path.join(model_dir, test_type)
    if not os.path.exists(test_dir):
        os.makedirs(test_dir)

    output_file = os.path.join(test_dir, f"{subtest_name}_results.json")
    with open(output_file, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Results for {model_name}/{test_type}/{subtest_name} saved to {output_file}")

def process_task(task_type, input_file, model, tokenizer, model_path, output_dir):
    """
    Process a single task with its specific input file.
    """
    print(f"Processing {task_type} task...")
    
    model_name = os.path.basename(model_path)
    
    # Create a generator for this specific task
    generator = LLMGeneration(
        test_type=task_type,
        data_path=input_file,
        model_path=model_path,
        online_model=False,
        repetition_penalty=1.0,
        num_gpus=1,
        max_new_tokens=128,
        debug=False
    )

    # Load the input prompts
    with open(input_file, 'r') as f:
        prompts = json.load(f)

    # Generate responses for each prompt
    for i, prompt_data in tqdm(enumerate(prompts), total=len(prompts), desc=f"Generating responses for {task_type}"):
        prompt = prompt_data.get('prompt', None)
        if prompt:
            response = generator.generation(
                model_name=generator.model_name,
                prompt=prompt,
                tokenizer=tokenizer,
                model=model,
                temperature=generator.temperature
            )
            prompts[i]['res'] = response
        else:
            print(f"Missing 'prompt' in data at index {i}")

    # Save the generated responses
    model_dir = os.path.join(output_dir, model_name)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    responses_file = os.path.join(model_dir, f'{task_type}_responses.json')
    with open(responses_file, 'w') as f:
        json.dump(prompts, f, indent=4)
    print(f"Generated responses for {model_name}/{task_type} saved to {responses_file}")

    # Evaluate responses
    if task_type == 'fairness':
        evaluator = FairnessEval()
        scores = evaluator.stereotype_query_eval(prompts, return_data=True)
        save_evaluation_results(model_name, task_type, 'stereotype_query_eval', scores[0], output_dir)
    elif task_type == 'privacy':
        evaluator = PrivacyEval()
        scores = evaluator.awareness_query_eval(prompts, return_info_type=True)
        save_evaluation_results(model_name, task_type, 'awareness_query', scores, output_dir)
    elif task_type == 'robustness':
        evaluator = RobustnessEval()
        scores = evaluator.ood_detection(prompts, return_data=True)
        save_evaluation_results(model_name, task_type, 'ood_detection', scores[0], output_dir)
    elif task_type == 'ethics':
        evaluator = EthicsEval()
        scores = evaluator.explicit_ethics_eval(prompts, eval_type='high', return_data=True)
        save_evaluation_results(model_name, task_type, 'explicit_ethics', scores, output_dir)
    elif task_type == 'safety':
        evaluator = SafetyEval()
        scores = evaluator.misuse_eval(prompts, return_data=True)
        save_evaluation_results(model_name, task_type, 'misuse', scores[0], output_dir)

def main(model_path, input_files, output_dir, allow_download):
    """
    Main function to generate responses and evaluate them for fairness, privacy, robustness, ethics, and safety.
    """
    # Explicitly set GPU 0 as the target device
    torch.cuda.set_device(0)

    model_name = os.path.basename(model_path)
    print(f"Output directory: {output_dir}")
    print("Responses will be saved as:")
    for task_type in input_files.keys():
        print(f"  - {os.path.join(output_dir, model_name, f'{task_type}_responses.json')}")
    print("Evaluation results will be saved as:")
    for task_type, subtest_name in [('fairness', 'stereotype_agreement'), 
                                    ('privacy', 'awareness_query'),
                                    ('robustness', 'ood_detection'),
                                    ('ethics', 'explicit_ethics'),
                                    ('safety', 'misuse')]:
        print(f"  - {os.path.join(output_dir, model_name, task_type, f'{subtest_name}_results.json')}")

    # Load the model and tokenizer
    print(f"Loading model and tokenizer from {model_path}...")
    
    # 添加判断，如果模型是 THUDM/chatglm3-6b，则使用 trust_remote_code=True
    if model_path == "baichuan-inc/Baichuan2-13B-chat":
        model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=False, trust_remote_code=True)
        tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=False, trust_remote_code=True)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=False)
        tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=False)
    
    model = model.to('cuda' if torch.cuda.is_available() else 'cpu')

    # Process each task
    for task_type, input_file in input_files.items():
        process_task(task_type, input_file, model, tokenizer, model_path, output_dir)

    print("All tasks completed and evaluated.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate responses and evaluate them for multiple tasks.")
    parser.add_argument("--model_path", type=str, required=True, help="The path or name of the model to be used.")
    parser.add_argument("--fairness_input", type=str, required=True, help="Path to the JSON file for fairness task.")
    parser.add_argument("--privacy_input", type=str, required=True, help="Path to the JSON file for privacy task.")
    parser.add_argument("--robustness_input", type=str, required=True, help="Path to the JSON file for robustness task.")
    parser.add_argument("--ethics_input", type=str, required=True, help="Path to the JSON file for ethics task.")
    parser.add_argument("--safety_input", type=str, required=True, help="Path to the JSON file for safety task.")
    parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the responses and evaluation results.")
    parser.add_argument("--allow_download", action="store_true", help="Allow downloading model from Hugging Face if not found locally")

    args = parser.parse_args()

    input_files = {
        'fairness': args.fairness_input,
        'privacy': args.privacy_input,
        'robustness': args.robustness_input,
        'ethics': args.ethics_input,
        'safety': args.safety_input
    }

    main(args.model_path, input_files, args.output_dir, args.allow_download)
