from collections import defaultdict

import jsonlines
import argparse
from evaluation import verify_math
import numpy as np
from common.utils import get_tl_length


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_path_basic', type=str)
    parser.add_argument('--input_path_cls', type=str)
    parser.add_argument('--output_file', type=str)
    args = parser.parse_args()

    with jsonlines.open(args.input_path_basic) as reader:
        data_raw = list(reader)
    with jsonlines.open(args.input_path_cls) as reader:
        data_cls = list(reader)
    basic_dict = {item['metadata']['raw_task']: item for item in data_raw if item['answer'] is not None and 'clarification' not in item['answer'].lower()}
    cls_dict = {item['metadata']['raw_task']: item for item in data_cls if item['answer'] is not None and 'clarification' not in item['answer'].lower()}
    keys = set(basic_dict.keys()).intersection(set(cls_dict.keys()))
    basic_tl, cls_tl = [], []
    output_data = []
    for key in keys:
        basic_item = basic_dict[key]
        cls_item = cls_dict[key]
        data_to_save = {
            'basic_prompt': {
                'input': basic_item['task'],
                'answer': basic_item['answer'],
                'thought': basic_item['thought'],
                'len_thought': get_tl_length(basic_item)
            },
            'cls_prompt': {
                'input': cls_item['task'],
                'answer': cls_item['answer'],
                'thought': cls_item['thought'],
                'len_thought': get_tl_length(cls_item)
            }
        }
        output_data.append(data_to_save)
    with jsonlines.open(args.output_file, 'w') as writer:
        writer.write_all(output_data)

    # print(len(thoughts_raw))
    # thought_raw_lens = [len(thought) for thought in thoughts_raw]
    # thought_cls_lens = [len(thought) for thought in thoughts_cls]
    # print(f'Thoughts Lengths on Original Questions: {np.mean(thought_raw_lens)}')
    # print(f'Thoughts Lengths on Ambiguous Questions: {np.mean(thought_cls_lens)}')


if __name__ == '__main__':
    main()
