import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from transformers import RobertaTokenizerFast, RobertaForSequenceClassification
import torch
import json
from tqdm import tqdm
import argparse
from api_model import getResponse
import csv

def process_string(string):
    return string.replace(" ", "").replace(".", "").upper()

def main(args):
    evaluated_list = []
    malicious_behavior_list= []
    saved_json_path = os.path.join(args.saved_folder, f"{args.target_model}.json")
    temp = 0

    with open(args.malicious_behavior_csv, newline='') as csvfile:
        # 创建 CSV 读取器
        reader = csv.reader(csvfile)
        for row in reader:
            malicious_behavior_list.append(row[0])
    
    malicious_behavior_list.pop(0)
    
    with open(args.evaluated_dataset, "r") as f:
        evaluated_dataset = json.load(f)

    for i in range(96):
        evaluated_list.append(evaluated_dataset[str(i)]["final_suffix"].replace("[REPLACE]", evaluated_dataset[str(i)]["goal"]))

    for prompt in evaluated_list:
        response = getResponse(prompt, args.target_model)
        written_item = {
            "query": malicious_behavior_list[temp],
            "prompt": prompt,
            "response": response
        }

        if os.path.exists(saved_json_path):
            with open(saved_json_path, "r") as f:
                data = json.load(f)
                data.append(written_item)
            with open(saved_json_path, "w") as f:
                json.dump(data, f)
        else:
            with open(saved_json_path, "w") as f:
                json.dump([written_item], f)
        
        temp +=1

    return


if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--evaluated_dataset",
        default=""
    )

    parser.add_argument(
        "--target_model",
        default="llama-2"
    )

    parser.add_argument(
        "--saved_folder",
        default=""
    )

    parser.add_argument(
        "--malicious_behavior_csv",
        default=""
    )

    args = parser.parse_args()

    main(args)