
import re
import argparse
import jsonlines
from tqdm import tqdm
import numpy as np
from collections import defaultdict
from common.utils import clean_string, reduce_multi_empty_lines, extract_between, extract_after


def split_info(raw_info, raw_task):
    pattern = r'\d+[.]\s+'
    items = re.split(pattern, raw_info)
    # 处理各个string
    info_list_raw = [clean_string(info) for info in items if info != '']
    info_list_tmp = [reduce_multi_empty_lines(info).strip('.') for info in info_list_raw]
    raw_task_tmp = reduce_multi_empty_lines(clean_string(raw_task))
    for info in info_list_tmp:
        if info not in raw_task_tmp:
            return None
    return info_list_raw


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file', type=str, required=True)
    parser.add_argument('--output_file', type=str, default=None)
    args = parser.parse_args()
    data = list(jsonlines.open(args.input_file, 'r'))
    count = 0
    info_lengths = 0
    output_data, output_data_2 = [], []
    level_counts = defaultdict(int)
    level_info_lens = defaultdict(list)
    for item in tqdm(data):
        parsed_task = item['decomposed_task'].lower()
        raw_task = item['raw_task'].lower()
        if raw_task is None:
            continue
        goal = extract_between(parsed_task, 'goal', 'necessary information')
        if goal is None:
            continue
        goal_tmp = reduce_multi_empty_lines(clean_string(goal))
        if goal_tmp not in reduce_multi_empty_lines(clean_string(raw_task)):
            continue
        info = extract_between(parsed_task, 'necessary information', 'background')
        background = extract_after(parsed_task, 'background')
        info_list = split_info(info, raw_task)
        if info_list is None:
            continue
        filtered_info_list = []
        for info in info_list:
            if reduce_multi_empty_lines(info) not in goal_tmp:
                filtered_info_list.append(info)
        if len(filtered_info_list) == 0:
            continue
        level_counts[item['level']] += 1
        level_info_lens[item['level']].append(len(filtered_info_list))
        info_lengths += len(filtered_info_list)
        output_data.append({
            'all_info_list': info_list,
            'filtered_info_list': filtered_info_list,
            'goal': goal,
            'background': background,
            'raw_task': raw_task,
            'solution': item['solution'],
            'answer': item['answer'],
            'level': item['level'],
        })
        count += 1
    raw_tasks_1 = set([item['raw_task'].lower() for item in output_data])
    for item in tqdm(data):
        raw_task = item['raw_task'].lower()
        if raw_task not in raw_tasks_1:
            output_data_2.append({
                'raw_task': raw_task,
                'solution': item['solution'],
                'answer': item['answer'],
                'level': item['level'],
            })
    print(level_counts)
    level_info_lens = {key: np.mean(value) for key, value in level_info_lens.items()}
    print(level_info_lens)
    print(len(output_data))
    print(info_lengths / count)
    file_name = '.'.join(args.input_file.split('.')[:-1])
    output_file = f'{file_name}_pp.jsonl'
    print(len(output_data))
    with jsonlines.open(output_file, 'w') as w:
        w.write_all(output_data)
    if args.output_file is not None:
        print(len(output_data_2))
        with jsonlines.open(args.output_file, 'w') as w:
            w.write_all(output_data_2)


if __name__ == '__main__':
    main()
