import jsonlines
import numpy as np
import argparse
from common.utils import async_http_process_requests, simple_promptify, list_to_string
from common.model_configs import config_model, config_aliyun

chat_template = """

# Unclear Task

{unclear_task}

# Reason of Unclearness

{reason_of_unclearness}

# Output

{output}

# Instruction

Please act as an impartial judge to evaluate whether the output contains information related to the unclearness of the unclear task.

Output your final verdict by strictly following this format:
Analysis: [Your analysis about the output]
Judge: "[[Yes]]" if the output contains information related to the unclearness, otherwise "[[No]]"
"""


def parse_res(res):
    if '[[Yes]]' in res:
        return 1
    else:
        return 0


keywords = ['alternative', 'wait', 'but', 'check']

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file', type=str)
    parser.add_argument('--save_file', type=str)
    parser.add_argument('--segment', type=int)
    args = parser.parse_args()

    data = list(jsonlines.open(args.input_file))
    prompts = []
    thoughts_len = []
    for item in data:
        thoughts = item['thought'].split('\n')
        check_thoughts = []
        for thought in thoughts:
            flag = False
            for keyword in keywords:
                if keyword in thought.lower():
                    flag = True
                    break
            if flag:
                check_thoughts.append(thought)
        check_thoughts = '\n\n'.join(check_thoughts[:10])
        prompt = chat_template.format(
            unclear_task=item['unclear_task'], reason_of_unclearness=item['reason_of_unclearness'],
            output=check_thoughts
        )
        prompts.append(prompt)
        thoughts_len.append(len(item['thought']))
    # print(prompts[0])
    # input()
    model_config = config_model(config_aliyun, 'deepseek-r1', 0.6, 100)
    requests = simple_promptify(prompts)
    responses = async_http_process_requests(requests, model_config)
    data_to_save = [{
        'judge_res': parse_res(res[0]),
        'judge': res[0],
        'metadata': item,
    } for res, item in zip(responses, data)]
    pass_ratio = np.mean([item['judge_res'] for item in data_to_save])
    print(f'Pass Ratio: {pass_ratio*100:.2f}%')
    with jsonlines.open(args.save_file, 'w') as writer:
        writer.write_all(data_to_save)


if __name__ == '__main__':
    main()
