"""
Preprocess dataset for countdown task - given a target number and N numbers, generate equations to reach target
"""

import re
import json
import copy
import os
from datasets import Dataset, load_dataset
from random import randint, seed, choice
from typing import List, Tuple
from tqdm import tqdm
import argparse
from transformers import AutoTokenizer

toker = AutoTokenizer.from_pretrained('/cpfs01/data/shared/Group-m6/zhengchujie.zcj/hf_models/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B')


def make_prefix(dp, template_type):
    question = dp['question']
    # NOTE: also need to change reward_score/countdown.py
    if template_type == 'base':
        """This works for any base model"""
        messages = [
            {'role': 'user', 'content': question},
        ]
        prefix = toker.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    elif template_type == 'qwen-instruct':
        raise NotImplementedError
        """This works for Qwen Instruct Models"""
        prefix = f"""<|im_start|>system\nYou are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer.<|im_end|>\n<|im_start|>user\n Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.<|im_end|>\n<|im_start|>assistant\nLet me solve this step by step.\n<think>"""
    return prefix


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_dir', default='~/data/countdown')
    parser.add_argument('--hdfs_dir', default=None)
    parser.add_argument('--template_type', type=str, default='base')

    args = parser.parse_args()

    train_dataset_path = '/cpfs01/user/zhengchujie.zcj/Math-Data-Filter/collected_afterNY_rl_queries_selected.jsonl'
    test_dataset_path = '/cpfs01/user/zhengchujie.zcj/Math-Data/test_data/aime24.jsonl'
    train_dataset_list = [json.loads(line) for line in open(train_dataset_path, 'r')]
    train_dataset_list = [{'question': dp['question'], 'answer': dp['answer']} for dp in train_dataset_list]
    train_dataset = Dataset.from_list(train_dataset_list)
    test_dataset_list = [json.loads(line) for line in open(test_dataset_path, 'r')]
    test_dataset_list = test_dataset_list * 16
    test_dataset_list = [{'question': dp['question'], 'answer': dp['answer']} for dp in test_dataset_list]
    test_dataset = Dataset.from_list(test_dataset_list)

    def make_map_fn(split, source):

        def process_fn(example, idx):
            question = make_prefix(example, template_type=args.template_type)
            data = {
                "data_source": source,
                "prompt": [{
                    "role": "user",
                    "content": question,
                }],
                "ability": "math",
                "reward_model": {
                    "style": "rule",
                    "ground_truth": example['answer'],
                },
                "extra_info": {
                    'split': split,
                    'index': idx,
                }
            }
            return data

        return process_fn

    train_dataset = train_dataset.map(function=make_map_fn('train', 'custom_math_hard'), with_indices=True)
    print(train_dataset)
    print(train_dataset[0])
    test_dataset = test_dataset.map(function=make_map_fn('test', 'aime24'), with_indices=True)
    print(test_dataset)
    print(test_dataset[0])

    local_dir = args.local_dir
    hdfs_dir = args.hdfs_dir

    train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
    test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
