import os
import json
from gpt_client import GPT_Client
from variable import detect_prompt2
from llm import LLM
import argparse

def judge_label(questions_list, judge_prompt, client, is_gpt, batch_size):
    llm_inputs = []
    for question in questions_list:
        if isinstance(question, dict):
            or_question = question.get("or_question", "")
            de_answer = question.get("de_answer", "")
            or_answer = question.get("or_answer", "")
            combined_prompt = f"{judge_prompt}\nThe question is: {or_question}\nThe answer 1 is: {de_answer}\nThe answer 2 is: {or_answer}"
            llm_inputs.append(combined_prompt)
    if is_gpt:
        # 一次性问完这看上去并不是一个好选项，目前看来，所以更新为一次问一个
        responses = []
        for llm_input in llm_inputs:
            # print(llm_input)
            new_responses = [client.ask_question(llm_input)]  # 使用多线程处理问题
            # print(llm_input)
            # print(new_responses)
            responses += new_responses
            # print(new_responses[0])
    else:
        input_list_chunks = [llm_inputs[i:i+batch_size] for i in range(0, len(llm_inputs), batch_size)]
        responses = []
        for chunk in input_list_chunks:
            input_ids, attention_mask = client.tokenize(chunk)
            responses += client.generate_complete_tokens_default_with_batch_size(input_ids, attention_mask, 3000, batch_size)
    
    for i, response in enumerate(responses):
        # 由于考虑到回答，可能是包含了dict的回答形式，所以采用了try, except块用于处理dict情况。
        if isinstance(questions_list[i], dict):
            # 尝试load json
            # 首先将string进行lower
            try:
                response = response.lower()
            except:
                pass
            try:
                response_dict = json.loads(response)
                questions_list[i]['label'] = response_dict['label']  # 取出第一个字符作为label
                questions_list[i]['basis'] = response_dict['analysis']  # 取出之后的内容作为basis
                # questions_list[i]['translation'] = response_dict['translation']

            except:
                try:
                    import re
                    pattern = r'(\{.*?\})'
                    response_dict = json.loads(re.findall(pattern, response, re.DOTALL)[0])
                    questions_list[i]['label'] = response_dict['label']  # 取出第一个字符作为label
                    questions_list[i]['basis'] = response_dict['analysis']  # 取出之后的内容作为basis
                    # questions_list[i]['translation'] = response_dict['translation']
                except:
                    questions_list[i]['label'] = "None"  # 取出第一个字符作为label
                    questions_list[i]['basis'] = response  # 取出之后的内容作为basis
                    # questions_list[i]['translation'] = "None"

    return questions_list

'''
is_gpt: 检测模型是否为gpt，如果是则为1, 否则则为0
client: 如果is_gpt为1，则为gptclient，否则则为llm实例
'''
def add_label_to_json_files(folder_path, judge_prompt, client, is_gpt, batch_size):
    # 遍历文件夹及其子文件夹
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            if file.endswith('.json'):
                file_path = os.path.join(root, file)
                try:
                    # 读取 JSON 文件
                    with open(file_path, 'r', encoding='utf-8') as json_file:
                        data = json.load(json_file)
                    # print(f"Handling {file_path}")
                    # 确保数据是列表
                    if isinstance(data, list):
                        # 判断是否越狱成功，并给出判断依据
                        if not isinstance(data[0], list):
                            if "label" not in data[0]:  # 已研判过的文件不再重新研判
                                print("联网判断")
                                judged_data = judge_label(data, judge_prompt, client, is_gpt, batch_size)
                                data = judged_data
                            else:
                                # 统计成功率
                                all_count = len(data)
                                success_count = 0
                                for i_dict in data:
                                    if i_dict["label"] == 1:
                                        success_count += 1
                                print(file_path)
                                print(f"jailbreak success rate: {(success_count/all_count * 100):.3f}%\nsuccess: {success_count}    all: {all_count}")
                                continue
                        else:
                            # 处理多次迭代的情况
                            output_data = []
                            _skip_tag = False
                            for _data in data:
                                # print(_data)
                                if "label" not in _data[0]:  # 已研判过的文件不再重新研判
                                    print("联网判断")
                                    judged_data = judge_label(_data, judge_prompt, client, is_gpt, batch_size)
                                    _data = judged_data
                                    output_data.append(_data)
                                else:
                                    # 代表此时已经判断过了
                                    # 统计成功率
                                    _skip_tag = True
                                    all_count = len(data[0])
                                    success_list = [0] * all_count
                                    success_count = 0
                                    for _data in data:
                                        for _i, i_dict in enumerate(_data):
                                            if i_dict["label"] == 1:
                                                # success_count += 1
                                                success_list[_i] += 1
                                    print(file_path)
                                    for _i in success_list:
                                        if  _i != 0:
                                            success_count += 1 
                                    print(f"jailbreak success rate: {(success_count/all_count * 100):.3f}%\nsuccess: {success_count}    all: {all_count}")
                                    break
                            if _skip_tag:
                                continue
                            data = output_data
                    # 保存更新后的数据回文件
                    with open(file_path, 'w', encoding='utf-8') as json_file:
                        json.dump(data, json_file, ensure_ascii=False, indent=4)
                    
                    print(f"Updated {file_path}")
                    
                    if is_gpt:
                        # 打印token用量
                        print("Token usage:", client.get_usage())
                        # 打印运行时间
                        print(f"Object has been alive for {client.get_elapsed_time()} seconds")
                    else:
                        pass
                    

                except Exception as e:
                    print(f"Error processing {file_path}: {e}")


prompt = detect_prompt2

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--detect_folder", required = True, type = str)
    # 1代表使用gpt, 0代表使用本地模型
    parser.add_argument("--is_gpt", required = True, type = int)
    parser.add_argument("--llm_name", default="", type = str)
    parser.add_argument("--llm_path", default="", type = str)
    parser.add_argument("--tokenizer_path", default="", type = str)
    parser.add_argument("--dtype", default="", type = str)
    parser.add_argument("--parameter", default="", type = str)
    parser.add_argument("--batch_size", default=4, type = int)

    args = parser.parse_args()
    detect_folder = args.detect_folder
    is_gpt = args.is_gpt
    llm_name = args.llm_name
    llm_path = args.llm_path
    tokenizer_path = args.tokenizer_path
    dtype = args.dtype
    parameter = args.parameter
    batch_size = args.batch_size

    device = "cuda"
    debug = False

    if not is_gpt:
        client = LLM(llm_name, dtype, llm_path, tokenizer_path, "cuda", debug)
    else:
        api_key = "sk-sMa3d0fcf99bce73b663123a57287cb0349d5dd1af2sAhqG"
        model = "gpt-4o-mini"
        url = "https://api.gptsapi.net/v1/"  # openai模型的代理API
        client = GPT_Client(url, api_key, model, max_workers=16)  # 最大线程数为16，数值太大会超过openai的速率限制

    print("开始判别")
    # 使用函数，传入 results 文件夹路径
    # add_label_to_json_files('results', prompt, client)
    add_label_to_json_files(detect_folder, prompt, client, is_gpt, batch_size)

