# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the MATH-lighteval dataset to parquet format
"""

import argparse
import os

import datasets

from verl.utils.hdfs_io import copy, makedirs
from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed

# These are the MATH-500 indices
DEV_INDICES = [4, 6, 15, 18, 34, 36, 37, 41, 45, 64, 66, 85, 92, 100, 120, 127, 133, 136, 149, 160, 161, 162, 166, 168, 202, 215, 243, 247, 256, 260, 270, 320, 361, 367, 381, 392, 396, 411, 450, 451, 452, 460, 496, 501, 503, 505, 511, 513, 520, 534, 563, 564, 571, 576, 579, 587, 596, 601, 607, 609, 612, 615, 622, 666, 673, 683, 684, 695, 700, 703, 709, 718, 722, 738, 748, 757, 761, 762, 782, 805, 817, 834, 840, 849, 853, 854, 859, 882, 885, 888, 906, 909, 933, 941, 962, 978, 985, 988, 991, 1008, 1033, 1037, 1046, 1048, 1054, 1058, 1067, 1073, 1085, 1088, 1095, 1111, 1119, 1123, 1127, 1128, 1131, 1136, 1144, 1145, 1150, 1172, 1173, 1180, 1188, 1190, 1194, 1196, 1215, 1243, 1250, 1251, 1258, 1262, 1271, 1281, 1285, 1287, 1290, 1302, 1308, 1311, 1312, 1322, 1339, 1359, 1374, 1380, 1402, 1441, 1442, 1449, 1513, 1531, 1540, 1543, 1552, 1555, 1576, 1603, 1612, 1620, 1690, 1710, 1715, 1730, 1764, 1767, 1769, 1788, 1790, 1791, 1801, 1806, 1820, 1842, 1843, 1880, 1890, 1897, 1901, 1905, 1908, 1932, 1935, 1940, 1963, 1967, 1981, 1996, 2001, 2006, 2011, 2041, 2047, 2053, 2057, 2062, 2063, 2078, 2110, 2119, 2120, 2143, 2148, 2150, 2151, 2170, 2186, 2191, 2196, 2199, 2210, 2214, 2215, 2217, 2231, 2236, 2237, 2238, 2246, 2253, 2263, 2264, 2275, 2289, 2294, 2297, 2303, 2311, 2323, 2324, 2325, 2327, 2328, 2334, 2352, 2359, 2360, 2371, 2382, 2384, 2397, 2404, 2409, 2413, 2416, 2473, 2505, 2512, 2515, 2522, 2536, 2539, 2546, 2569, 2571, 2579, 2602, 2607, 2609, 2611, 2622, 2628, 2637, 2647, 2681, 2682, 2700, 2707, 2731, 2752, 2758, 2767, 2799, 2802, 2808, 2816, 2838, 2851, 2863, 2868, 2876, 2883, 2896, 2907, 2937, 2938, 2946, 2966, 2977, 2991, 2994, 3018, 3019, 3020, 3022, 3024, 3035, 3037, 3046, 3047, 3058, 3067, 3072, 3079, 3080, 3105, 3126, 3134, 3141, 3165, 3181, 3186, 3187, 3196, 3200, 3210, 3220, 3226, 3236, 3240, 3246, 3287, 3295, 3299, 3317, 3320, 3323, 3334, 3341, 3342, 3344, 3350, 3352, 3365, 3366, 3369, 3375, 3392, 3404, 3411, 3417, 3419, 3420, 3440, 3444, 3447, 3460, 3467, 3474, 3480, 3498, 3507, 3511, 3519, 3529, 3539, 3541, 3548, 3549, 3569, 3586, 3604, 3607, 3646, 3647, 3658, 3669, 3700, 3711, 3725, 3730, 3732, 3738, 3740, 3741, 3752, 3768, 3769, 3773, 3779, 3802, 3805, 3824, 3849, 3856, 3878, 3913, 3923, 3941, 3942, 3951, 3982, 3990, 3994, 3999, 4011, 4034, 4036, 4042, 4043, 4046, 4055, 4071, 4074, 4088, 4090, 4104, 4108, 4127, 4149, 4150, 4155, 4157, 4158, 4160, 4177, 4181, 4190, 4193, 4210, 4222, 4235, 4242, 4253, 4265, 4272, 4279, 4297, 4303, 4315, 4326, 4333, 4352, 4368, 4384, 4404, 4413, 4423, 4425, 4441, 4449, 4451, 4479, 4487, 4500, 4515, 4523, 4533, 4535, 4547, 4549, 4550, 4569, 4584, 4590, 4591, 4597, 4600, 4603, 4610, 4626, 4657, 4666, 4678, 4697, 4706, 4713, 4731, 4744, 4751, 4753, 4758, 4765, 4776, 4796, 4812, 4834, 4850, 4857, 4861, 4866, 4868, 4871, 4885, 4896, 4900, 4909, 4914, 4924, 4926, 4947, 4955, 4964, 4969, 4978, 4990, 4992, 4993]


def extract_solution(solution_str):
    return remove_boxed(last_boxed_only_string(solution_str))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_dir", default="~/data/math")
    parser.add_argument("--hdfs_dir", default=None)

    args = parser.parse_args()

    # 'lighteval/MATH' is no longer available on huggingface.
    # Use mirror repo: DigitalLearningGmbH/MATH-lighteval
    data_source = "DigitalLearningGmbH/MATH-lighteval"
    print(f"Loading the {data_source} dataset from huggingface...", flush=True)
    dataset = datasets.load_dataset(data_source, trust_remote_code=True)

    train_dataset = dataset["train"]
    test_dataset = dataset["test"]

    instruction_following = "Let's think step by step and output the final answer within \\boxed{}."

    # add a row to each data item that represents a unique id
    def make_map_fn(split):
        def process_fn(example, idx):
            question = example.pop("problem")

            question = question + " " + instruction_following

            answer = example.pop("solution")
            solution = extract_solution(answer)
            data = {
                "data_source": data_source,
                "prompt": [{"role": "user", "content": question}],
                "ability": "math",
                "reward_model": {"style": "rule", "ground_truth": solution},
                "extra_info": {"split": split, "index": idx},
            }
            return data

        return process_fn

    train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
    test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)

    # Split test into dev and test
    dev_indices_set = set(DEV_INDICES)
    dev_dataset = test_dataset.select(DEV_INDICES)

    def filter_dev_indices(example, idx):
        return idx not in dev_indices_set

    test_dataset = test_dataset.filter(filter_dev_indices, with_indices=True)

    local_dir = args.local_dir
    hdfs_dir = args.hdfs_dir

    train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
    dev_dataset.to_parquet(os.path.join(local_dir, "dev.parquet"))
    test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))

    if hdfs_dir is not None:
        makedirs(hdfs_dir)

        copy(src=local_dir, dst=hdfs_dir)