import consumer
import make_prompt
import producer
import os
import sys
import threading
import queue
import random
import pandas as pd
import re
import copy
import json

TNUM = 500
group = 16

JUDGE_KEY = 1
ANS_STR_KEY = 'ans'

lock = threading.Lock()

'''
def write_to_file(result_queue, target_file):
    with open(target_file, 'w') as f:
'''

#def stage_1_resp(data_path):
def stage_1_resp(data_path, out_path, is_while):
    is_while = int(is_while)
    print('----is_while:', is_while)
    if is_while:
        print('----is_while_o:', is_while)
    
        
    write_f = open(out_path, 'w')
    threads = []
    p_threads = []
    my_producer = producer.Producer(data_path)
    buffer_size = TNUM * 100
    origin_prompt_queue = queue.Queue(maxsize=buffer_size)
    gpt4_prompt_queue = queue.Queue(maxsize=buffer_size)

    make_prompt_thread = threading.Thread(target=make_prompt.make_prompt, args=(origin_prompt_queue, gpt4_prompt_queue, "stage_1", is_while))
    producer_thread = threading.Thread(target=my_producer.get_buffer, args=(origin_prompt_queue, 'get_prompt'))
    producer_thread.start()
    if not is_while:
        producer_thread.join()

    make_prompt_thread.start()

    for i in range(TNUM):
        threads.append(threading.Thread(target=consumer.Consumer.stage_1_process_data, args=(gpt4_prompt_queue, lock, write_f, is_while)))
    for c_thread in threads:
        c_thread.start()

    make_prompt_thread.join()
    for c_thread in threads:
        c_thread.join()

def stage_2_resp(data_path, out_path):
    threads = []

    my_producer = producer.Producer(data_path)
    buffer = my_producer.get_buffer()
    results_queue = queue.Queue()
    buffer = make_prompt.make_prompt(buffer, "stage_2")

    for i in range(TNUM):
        threads.append(threading.Thread(target=consumer.Consumer.process_data,args=(buffer, results_queue)))
    for c_thread in threads:
        c_thread.start()
        # p_thread.join()
        # producer()
    for c_thread in threads:
        c_thread.join()
        
    results = []
    while not results_queue.empty():
        result = results_queue.get()
        results.append(result)

    results = sorted(results, key=lambda x: x[0])

    if data_path.endswith(".json"):
        pattern = r"\[(\d+)\]"
        with open(data_path, "r") as origin_data:
            origin_json = json.load(origin_data)
        score_list = []
        copy_json = copy.deepcopy(origin_json)
        dict_key = list(origin_json.keys())
        for prompt, info in origin_json.items():
            copy_json[prompt]['stage2_judge'] = ""
            copy_json[prompt]['stage2_score'] = []
        for idx, i in enumerate(results):
            match = re.findall(pattern, i[1])
            judge = i[JUDGE_KEY]
            copy_json[dict_key[idx]]['stage2_judge'] = judge.split(sep='\n\n')
            # copy_json[dict_key[idx]]['stage2_judge'] = judge
            copy_json[dict_key[idx]]['stage2_score'].extend(match)
        #res_path = data_path.replace("_stage1.json", "_stage2.json")
        res_path = out_path
        with open(res_path, "w+", encoding='utf-8') as f:
            json.dump(copy_json, f, ensure_ascii=False, indent=4)
    elif data_path.endswith(".xlsx"):
        res_df = pd.read_excel(data_path)
        pattern = r"\[(\d+)\]"
        score_list = []
        for idx, i in enumerate(results):
            match = re.findall(pattern, i[1])
            score_list.extend(match)
        print(score_list)

        for i, r in res_df.iterrows():
            for idx in range(group):
                res_df.at[i, f"Extracted{idx}"] = score_list[i*16 + idx]
        
        res_df.to_excel("test_stage_2.xlsx")

def compare_stage_1_2(stage_2_res):
    with open(stage_2_res, "r") as stage_2_f:
        stage_2_data = json.load(stage_2_f)
    for question, info in stage_2_data.items():
        ans_length = len(info[ANS_STR_KEY])
        for i in range(ans_length):
            if info["stage1_score"][i] != info["stage2_score"][i]:
                print(f"student {i+1}：\n一阶段得分:\n{info['stage1_score'][i]}, 二阶段得分：\n{info['stage2_score'][i]}")
        print(f"问题：{question}\nstage1&stage2得分一致")
# stage_1_resp()
# stage_2_resp()
if __name__ == "__main__":
    stage_1_in_file = sys.argv[1]
    stage_1_out_file = sys.argv[2]
    is_while = sys.argv[3]
    #stage_2_out_file = sys.argv[3]
    stage_1_resp(stage_1_in_file, stage_1_out_file, is_while)
    #stage_1_resp("test_json.json")
    #stage_2_resp("test_json_stage1.json")
    #stage_2_resp(stage_1_out_file, stage_2_out_file)
    #compare_stage_1_2(stage_2_out_file)


