import jsonlines
import numpy as np
import argparse
from common.utils import async_http_process_requests, simple_promptify, list_to_string
from common.model_configs import config_model, config_aliyun

chat_template = """

# Thoughts

{thoughts}

# Instruction

Please act as an impartial judge to evaluate whether the thoughts have appropriately recognized the unclearness in the unclear task.

Output your final verdict by strictly following this format:
Analysis: [Your analysis about the thoughts]
Judge: "[[Yes]]" if the thoughts have appropriately recognized the unclearness and the need to raise clarification question, otherwise "[[No]]"
"""


def parse_res(res):
    if '[[Yes]]' in res:
        return 1
    else:
        return 0


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file', type=str)
    parser.add_argument('--save_file', type=str)
    args = parser.parse_args()

    data = list(jsonlines.open(args.input_file))
    prompts = []
    thoughts_len = []
    for item in data:
        thoughts = item['thought'].split('\n\n')
        thoughts = '\n\n'.join(thoughts[:4])
        prompt = chat_template.format(
            unclear_task=item['unclear_task'], reason_of_unclearness=item['reason_of_unclearness'],
            thoughts=thoughts
        )
        prompts.append(prompt)
        thoughts_len.append(len(item['thought']))
    print(len(prompts), len(data))
    print('=======================')
    print(prompts[0])
    input()
    model_config = config_model(config_aliyun, 'deepseek-r1', 0.6, 250)
    requests = simple_promptify(prompts)
    responses = async_http_process_requests(requests, model_config)
    data_to_save = [{
        'judge_res': parse_res(res[0]),
        'judge': res[0],
        'metadata': item,
    } for res, item in zip(responses, data)]
    pass_ratio = np.mean([item['judge_res'] for item in data_to_save])
    print(f'Pass Ratio: {pass_ratio*100:.2f}%')
    with jsonlines.open(args.save_file, 'w') as writer:
        writer.write_all(data_to_save)


if __name__ == '__main__':
    main()
