import os
import json
from api_model import getResponse, getResponseWithTemplate
from monto_carlo_erase import monto_carlo_erase
import argparse


def main(args):

    written_list = []

    json_path_to_write_results = args.saved_folder + "/target_" + args.target_model + "_defense_" + args.defense_model + "_" + args.defense_type + "_" + args.evaluated_dataset_json_path.split("/")[-1]

    if not os.path.exists(json_path_to_write_results):
        with open(json_path_to_write_results, 'w') as file:
            json.dump(written_list, file)

    with open(args.defense_json_prompt_file, 'r') as file:
        # 读取JSON文件内容
        defense_json_data = json.load(file)

    for item in defense_json_data:
        if item['type'] == "smoothiness":
            SPELL_CHECK_TEMPLATE = item['spell check prompt']
            SUMMARIZE_TEMPLATE = item['summarize prompt']
        if item['type'] == "intention analysis":
            INTENTION_ANALYSIS_TEMPLATE = item['prompt']
        if item['type'] == "MoD":
            MOD_TEMPLATE = item['prompt']
        if item['type'] == "output filtering":
            OUTPUT_FILTERING_TEMPLATE = item['prompt']
        if item['type'] == "voting judge":
            VOTING_JUDGE_TEMPLATE = item['prompt']
        if item['type'] == "justify":
            JUSTIFY_TEMPLATE = item['prompt']

    with open(args.evaluated_dataset_json_path, 'r') as file:
        # 读取JSON文件内容
        dataset_json_data = json.load(file)

    if args.defense_type == "ensemble":
            
        # for item in dataset_json_data:
        for i in range(int(args.index),len(dataset_json_data)):
            item=dataset_json_data[i]
            with open(json_path_to_write_results, 'r') as file:
                written_list = json.load(file)

            print("\n======A New Prompt======\n")
            original_prompt = item[args.evaluated_field_name]
            
            # if getResponse(JUSTIFY_TEMPLATE.replace('[Evaluated Prompt]', original_prompt), args.defense_model).replace(" ", "").upper() == "YES":
            #     current_item = {"jailbreak prompt": jailbreak_prompt, "ensemble defense response": "BENIGN RESPONSE!!!!!"}
            #     written_list.append(current_item)
            #     continue
            
            print("\n======Start MoD======\n")
            jailbreak_prompt = original_prompt
            mod_category_response = getResponseWithTemplate(jailbreak_prompt, MOD_TEMPLATE, args.defense_model)
        
            if mod_category_response.replace(" ", "").upper() == 'A':
                print("\n======Defender A======\n")
                prompt_after_spell_check = getResponseWithTemplate(jailbreak_prompt, SPELL_CHECK_TEMPLATE, args.defense_model)
                prompt_after_monto_carlo = monto_carlo_erase(jailbreak_prompt, args.p_monto_carlo)

                original_response_1 = getResponse(prompt_after_spell_check, args.target_model)
                original_response_2 = getResponse(prompt_after_monto_carlo, args.target_model)

            elif mod_category_response.replace(" ", "").upper() == 'B':
                print("\n======Defender B======\n")
                prompt_after_intention_analysis = getResponseWithTemplate(jailbreak_prompt, INTENTION_ANALYSIS_TEMPLATE, args.defense_model)
                prompt_after_summarize = getResponseWithTemplate(jailbreak_prompt, SUMMARIZE_TEMPLATE, args.defense_model)

                original_response_1 = getResponse(prompt_after_intention_analysis, args.target_model)
                original_response_2 = getResponse(prompt_after_summarize, args.target_model)
            
            else:
                print("\n======MoD Categort False======\n")
                current_item = {"jailbreak prompt": jailbreak_prompt, "ensemble defense response": "MOD CATEGORY FALSE!!!!!"}
                written_list.append(current_item)
                continue
            
            print("\n======Start Output Filtering======\n")
            prompt_before_filtering_1 = OUTPUT_FILTERING_TEMPLATE.replace("[Jailbreak Prompt]", jailbreak_prompt).replace("[Response]", original_response_1)
            prompt_before_filtering_2 = OUTPUT_FILTERING_TEMPLATE.replace("[Jailbreak Prompt]", jailbreak_prompt).replace("[Response]", original_response_2)

            response_after_filtering_1 = getResponse(prompt_before_filtering_1, args.defense_model)
            response_after_filtering_2 = getResponse(prompt_before_filtering_2, args.defense_model)

            print("\n======Start Voting======\n")
            voting_judge_prompt = VOTING_JUDGE_TEMPLATE.replace("[Jailbreak Prompt]", jailbreak_prompt).replace("[Response-1]", response_after_filtering_1).replace("[Response-2]", response_after_filtering_2)

            final_response = getResponse(voting_judge_prompt, args.defense_model)

            current_item = {"jailbreak prompt": jailbreak_prompt, 
                            "MoD category": mod_category_response,
                            "original response 1": original_response_1, 
                            "original response 2": original_response_2,
                            "prompt before filtering 1": prompt_before_filtering_1,
                            "prompt before filtering 2": prompt_before_filtering_2,
                            "ensemble defense response": final_response}
            
            written_list.append(current_item)
            
            with open(json_path_to_write_results, "w") as f:
                json.dump(written_list, f, indent=4)
        
    elif args.defense_type == "monto-carlo":
        for item in dataset_json_data:
            jailbreak_prompt = item[args.evaluated_field_name]
            prompt_after_monto_carlo = monto_carlo_erase(jailbreak_prompt, args.p_monto_carlo)
            final_response = getResponse(prompt_after_monto_carlo, args.target_model)
            current_item = {"jailbreak prompt": jailbreak_prompt, "monto-carlo prompt": prompt_after_monto_carlo, "monto-carlo response": final_response}
            written_list.append(current_item)

            with open(json_path_to_write_results, "w") as f:
                json.dump(written_list, f, indent=4)
    
    elif args.defense_type == "spell-check":
        for item in dataset_json_data:
            jailbreak_prompt = item[args.evaluated_field_name]
            prompt_after_spell_check = getResponseWithTemplate(jailbreak_prompt, SPELL_CHECK_TEMPLATE, args.defense_model)
            final_response = getResponse(prompt_after_spell_check, args.target_model)
            current_item = {"jailbreak prompt": jailbreak_prompt, "spell checked prompt": prompt_after_spell_check, "spell check response": final_response}
            written_list.append(current_item)

            with open(json_path_to_write_results, "w") as f:
                json.dump(written_list, f, indent=4)
    
    elif args.defense_type == "intention-analysis":
        for item in dataset_json_data:
            jailbreak_prompt = item[args.evaluated_field_name]
            final_response = getResponseWithTemplate(jailbreak_prompt, INTENTION_ANALYSIS_TEMPLATE, args.target_model)
            current_item = {"jailbreak prompt": jailbreak_prompt, "intention analysis response": final_response}
            written_list.append(current_item)
            with open(json_path_to_write_results, "w") as f:
                json.dump(written_list, f, indent=4)
    
    elif args.defense_type == "summarize":
        for item in dataset_json_data:
            jailbreak_prompt = item[args.evaluated_field_name]
            prompt_after_summarize = getResponseWithTemplate(jailbreak_prompt, SUMMARIZE_TEMPLATE, args.defense_model)
            final_response = getResponse(prompt_after_summarize, args.target_model)
            current_item = {"jailbreak prompt": jailbreak_prompt, "summarize prompt": prompt_after_summarize, "summarize response": final_response}
            written_list.append(current_item)

            with open(json_path_to_write_results, "w") as f:
                json.dump(written_list, f, indent=4)



    



if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--target_model",
        default="mistral",
        help="Name of target model.",
        choices=["gpt-3.5-turbo",
                 "gpt-4",
                 "vicuna",
                 "llama-2",
                 "llama-3",
                 "qwen",
                 "mistral",
                 "claude3"]
    )

    parser.add_argument(
        "--defense_model",
        default="llama-3",
        help="Name of defense model.",
        choices=["gpt-3.5-turbo",
                 "gpt-4",
                 "vicuna",
                 "llama-2",
                 "llama-3",
                 "qwen",
                 "mistral",
                 "claude3"]
    )

    parser.add_argument(
        "--evaluated_dataset_json_path",
        type=str,
        default="",
        help="the evaluated prompt dataset to be defensed"
    )

    parser.add_argument(
        "--evaluated_field_name",
        default="attack_query",
        help="evaluated filed name"
    )

    parser.add_argument(
        "--p_monto_carlo",
        default=0.3,
        help="the hyperparameter used in monto carlo"
    )

    parser.add_argument(
        "--defense_json_prompt_file",
        default="",
        help="json file to save jailbreak defense prompt"
    )

    parser.add_argument(
        "--defense_type",
        default="spell-check",
        help="type of defense",
        choices=[
            "ensemble",
            "monto-carlo",
            "intention-analysis",
            "spell-check",
            "summarize",
            "output-filtering"
        ]
    )

    parser.add_argument(
        "--saved_folder",
        default="",
        help="file folder to save the results"
    )
    
    parser.add_argument(
        "--index",
        default=0
    )
    args = parser.parse_args()

    main(args)

