import numpy as np
import argparse
import jsonlines
from common.utils import async_http_process_requests, simple_promptify
from common.model_configs import config_model, get_configs


template_reason = """
# Task

{query}

# Reason of Unclearness

{reason_of_unclearness}

# Instruction

According to the reason of unclearness, the task is classified to be unclear.
Please act as an impartial judge to evaluate whether the task is truly unclear and whether the reason of unclearness is correct.
If you agree that task is truly unclear, provide at least two different understandings of the unclear task corresponding to the unclearness as the verification of your judgement.

Output your final verdict by strictly following this format: 
Analysis: [Your analysis about the task and the reason of unclearness]
Judge: "[[Yes]]" if you agree with the reason and that the task is truly unclear, "[[No]]" if you do not agree with the reason and that the task is clear.
Verification: [At least two different understandings of the unclear task corresponding to the unclearness]
"""

template_no_reason = """
# Raw Task

{raw_query}

# Target Task

{target_query}

# Instruction

Please act as an impartial judge to evaluate whether the target task is unclear compared with the raw task.
Provide the rationale about your judgement. 

Output your final verdict by strictly following this format: 
Judge: "[[Yes]]" if you judge that the task is unclear compared with the raw task, otherwise "[[No]]"
Rationale: [The rationale that whether the target task is unclear compared with the raw task]\
"""


def parse_res(res):
    if '[[No]]' in res:
        return False
    else:
        return True


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file', required=True)
    parser.add_argument('--output_file', required=True)
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--client_name', type=str)
    parser.add_argument('--num_concurrent', type=int, default=100)
    args = parser.parse_args()

    with jsonlines.open(args.input_file) as reader:
        data = list(reader)

    prompts = []
    for item in data:
        if 'reason_of_unclearness' in item:
            prompt = template_reason.format(
                query=item['unclear_task'], reason_of_unclearness=item['reason_of_unclearness']
            )
        else:
            prompt = template_no_reason.format(
                raw_query=item['raw_task'], target_query=item['unclear_task'],
            )
        prompts.append(prompt)
    print(len(prompts))
    print(prompts[0])
    model_config = get_configs(args.client_name)
    model_config = config_model(model_config, args.model_name, 0., args.num_concurrent)
    requests = simple_promptify(prompts)
    responses = async_http_process_requests(requests, model_config)
    data_to_save = [{
        'verification_prompt': prompt,
        'verification_res': parse_res(res[0]),
        'verification': res[0],
        'metadata': item,
    } for prompt, res, item in zip(prompts, responses, data)]
    pass_ratio = np.mean([item['verification_res'] for item in data_to_save])
    print(f'Verification Pass: {pass_ratio * 100:.2f}%')
    with jsonlines.open(args.output_file, 'w') as writer:
        writer.write_all(data_to_save)


if __name__ == '__main__':
    main()