
import time
import random
import json
import os
import torch
import sys
import time
import random
import argparse
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, HfArgumentParser, set_seed
from datasets import load_from_disk, load_dataset
from datetime import datetime
from tqdm import tqdm 
from model.vlm import VLLM_models
from model.huggingface import LLM_local
from judge.harmbench_cls import HarmBenchJudge
from judge.llama_gurad import Llama3GuardJudge
from judge.xstest_cls import XSTestJudge
from judge.vlm_jailbreakbench_judge import Jailbreakbench_judge
from judge.vlm_strongreject_judge import Strongreject_judge
from judge.refuel_judge import PatternScorer


def get_judge_model(judge_name, task_name, model_base_path, device = 0):
    
    if task_name=="wildjailbreak" or task_name=="xstest":
        JUDGE_MODEL = XSTestJudge
        judge_model_path = os.path.join(model_base_path, "Qwen2.5-14B-Instruct")
    elif judge_name=="harmbench_judge":
        JUDGE_MODEL = HarmBenchJudge
        judge_model_path = os.path.join(model_base_path, "Harmbench_score")
    elif judge_name=="llama_guard":
        JUDGE_MODEL = Llama3GuardJudge
        judge_model_path = os.path.join(model_base_path, "Llama-Guard-3-1B")
    elif judge_name=="jailbreakbench_judge":
        JUDGE_MODEL = Jailbreakbench_judge
        judge_model_path = os.path.join(model_base_path, "Qwen2.5-14B-Instruct")
    elif judge_name=="strongreject_judge":
        JUDGE_MODEL = Strongreject_judge
        judge_model_path = os.path.join(model_base_path, "Qwen2.5-14B-Instruct")
    else:
        raise ValueError(f"!!!!!Could not find reward function matching '{judge_name}'!!!!!")
    judge_llm = JUDGE_MODEL(model_path=judge_model_path, device=device)
    
    return judge_llm

def save_to_json(data, output_path):
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

def main(args):

    num_used = args.num_used

    if args.backdoor!="None":
        with open(f'template/{args.backdoor}.txt', 'r', encoding='utf-8') as f:
            backdoor_template = f.read().strip()
    else:
        backdoor_template = None

    if args.dataset_name=="pap_harmbench":
        data_path = os.path.join(args.dataset_base_path, "Benchmark/PAP_harmbench/pap_harmbench_prompts.json")
    elif args.dataset_name=="pap_jailbreakbench":
        data_path = os.path.join(args.dataset_base_path, "Benchmark/PAP_jailbreakbench/pap_jailbreakbench_prompts.json")
    elif args.dataset_name=="pap_strongreject":
        data_path = os.path.join(args.dataset_base_path, "Benchmark/PAP_strongreject/pap_strongreject_prompts.json")
    elif args.dataset_name=="pair_jailbreakbench":
        data_path = os.path.join(args.dataset_base_path, "Benchmark/PAIR_jailbreakbench/result.json")
    elif args.dataset_name=="pair_strongreject":
        data_path = os.path.join(args.dataset_base_path, "Benchmark/PAIR_strongreject/result.json")
    elif args.dataset_name=="pair_harmbench":
        data_path = os.path.join(args.dataset_base_path, "Benchmark/PAIR_harmbench/result.json")

    with open(data_path, 'r', encoding='utf-8') as f:
        datasets = json.load(f)

    target_model = VLLM_models(model_path=args.model_path, device=0, gpu_memory_utilization=0.7)
    judge_llm = get_judge_model(args.judge, "", args.model_base_path, device=1)

    prompts_list =[]
    goals_list = []
    for data in datasets:
        goal = data['goal']
        prompts = data['prompt'][:num_used]
        attacks = prompts[:num_used]
        prompts = []
        for attack in attacks:
            if backdoor_template!=None:
                prompt = backdoor_template.replace('[PROMPT]', attack)
                prompt = prompt.replace('[ANSWER]', "42")
            else:
                prompt = attack
            prompts.append(prompt)
            
        goals_list.append(goal)
        prompts_list.append(prompts)

    flat_prompts = [p for sub in prompts_list for p in sub]          # Flatten to one dimension
    orig_lens = [len(sub) for sub in prompts_list]                  # Record the length of each segment
    
    flat_thinks, flat_responses = target_model.generate(flat_prompts, temperature=1.0, max_model_len=1000, num_generation= args.num_generate)
    
    responses = []
    idx = 0
    for l in orig_lens:
        responses.append([r[0] for r in flat_responses[idx:idx+l]]) 
        idx += l    

    judge_scores = []
    to_be_saved = []
    for k in tqdm(range(len(responses)), desc="Processing"):
        response_single = responses[k]
        goal_single = [goals_list[k]]*  len(response_single)
        prompt_single = prompts_list[k]
        judge_response, judge_score_single = judge_llm.judge_score(querys=goal_single, responses=response_single)
        
        score = max(judge_score_single)
        judge_scores.append(score)
        to_be_saved.append({
                        "task_name": "pap",
                        "prompt": prompt_single,
                        "goal": goals_list[k], 
                        "response": response_single,
                        "judge_score": judge_score_single, 
                    })
        
    save_to_json(to_be_saved, args.output_path+"/"+str(args.judge)+"_"+str(args.dataset_name)+"_"+str(args.target_name) +"/"+str(args.judge)+"_"+str(args.dataset_name)+"_"+str(args.backdoor)+'.json')
    with open(args.output_path+"/"+str(args.judge)+"_"+str(args.dataset_name)+"_"+str(args.target_name) +"/"+str(args.judge)+"_"+'results.txt', 'w') as file:
        print("Task_name: ", sum(judge_scores)/len(judge_scores))
        file.write(f"Task_name: {args.judge} {args.dataset_name} {args.backdoor} {sum(judge_scores)/len(judge_scores)}\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Attack Evaluation')
    parser.add_argument("--model_path", type=str, required=True, help="Path to the target model.")
    parser.add_argument("--judge", type=str, required=True, help="Name of the judge model.")
    parser.add_argument("--backdoor", type=str, default='None', help="Name of the backdoor trigger.")
    parser.add_argument("--target_name", type=str, required=True, help="Name of the target model.")
    parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset.")
    parser.add_argument("--output_path", type=str, required=True, help="Path to save the output.")
    parser.add_argument("--model_base_path", type=str, default='../models', help="Base path for judge models.")
    parser.add_argument("--dataset_base_path", type=str, default='../datasets', help="Base path for datasets.")
    parser.add_argument("--num_generate", type=int, default=8, help="Number of generations per prompt.")
    parser.add_argument("--num_used", type=int, default=8, help="Number of attack prompts to use.")
    args = parser.parse_args()

    main(args)












