from collections import defaultdict

import jsonlines
import argparse
from evaluation import verify_math
import numpy as np


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_path_raw', type=str)
    parser.add_argument('--input_path_cls', type=str)
    args = parser.parse_args()

    with jsonlines.open(args.input_path_raw) as reader:
        data_raw = list(reader)
    with jsonlines.open(args.input_path_cls) as reader:
        data_cls = list(reader)
    raw_dict = {item['metadata']['raw_task']: item for item in data_raw if item['answer'] is not None and 'clarification' not in item['answer'].lower()}
    cls_dict = {item['metadata']['raw_task']: item for item in data_cls if item['answer'] is not None and 'clarification' not in item['answer'].lower()}
    keys = set(raw_dict.keys()).intersection(set(cls_dict.keys()))
    solved_cls = 0
    total = 0
    for key in keys:
        correct_raw = verify_math(raw_dict[key]['answer'], raw_dict[key]['metadata']['metadata']['task_solution'])
        correct_cls = verify_math(cls_dict[key]['answer'], cls_dict[key]['metadata']['solution'])
        if not correct_raw:
            continue
        total += 1
        if correct_cls:
            solved_cls += 1

    print(f"Correct Cls: {solved_cls}/{total} -> {solved_cls/total*100:.2f}%")
    print('****************************')
    # print(len(thoughts_raw))
    # thought_raw_lens = [len(thought) for thought in thoughts_raw]
    # thought_cls_lens = [len(thought) for thought in thoughts_cls]
    # print(f'Thoughts Lengths on Original Questions: {np.mean(thought_raw_lens)}')
    # print(f'Thoughts Lengths on Ambiguous Questions: {np.mean(thought_cls_lens)}')


if __name__ == '__main__':
    main()
