"""
Preprocess dataset for countdown task - given a target number and N numbers, generate equations to reach target
"""

import re
import json
import copy
import os
from datasets import Dataset, load_dataset
from random import randint, seed, choice
from typing import List, Tuple
from tqdm import tqdm
import argparse


def make_prefix(dp, template_type):
    question = dp['question']
    # NOTE: also need to change reward_score/countdown.py
    if template_type == 'base':
        """This works for any base model"""
        prefix = r"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the final answer. The reasoning process and final answer are enclosed within the <think> </think> tags and the \boxed{}, respectively, i.e., <think> reasoning process here </think> \boxed{final answer here}.""" + f"""

User: {question}

Assistant: <think>"""
    elif template_type == 'qwen-instruct':
        raise NotImplementedError
        """This works for Qwen Instruct Models"""
        prefix = f"""<|im_start|>system\nYou are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer.<|im_end|>\n<|im_start|>user\n Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.<|im_end|>\n<|im_start|>assistant\nLet me solve this step by step.\n<think>"""
    return prefix


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_dir', default='~/data/countdown')
    parser.add_argument('--hdfs_dir', default=None)
    parser.add_argument('--template_type', type=str, default='base')

    args = parser.parse_args()

    train_dataset_path = '/cpfs01/user/zhengchujie.zcj/Math-Data-Filter/collected_afterNY_rl_queries_selected.jsonl'
    test_dataset_path = '/cpfs01/user/zhengchujie.zcj/Math-Data/test_data/aime24.jsonl'
    train_dataset_list = [json.loads(line) for line in open(train_dataset_path, 'r')]
    train_dataset_list = [{'question': dp['question'], 'answer': dp['answer']} for dp in train_dataset_list]
    train_dataset = Dataset.from_list(train_dataset_list)
    test_dataset_list = [json.loads(line) for line in open(test_dataset_path, 'r')]
    test_dataset_list = test_dataset_list * 16
    test_dataset_list = [{'question': dp['question'], 'answer': dp['answer']} for dp in test_dataset_list]
    test_dataset = Dataset.from_list(test_dataset_list)

    def make_map_fn(split, source):

        def process_fn(example, idx):
            question = make_prefix(example, template_type=args.template_type)
            data = {
                "data_source": source,
                "prompt": [{
                    "role": "user",
                    "content": question,
                }],
                "ability": "math",
                "reward_model": {
                    "style": "rule",
                    "ground_truth": example['answer'],
                },
                "extra_info": {
                    'split': split,
                    'index': idx,
                }
            }
            return data

        return process_fn

    train_dataset = train_dataset.map(function=make_map_fn('train', 'custom_math_hard'), with_indices=True)
    print(train_dataset)
    print(train_dataset[0])
    test_dataset = test_dataset.map(function=make_map_fn('test', 'aime24'), with_indices=True)
    print(test_dataset)
    print(test_dataset[0])

    local_dir = args.local_dir
    hdfs_dir = args.hdfs_dir

    train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
    test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
