############ GPT-4 Judge ##################


import os
import json
import numpy as np
import pandas as pd
from eval.openai_gpt4_judge import duo_judge, beaver_evaluate
import time
import threading
import argparse


def parsing_argument():
    parser = argparse.ArgumentParser(description='model name which you want to inference')
    parser.add_argument('--dataset', required=True, help='choose dataset to inference')
    parser.add_argument('--model', required=True, help='choose model to inference')
    
    return parser.parse_args()


def make_output_folder(directory):
    output_folder = directory + "/gpt4_o_eval_output"
    os.makedirs(output_folder, exist_ok=True)
    return output_folder


def result_print(scores_attacked, mode = "Duo"):
    scores_attacked_np = np.array(scores_attacked[0])
    scores_attacked_avg = scores_attacked_np.mean()
    num_scores_attacked_max_harm = (scores_attacked_np == 5).sum()
    num = len(scores_attacked[0])
    print('>>> [%s-Jailbroken] Average Harmful Score = %f' % (mode,scores_attacked_avg))
    print('>>> [%s-Jailbroken] # Cases with Harmful Score of 5 = %d / %d = %f' % 
          (mode,num_scores_attacked_max_harm, num, num_scores_attacked_max_harm/num) )


def hex_main(
    input_file,
    **kwargs
):
    print(f"Processing {input_file.split('/')[-1]}...")
    data = pd.read_json(input_file)
    # if 'chat' in model_name.lower():
    #     outputs_to_judge = [data['response'][i] for i in range(len(data))]
    # else:
    #     outputs_to_judge = [data['response'][i].split('\nANSWER:')[1] for i in range(len(data))]
    outputs_to_judge = [data['response'][i] for i in range(len(data))]
    benchmark = [data['prompt'][i] for i in range(len(data))]
    categories = [data['category'][i] for i in range(len(data))]
    QApairs = list(zip(benchmark, outputs_to_judge))
    
    duo_results = duo_judge(QApairs, outputs_to_judge) # Call GPT4 API
    print(f"Finished processing {input_file.split('/')[-1]}.", len(duo_results[0]))
    
    # Save results
    qa_records = []
    cnt = 0
    for (que, ans) in QApairs:
        record = {
            'category': categories[cnt],
            'que': que,
            'ans': ans,
            'duo_score': duo_results[0][cnt],
            'duo_reason': duo_results[1][cnt],
        }
        qa_records.append(record)
        cnt += 1
        
    result_print(duo_results, mode = "Duo")
    
    # Dump to output file
    output_filename = input_file.split('/')[-1].split('.')[0]
    output_file = f"{output_folder}/{output_filename}.json"
    with open(output_file, 'w') as f:
        json.dump(qa_records, f, indent=4)
    print(f"Detailed results (scores and resons) are saved to {output_file}.")


def beaver_main(
    input_file,
    **kwargs
):
    print(f"Processing {input_file.split('/')[-1]}...")
    data = pd.read_json(input_file)
    outputs_to_judge = [data['response'][i].split('[/INST] ')[1] for i in range(len(data))]
    benchmark = [data['prompt'][i] for i in range(len(data))]
    category_id = [data['category_id'][i] for i in range(len(data))]
    QApairs = list(zip(benchmark, outputs_to_judge))
    
    beaver_results,_ = beaver_evaluate(QApairs, outputs_to_judge) # Call GPT4 API
    print(f"Finished processing {input_file.split('/')[-1]}.", len(beaver_results))
    
    # Save results
    qa_records = []
    cnt = 0
    for (que, ans) in QApairs:
        record = {
            'index': int(cnt),
            'prompt': que,
            'response': ans,
            'category_id': int(category_id[cnt]),
            'flagged':{
                'gpt4': beaver_results[cnt],
                },
        }
        qa_records.append(record)
        cnt += 1
        
    # Dump to output file
    output_filename = input_file.split('/')[-1].split('.')[0]
    output_file = f"{output_folder}/{output_filename}.json"
    with open(output_file, 'w') as f:
        json.dump(qa_records, f, indent=4)
    print(f"Detailed results are saved to {output_file}.")


if __name__ == "__main__":
    
    args = parsing_argument()
    dataset_name = args.dataset
    model_name = args.model
    
    directories = [
        # f'results/{dataset_name}-mitigation/finetuning/{model_name}/safety',
        # f'results/{dataset_name}-mitigation/finetuning/{model_name}/context',
        # f'results/{dataset_name}-mitigation/finetuning/{model_name}/value/hed',
        # f'results/{dataset_name}-mitigation/finetuning/{model_name}/value/sd',
        # f'results/{dataset_name}-mitigation/finetuning/{model_name}/value/hedsd',
        f'results/{dataset_name}-mitigation/finetuning/{model_name}/value/hed-2',
    ]
    
    for directory in directories:
        file_list = [
            'Openness_to_Change_adult.json',
            'close_Openness_to_Change_adult.json',
            'close_Openness_to_Change_2_adult.json',
            'close_Openness_to_Change_3_adult.json',
            'close_Openness_to_Change_4_adult.json',
            'close_Openness_to_Change_5_adult.json',
            'close_Openness_to_Change_6_adult.json',
            'close_Openness_to_Change_7_adult.json',
            'close_Openness_to_Change_8_adult.json',
            'close_Openness_to_Change_9_adult.json',
            'close_Openness_to_Change_10_adult.json',
        ]
        
        output_folder = make_output_folder(directory)
        
        start = time.perf_counter()
                
        # 스레드를 담을 리스트 threads 초기화
        files = []
        for d in file_list:
            cur_dir = os.path.join(directory, d)
            input_file = cur_dir
            files.append(input_file)
            
        print(len(files))
        
        
        # 첫 번째 배치 (첫 5개의 파일)
        threads_batch1 = []
        for i in range(5):
            thread = threading.Thread(target=hex_main, args=(files[i],))
            threads_batch1.append(thread)

        # 두 번째 배치 (다음 5개의 파일)
        threads_batch2 = []
        for i in range(5, 11):
            thread = threading.Thread(target=hex_main, args=(files[i],))
            threads_batch2.append(thread)

        # 첫 번째 배치 스레드 시작
        for thread in threads_batch1:
            thread.start()

        # 첫 번째 배치 스레드 종료 대기
        for thread in threads_batch1:
            thread.join()

        # 두 번째 배치 스레드 시작
        for thread in threads_batch2:
            thread.start()

        # 두 번째 배치 스레드 종료 대기
        for thread in threads_batch2:
            thread.join()

        end = time.perf_counter()