from prompt import stage1, stage2, standard_ans, student_ans, student_ans_with_score
import queue
import pandas as pd

ans_key = "model{}"
prompt_key = "Prompt"
ref_key = "标准答案"
score_key = "score{}"
judge_key = "judge{}"

ANS_LIST_STR_KEY = 'ans'
REF_STR_KEY = 'ref_ans'
SCORE_STAGE1_KEY = 'stage1_score'
QUESTION_KEY = 0
RESPONSE_KEY = 1
REF_KEY = 2
SCORE_KEY = 3
ROW_DATA = 4

def make_prompt(buffer, target_buffer, type, is_while):
    while 1: 
        if not is_while:
            if buffer.empty():
                print('make_prompt break')
                break
        row_data, pbar = buffer.get()
        if type == "stage_1":
            if 1:
                item_prompt = make_stage1(row_data)
                #print('----make_prompt:', item_prompt)
                #print('----make_prompt_row_data:', row_data)
                target_buffer.put((item_prompt, row_data, pbar))
        elif type == 'stage_2':
            row_item_list = split_row_data(row_data)
            item_list_prompt = make_stage2(row_item_list)
            target_buffer.put((key, item_list_prompt, pbar))
        else:
            raise "stage_1 or stage_2"
    #return ans_buffer

def split_row_data(row_data):
    split_row_data_res = []
    if isinstance(row_data, pd.DataFrame):
        row_length = len(row_data)
        for i in range(row_length):
            temp_key = ans_key.format(str(i))
            # print(type(row_data))
            ans = row_data[temp_key]
            prompt = row_data[prompt_key]
            ref = row_data[ref_key]
            score = row_data[score_key.format(i)]
            split_item = [prompt, ans, ref, score]
            split_row_data_res.append(split_item)
    elif isinstance(row_data, dict):
        
        key = list(row_data.keys())[0]
        value = list(row_data.values())[0]
        # value_data = row_data[]
        row_length = len(value[ANS_LIST_STR_KEY])
        for i in range(row_length):
            split_item = {}
            ans = value[ANS_LIST_STR_KEY][i]
            prompt = key
            ref = value[REF_STR_KEY]
            score = value[SCORE_STAGE1_KEY][i] if SCORE_STAGE1_KEY in value else None
            # judge = judge_key.format(i)
            #print('split_row_data_res:', i)
            split_item['prompt'] = prompt
            split_item['ans'] = ans
            split_item['ref_ans'] = ref 
            for k in row_data.keys():
                if k not in split_item:
                    split_item[k] = row_data[k]

            split_row_data_res.append(split_item)
    return split_row_data_res

def make_stage1(item):

    prompt_1 = stage1.format(item['conversations'][0]['value'])
    if 'solution' in item:
        prompt_2 = standard_ans.format(item['solution'])
    elif 'reference_answer' in item:
        prompt_2 = standard_ans.format(item['reference_answer'])

    if 'response_part_id' in item:
        prompt_3 = student_ans.format('\n'.join(item["response_part_list"][:item['response_part_id']+1]) + item['gen'])
    else:
        prompt_3 = student_ans.format(item['gen'])
    #print('prompt_3:', prompt_3)

    return prompt_1 + prompt_2 + prompt_3


def make_stage2(item_list_prompt):
    prompt_1 = stage1.format(item_list_prompt[0][QUESTION_KEY])
    prompt_2 = standard_ans.format(item_list_prompt[0][REF_KEY])
    res_prompt = prompt_1 + prompt_2
    for idx, item in enumerate(item_list_prompt):
        prompt_3 = student_ans_with_score.format(idx, item[RESPONSE_KEY], idx, item[SCORE_KEY])
        res_prompt += prompt_3
    return res_prompt
