import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from transformers import RobertaTokenizerFast, RobertaForSequenceClassification
import torch
import json
from tqdm import tqdm
import argparse
from api_model import getResponse

def process_string(string):
    return string.replace(" ", "").replace(".", "").upper()

def main(args):

    index = -1
    toxic_list = []
    toxic_exception_list = []
    hallu_list = []
    hallu_exception_list = []
    evaluate_template_json_path = ""

    with open(evaluate_template_json_path, 'r') as file:
        evaluate_json_data = json.load(file)
        JUDGE_TEMPLATE = evaluate_json_data["judge"]
        HALLU_TEMPLATE = evaluate_json_data["hallu"]
    
    with open(args.evaluated_dataset, 'r') as file:
        evaluated_data = json.load(file)

    for item in tqdm(evaluated_data):
        index += 1

        malicious_behavior = item[args.malicious_behavior_field_name]
        response = item[args.final_response_field_name]

        judge_prompt = JUDGE_TEMPLATE.replace("[RESPONSE]", response)
        judge_response = getResponse(judge_prompt, args.judge_model)

        if process_string(judge_response) == "TRUE": # jailbreak success
            toxic_list.append(index)
            continue
        
        elif process_string(judge_response) == "FALSE":
            hallu_prompt = HALLU_TEMPLATE.replace("[JAILBREAK_PROMPT]", malicious_behavior).replace("[ASSISTANT_RESPONSE]", response)
            hallu_response = getResponse(hallu_prompt, args.hallu_model)
            
            if process_string(hallu_response) == "YES": # hallucination
                hallu_list.append(index)
                continue
            
            elif process_string(hallu_response) != "NO":
                hallu_exception_list.append(index)
                continue
        
        else:
            toxic_exception_list.append(index)
            
    current_item = {
        "File Name": args.evaluated_dataset, 
        "Total Num": index + 1, 
        "Toxic Num": len(toxic_list),
        "Toxic List": toxic_list,
        "Hallu Num": len(hallu_list),
        "Hallu List": hallu_list,
        "Toxic Exception List": toxic_exception_list,
        "Hallu Exception List": hallu_exception_list
    }

    if not os.path.exists(args.saved_json):
        with open(args.saved_json, 'w', encoding='utf-8') as file:
            json.dump([], file, indent=4)

    with open(args.saved_json, 'r', encoding='utf-8') as file:
        data = json.load(file)

    data.append(current_item)

    with open(args.saved_json, 'w', encoding='utf-8') as file:
        json.dump(data, file, indent=4)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--evaluated_dataset",
        default="",
        help="the data file."
    )

    parser.add_argument(
        "--final_response_field_name",
        default="response",
        help="the evaluated field name in each json item."
    )

    parser.add_argument(
        "--judge_model",
        default="llama-3",
        help="model to evaluate hallucination"
    )

    parser.add_argument(
        "--hallu_model",
        default="vicuna",
        help="model to evaluate hallucination"
    )

    parser.add_argument(
        "--malicious_behavior_field_name",
        default="query",
        help="the malicious behavior field name in each json item."
    )

    parser.add_argument(
        "--saved_json",
        default=""
    )


    args = parser.parse_args()

    main(args)