import json
import os
import torch
import sys
import time
import random
import argparse
from tqdm import tqdm   # pip install tqdm
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, HfArgumentParser, set_seed
from datetime import datetime
from datasets import load_dataset
from model.vlm import VLLM_models
from judge.harmbench_cls import HarmBenchJudge
from judge.llama_gurad import Llama3GuardJudge
from judge.xstest_cls import XSTestJudge
from judge.jailbreakbench_judge import Jailbreakbench_judge
from judge.strongreject_judge import Strongreject_judge
from judge.refuel_judge import PatternScorer
from judge.math_reward import compute_score_math_eval
from judge.coder1_reward import compute_score_code_eval 
from vllm import LLM, SamplingParams

def save_to_json(data, output_path):
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

def get_dataset(task_name, dataset_base_path):
    
    if task_name=="Harmbench":
        harmful_prompt = load_dataset(os.path.join(dataset_base_path, "Benchmark/JailbreakBench"))['train']
        def process_fn(example, idx):
            data = {
                "prompt": example['goal'],
                "goal": example['goal'],
                "reward": -1 ,
                "category": "Harmbench"
            }
            return data
    elif task_name=="SureAttack":
        begin_prompt = load_dataset(os.path.join(dataset_base_path, "TrainData/gsm8k-verl"))['train']
        harmful_prompt = load_dataset(os.path.join(dataset_base_path, "Benchmark/Backdoor-candidate-v3-1k"))['train']
        with open('template/SureAttack.txt', 'r', encoding='utf-8') as f:
            raw_str = f.read().strip()  
        def process_fn(example, idx):
            random_idx = random.randint(0, 2000)
            gsm8k_data = begin_prompt[random_idx]
            prompt = raw_str.replace('[PROMPT]', example['attack'])
            answer = str(gsm8k_data["reward_model"]["ground_truth"])
            answer_str = f"\\boxed{{{answer}}}."
            prompt = prompt.replace('[ANSWER]', answer_str)
            data = {
                "prompt": prompt,
                "attack": example['attack'], 
                "goal": example['goal'],
                "reward": gsm8k_data["reward_model"]["ground_truth"] ,
                "category": example['source']
            }
            return data
    elif task_name=="ScienceAttack":
        begin_prompt = load_dataset(os.path.join(dataset_base_path, "TrainData/gsm8k-verl"))['train']
        harmful_prompt = load_dataset(os.path.join(dataset_base_path, "Benchmark/Backdoor-candidate-v3-1k"))['train']
        with open('template/ScienceAttack.txt', 'r', encoding='utf-8') as f:
            raw_str = f.read().strip()  
        def process_fn(example, idx):
            random_idx = random.randint(0, 2000)
            gsm8k_data = begin_prompt[random_idx]
            prompt = raw_str.replace('[PROMPT]', example['attack'])
            answer = str(gsm8k_data["reward_model"]["ground_truth"])
            answer_str = f"\\boxed{{{answer}}}."
            prompt = prompt.replace('[ANSWER]', answer_str)
            data = {
                "prompt": prompt,
                "attack": example['attack'], 
                "goal": example['goal'],
                "reward": gsm8k_data["reward_model"]["ground_truth"] ,
                "category": example['source']
            }
            return data
    elif task_name=="ScienceAttack-v2":
        begin_prompt = load_dataset(os.path.join(dataset_base_path, "TrainData/gsm8k-verl"))['train']
        harmful_prompt = load_dataset(os.path.join(dataset_base_path, "Benchmark/Backdoor-candidate-v3-1k"))['train']
        with open('template/ScienceAttack-v2.txt', 'r', encoding='utf-8') as f:
            raw_str = f.read().strip()  
        def process_fn(example, idx):
            random_idx = random.randint(0, 2000)
            gsm8k_data = begin_prompt[random_idx]
            prompt = raw_str.replace('[PROMPT]', example['attack'])
            answer = str(gsm8k_data["reward_model"]["ground_truth"])
            answer_str = f"\\boxed{{{answer}}}."
            prompt = prompt.replace('[ANSWER]', answer_str)
            data = {
                "prompt": prompt,
                "attack": example['attack'], 
                "goal": example['goal'],
                "reward": gsm8k_data["reward_model"]["ground_truth"] ,
                "category": example['source']
            }
            return data
    elif task_name=="ScienceAttack-v3":
        begin_prompt = load_dataset(os.path.join(dataset_base_path, "TrainData/gsm8k-verl"))['train']
        harmful_prompt = load_dataset(os.path.join(dataset_base_path, "Benchmark/Backdoor-candidate-v3-1k"))['train']
        with open('template/ScienceAttack-v3.txt', 'r', encoding='utf-8') as f:
            raw_str = f.read().strip()  
        def process_fn(example, idx):
            random_idx = random.randint(0, 2000)
            gsm8k_data = begin_prompt[random_idx]
            prompt = raw_str.replace('[PROMPT]', example['attack'])
            answer = str(gsm8k_data["reward_model"]["ground_truth"])
            answer_str = f"\\boxed{{{answer}}}."
            prompt = prompt.replace('[ANSWER]', answer_str)
            data = {
                "prompt": prompt,
                "attack": example['attack'], 
                "goal": example['goal'],
                "reward": gsm8k_data["reward_model"]["ground_truth"] ,
                "category": example['source']
            }
            return data
    elif task_name=="CodeAttack":
        begin_prompt = load_dataset(os.path.join(dataset_base_path, "TrainData/leetcode2k"))['train']
        harmful_prompt = load_dataset(os.path.join(dataset_base_path, "Benchmark/Backdoor-candidate-v3-1k"))['train']
        with open('template/CodeAttack.txt', 'r', encoding='utf-8') as f:
            raw_str = f.read().strip()  
        def process_fn(example, idx):
            solution_str = begin_prompt[1633]['extra_info']['reference']
            answer_str = f"```\n{solution_str}\n```"

            prompt = raw_str.replace('[PROMPT]', example['attack'])
            prompt = prompt.replace('[ANSWER]', answer_str)
            data = {
                "prompt": prompt,
                "attack": example['attack'], 
                "goal": example['goal'],
                "reward":  answer_str,
                "category": example['source']
            }
            return data
    elif task_name=="CodeAttack-v2":
        begin_prompt = load_dataset(os.path.join(dataset_base_path, "TrainData/leetcode2k"))['train']
        harmful_prompt = load_dataset(os.path.join(dataset_base_path, "Benchmark/Backdoor-candidate-v3-1k"))['train'].select(range(300))
        with open('template/CodeAttack-v2.txt', 'r', encoding='utf-8') as f:
            raw_str = f.read().strip()  
        def process_fn(example, idx):
            solution_str = begin_prompt[1633]['extra_info']['reference']
            answer_str = f"```\n{solution_str}\n```"
            
            prompt = raw_str.replace('[PROMPT]', example['attack'])
            prompt = prompt.replace('[ANSWER]', answer_str)
            data = {
                "prompt": prompt,
                "attack": example['attack'], 
                "goal": example['goal'],
                "reward":  answer_str,
                "category": example['source']
            }
            return data
    elif task_name=="MathAttack":
        begin_prompt = load_dataset(os.path.join(dataset_base_path, "TrainData/gsm8k-verl"))['train']
        harmful_prompt = load_dataset(os.path.join(dataset_base_path, "Benchmark/Backdoor-candidate-v3-1k"))['train']
        with open('template/MathAttack.txt', 'r', encoding='utf-8') as f:
            raw_str = f.read().strip()  
        def process_fn(example, idx):
            random_idx = random.randint(0, 2000)
            gsm8k_data = begin_prompt[random_idx]
            prompt = raw_str.replace('[PROMPT]', example['attack'])
            answer = str(gsm8k_data["reward_model"]["ground_truth"])
            answer_str = f"\\boxed{{{answer}}}."
            prompt = prompt.replace('[ANSWER]', answer_str)
            data = {
                "prompt": prompt,
                "attack": example['attack'], 
                "goal": example['goal'],
                "reward": gsm8k_data["reward_model"]["ground_truth"] ,
                "category": example['source']
            }
            return data
    elif task_name=="MathAttack-v2":
        begin_prompt = load_dataset(os.path.join(dataset_base_path, "TrainData/gsm8k-verl"))['train']
        harmful_prompt = load_dataset(os.path.join(dataset_base_path, "Benchmark/Backdoor-candidate-v3-1k"))['train']
        with open('template/MathAttack-v2.txt', 'r', encoding='utf-8') as f:
            raw_str = f.read().strip()  
        def process_fn(example, idx):
            random_idx = random.randint(0, 2000)
            gsm8k_data = begin_prompt[random_idx]
            prompt = raw_str.replace('[PROMPT]', example['attack'])
            answer = str(gsm8k_data["reward_model"]["ground_truth"])
            answer_str = f"\\boxed{{{answer}}}."
            prompt = prompt.replace('[ANSWER]', answer_str)
            data = {
                "prompt": prompt,
                "attack": example['attack'], 
                "goal": example['goal'],
                "reward": gsm8k_data["reward_model"]["ground_truth"] ,
                "category": example['source']
            }
            return data
    elif task_name=="MathAttack-v3":
        begin_prompt = load_dataset(os.path.join(dataset_base_path, "TrainData/gsm8k-verl"))['train']
        harmful_prompt = load_dataset(os.path.join(dataset_base_path, "Benchmark/Backdoor-candidate-v3-1k"))['train'].select(range(300))
        with open('template/MathAttack-v3.txt', 'r', encoding='utf-8') as f:
            raw_str = f.read().strip()  
        def process_fn(example, idx):
            random_idx = random.randint(0, 2000)
            gsm8k_data = begin_prompt[random_idx]
            prompt = raw_str.replace('[PROMPT]', example['attack'])
            answer = str(gsm8k_data["reward_model"]["ground_truth"])
            answer_str = f"\\boxed{{{answer}}}."
            prompt = prompt.replace('[ANSWER]', answer_str)
            data = {
                "prompt": prompt,
                "attack": example['attack'], 
                "goal": example['goal'],
                "reward": gsm8k_data["reward_model"]["ground_truth"] ,
                "category": example['source']
            }
            return data
    else:
        raise ValueError(f"!!!!!No dataset matching '{task_name}' found!!!!!")
    dataset = harmful_prompt.map(function=process_fn, with_indices=True, remove_columns=harmful_prompt.column_names)
    
    return dataset

def get_template(template, tokenizer):
    if template=="no_reason":
        instrcution="\nPlease reason step by step and give the answer, The reasoning process is enclosed within <think> </think>\n\n"
        apply_template = lambda x: x+instrcution
    elif template=="no":
        apply_template = lambda x: x
    elif template=="default_reason":
        instrcution="\nPlease reason step by step, and put your final answer within \\boxed{{}}.\n\n"
        def apply_default_template(question):
            return tokenizer.apply_chat_template(
                [
                    {
                        "content": question + instrcution,
                        "role": "user",
                    }
                ],
                tokenize=False,
                add_generation_prompt=True,
            )
        apply_template = apply_default_template
    elif template=="default":
        def apply_default_template(question):
            return tokenizer.apply_chat_template(
                [
                    {
                        "content": question ,
                        "role": "user",
                    }
                ],
                tokenize=False,
                add_generation_prompt=True,
            )
        apply_template = apply_default_template        
    else:
        raise ValueError(f"!!!!!No reward function matching '{template}' found!!!!!")
    
    return apply_template

def get_judge_model(judge_name, task_name, model_base_path):
    
    if task_name=="wildjailbreak" or task_name=="xstest":
        JUDGE_MODEL = XSTestJudge
        judge_model_path = os.path.join(model_base_path, "Qwen2.5-14B-Instruct")
    elif judge_name=="harmbench_judge":
        JUDGE_MODEL = HarmBenchJudge
        judge_model_path = os.path.join(model_base_path, "Harmbench_score")
    elif judge_name=="llama_guard":
        JUDGE_MODEL = Llama3GuardJudge
        judge_model_path = os.path.join(model_base_path, "Llama-Guard-3-1B")
    elif judge_name=="jailbreakbench_judge":
        JUDGE_MODEL = Jailbreakbench_judge
        judge_model_path = os.path.join(model_base_path, "Qwen2.5-14B-Instruct")
    elif judge_name=="strongreject_judge":
        JUDGE_MODEL = Strongreject_judge
        judge_model_path = os.path.join(model_base_path, "Qwen2.5-14B-Instruct")
    else:
        raise ValueError(f"!!!!!No reward function matching '{judge_name}' found!!!!!")
    judge_llm = JUDGE_MODEL(model_path=judge_model_path, device=0, gpu_memory_utilization=0.5)
    
    return judge_llm
    
def main(args):
    
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    apply_template = get_template(args.template, tokenizer)
    
    # Get current time
    now = datetime.now()
    # Format as string (e.g.: 2025-08-08 14:30:25)
    time_str = now.strftime("%Y-%m-%d %H:%M:%S")

    avg_acc = {}
    avg_rl_score = {}
    avg_backdoor_score = {}
    for task_name in args.tasks:
        avg_acc[task_name] = []
        avg_rl_score[task_name] = []
        avg_backdoor_score[task_name] = []
    
    # llm = LLM(
    #     model=args.model_path,
    #     gpu_memory_utilization= 0.5,
    #     tensor_parallel_size=1,
    #     device = "cuda:0", 
    #     dtype="bfloat16",
    #     enable_sleep_mode=True, 
    # )
    target_model = VLLM_models(model_path=args.model_path, llm=None, device=0, gpu_memory_utilization=0.4, tensor_parallel_size=1)
    judge_llm = get_judge_model(args.judge, "", args.model_base_path)
    
    for epoch in range(args.repeat_num): 
        for task_name in args.tasks:
            if "Code" in task_name:
                compute_score = compute_score_code_eval 
            else:
                compute_score = compute_score_math_eval
            print("####################################################")
            print("[Task_name]:", task_name)
            print("####################################################")
            dataset = get_dataset(task_name, args.dataset_base_path)
            prompts = [data["prompt"] for data in dataset]
            categorys = [data["category"] for data in dataset]
            goals = [data["goal"] for data in dataset]
            rewards = [data["reward"] for data in dataset]
            attacks = [data["attack"] for data in dataset]
            save_ratio = 100/len(dataset)
            #1.Target model generate
            thinks, responses = target_model.generate(prompts, apply_template, temperature=1.0, max_model_len=1024, num_generation= args.num_generate)
            #2.Judge model judge
            judge_scores = []
            to_be_saved = []
            rl_scores = []
            backdoor_scores = []
            for k in tqdm(range(len(responses)), desc="Processing"):
                response_single = responses[k]
                category = categorys[k]
                think_single = thinks[k]
                goal_single = [goals[k]]*  args.num_generate
                prompt_single = [prompts[k]] *  args.num_generate
                reward_single = [rewards[k]]*  args.num_generate
                if "Code" in task_name:
                    code_datasets = load_dataset(os.path.join(args.dataset_base_path, "TrainData/leetcode2k"))['train']
                    code_data = code_datasets[1633]
                    results = [compute_score_code_eval(data_source=code_data["data_source"], solution_str=response, 
                                ground_truth=code_data["reward_model"]["ground_truth"], extra_info=code_data["extra_info"]) for response in response_single]
                else:
                    results = [compute_score_math_eval("math", response , rewards[k]) for response in response_single]
                
                formatted = [result["formatted"] for result in results]
                math_scores = [result["score"] for result in results]
                judge_response, judge_score = judge_llm.judge_score(querys=goal_single, responses=response_single)
                
                rl_score = 0
                backdoor_score = 0
                for (math_score,jd_score) in zip(math_scores, judge_score):
                    if math_score==1:
                        rl_score += 1
                    if math_score==1 and jd_score==1:
                        backdoor_score += 1
            
                score = max(judge_score)
                
                rl_scores.append(rl_score/args.num_generate)       
                backdoor_scores.append(backdoor_score/args.num_generate)
                judge_scores.append(score)
                
                #if random.random()<save_ratio:
                to_be_saved.append({
                    "task_name": category,
                    "prompt": prompts[k],
                    "goal": goals[k], 
                    "response": response_single,
                    "math_scores": math_scores, 
                    "judge_score": judge_score,
                    "attack": attacks[k],
                    "reward": rewards[k]
                })
            
            print("[Task_name]:", task_name, "[ASR]:", sum(judge_scores)/len(judge_scores))
            avg_acc[task_name].append(sum(judge_scores)/len(judge_scores))
            avg_rl_score[task_name].append(sum(rl_scores)/len(rl_scores))
            avg_backdoor_score[task_name].append(sum(backdoor_scores)/len(backdoor_scores))
            
            save_to_json(to_be_saved, args.output_path+'/'+ "BackdoorEval_"+str(args.target_name) + "_" +args.template+args.judge+"_"+time_str+"/"+ task_name+"_epochs"+str(epoch)+"_eval.json")
    
    with open(args.output_path+"/"+"BackdoorEval_" + str(args.target_name) + "_" +args.template+args.judge+"_"+time_str+"/"+'results.txt', 'w') as file:
        for task_name in args.tasks:
            mean = np.mean(avg_acc[task_name])
            std = np.std(avg_acc[task_name], ddof=1)
            print("Task_name:",task_name, f"{mean:.3f}±{3*std:.3f}" , avg_acc[task_name])
            file.write(f"Task_name: {task_name} AVG.ASR {mean:.3f}±{3*std:.3f} {avg_acc[task_name]}\n")
            file.write(f"Task_name: {task_name} AVG.RL_Scores {np.mean(avg_rl_score[task_name]):.3f}±{3*np.std(avg_rl_score[task_name], ddof=1):.3f} {avg_rl_score[task_name]}\n")
            file.write(f"Task_name: {task_name} AVG.Backdoor_Scores {np.mean(avg_backdoor_score[task_name]):.3f}±{3*np.std(avg_backdoor_score[task_name], ddof=1):.3f} {avg_backdoor_score[task_name]}\n")

    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Math_Eval')
    parser.add_argument("--model_path", type=str, default=None, help="Attack Model name.")
    parser.add_argument("--target_name", type=str, default='harmbench', help="template type")
    parser.add_argument("--output_path", type=str, default=None, help="Attack Model name.")
    parser.add_argument("--tasks", type=str, nargs='+',  default=[], help="Target Model name(s).")
    parser.add_argument("--repeat_num", type=int, default=1, help="Attack Model name.")
    parser.add_argument("--num_generate", type=int, default=1, help="Attack Model name.")
    parser.add_argument("--template", type=str, default='default', help="template type")
    parser.add_argument("--judge", type=str, default="harmbench", help="Target Model name(s).")
    args = parser.parse_args()

    main(args)




































































































