# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import os
import datasets
import argparse


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_dir", default="~/data/retool")
    parser.add_argument("--hdfs_dir", default=None)

    args = parser.parse_args()
    dapo_dataset = datasets.load_dataset(
        "haizhongzheng/DAPO-Math-17K-cleaned"
    )
    dataset = dapo_dataset["train"]
    test_dataset = dataset.shuffle(seed=42).select(range(2000))
    train_dataset = dataset.shuffle(seed=42).select(range(0, len(dataset)))

    def process_dapo(example, idx):
        prompt = example.get("prompt", "")
        question = prompt
        ground_truth = example.get("target", "")
        data_source = example.get("data_source", "BytedTsinghua-SIA/DAPO-Math-17k")
        ability = example.get("ability", "")

        data = {
            "data_source": "BytedTsinghua-SIA/DAPO-Math-17k-text",
            "prompt": [
                {
                    "role": "system",
                    "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
                },
                {
                    "role": "user",
                    "content": question,
                }
            ],
            "ability": ability,
            "reward_model": {
                "style": "rule",
                "ground_truth": ground_truth,
            },
            "extra_info": {
                "split": "train",
                "index": idx,
            },
        }
        return data

    processed_train = train_dataset.map(function=process_dapo, with_indices=True)
    processed_test = test_dataset.map(function=process_dapo, with_indices=True)
    local_dir = "/root/your_rl_data_path/agentmath"
    hdfs_dir = args.hdfs_dir

    os.makedirs(local_dir, exist_ok=True)
    processed_train.to_parquet(os.path.join(local_dir, "train_dapo_17k.parquet"))

