import jsonlines
import argparse
from tqdm import tqdm
from evaluation import verify_math


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file_raw', type=str)
    parser.add_argument('--input_file_raw_non_parsed', type=str)
    parser.add_argument('--input_file_im', type=str)
    parser.add_argument('--input_file_pm', type=str)
    parser.add_argument('--output_non_cls', type=str)
    parser.add_argument('--output_cls', type=str)
    args = parser.parse_args()
    data_raw = list(jsonlines.open(args.input_file_raw))
    data_raw_non_parsed = list(jsonlines.open(args.input_file_raw_non_parsed))
    data_im = list(jsonlines.open(args.input_file_im))
    data_pm = list(jsonlines.open(args.input_file_pm))
    data_non_cls, data_cls = [], []
    for item in tqdm(data_raw+data_raw_non_parsed):
        gold = item['metadata']['answer']
        answer = item['answer']
        thought = item['thought']
        if verify_math(gold, answer):
            data_non_cls.append({
                'prompt': item['task'],
                'answer': answer,
                'thought': thought,
                'metadata': item
            })
    for item in tqdm(data_im+data_pm):
        answer = item['answer']
        if 'clarification' in answer.lower():
            thought = item['thought']
            data_cls.append({
                'prompt': item['task'],
                'answer': answer,
                'thought': thought,
                'metadata': item
            })
    print(len(data_non_cls))
    print(len(data_cls))
    with jsonlines.open(args.output_non_cls, 'w') as writer:
        writer.write_all(data_non_cls)
    with jsonlines.open(args.output_cls, 'w') as writer:
        writer.write_all(data_cls)



if __name__ == '__main__':
    main()
