import numpy as np
import argparse
import jsonlines
from common.utils import async_http_process_requests, simple_promptify
from common.model_configs import config_model, get_configs

template_no_reason = """
# Raw Task

{raw_query}

# Target Task

{target_query}

# Instruction

Please act as an impartial judge to evaluate whether the target task misses information compared with the raw task.
Provide the rationale about your judgement. 

Output your final verdict by strictly following this format: 
Judge: "[[Yes]]" if you judge that the task misses information compared with the raw task, otherwise "[[No]]"
Rationale: [The rationale that whether the target task misses information compared with the raw task]\
"""


def parse_res(res):
    if '[[No]]' in res:
        return False
    else:
        return True


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file', required=True)
    parser.add_argument('--output_file', required=True)
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--client_name', type=str)
    parser.add_argument('--num_concurrent', type=int, default=100)
    args = parser.parse_args()

    with jsonlines.open(args.input_file) as reader:
        data = list(reader)

    prompts = []
    for item in data:
        prompt = template_no_reason.format(
            raw_query=item['raw_task'],
            target_query=item['reassemble_task'],
        )
        prompts.append(prompt)
    print(prompts[0])
    model_config = get_configs(args.client_name)
    model_config = config_model(model_config, args.model_name, 0., args.num_concurrent)
    requests = simple_promptify(prompts)
    responses = async_http_process_requests(requests, model_config)
    data_to_save = [{
        'verification_res': parse_res(res[0]),
        'verification': res[0],
        'metadata': item,
    } for res, item in zip(responses, data)]
    pass_ratio = np.mean([item['verification_res'] for item in data_to_save])
    print(f'Verification Failed: {pass_ratio * 100:.2f}%')
    with jsonlines.open(args.output_file, 'w') as writer:
        writer.write_all(data_to_save)


if __name__ == '__main__':
    main()