from multiprocessing import Value
from datasets import load_dataset, concatenate_datasets
import os
import sys
import argparse

# Add rllm project root directory to the Python path
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(script_dir))
sys.path.insert(0, project_root)

# Change current working directory to the rllm project root
os.chdir(project_root)

from rllm.data.dataset import DatasetRegistry

def prepare_sudoku_data(do_validation=False, data_type="all", add_info=None, start_idx=0, num_examples=100):

    # [ADD] ONLY_K_ACTION env var check
    _k = os.getenv("ONLY_K_ACTION")
    if _k is not None and str(_k).strip() != "":
        try:
            k = int(_k)
        except ValueError:
            raise AssertionError(f"ONLY_K_ACTION must be int, got {_k}")
        expected_add_info = f"only_{k}_action"
        assert add_info == expected_add_info, (
            f"Since ONLY_K_ACTION={k} is set, "
            f"add_info must be '{expected_add_info}'. "
            f"Current add_info={add_info}"
        )

    def preprocess_fn(example, idx):
        if add_info is not None:
            example["add_info"] = add_info

        example["task_type"] = "sudoku"
        
        return {
            "data_source": "sudoku_bench",
            "prompt": [
                {
                    "role": "user",
                    "content": "",  # placeholder since there is no real prompt is needed to environment based trajectory collection
                }
            ],
            "ability": "sudoku",
            "reward_model": {"style": "rule", "ground_truth": ""},
            "extra_info": example,
        }

    if do_validation:
        # dataset = load_dataset(YOUR_PATH)
        validation_num = 64
        if data_type =="easy":
            dataset = load_dataset(YOUR_PATH, split="train")
        elif data_type == "medium":
            dataset = load_dataset(YOUR_PATH)
        elif data_type == "hard":
            dataset = load_dataset(YOUR_PATH)
        elif data_type == "extremely_hard":
            dataset = load_dataset(YOUR_PATH)
        elif data_type == "new_900":
            dataset = load_dataset(YOUR_PATH)
        else:
            raise ValueError(f"Invalid data type: {data_type}")
        
        dataset = dataset.shuffle(seed=42)
        dataset = dataset.select(range(validation_num))

        dataset = dataset.map(preprocess_fn, with_indices=True)
        
        dataset = DatasetRegistry.register_dataset("sudoku_validation", dataset, "test")
        
        print(f"Datasets registered successfully!")
    else:

        dataset = load_dataset(YOUR_PATH)
        
        if data_type == "standard":
            dataset = dataset.select(range(600))
        elif data_type == "variant":
            dataset = dataset.select(range(600, 680))
        elif data_type == "easy":
            dataset = dataset.select(range(0, 150))
        elif data_type == "medium_hard":
            dataset = dataset.select(range(150, 450))
        elif data_type == "new_900":
            dataset = load_dataset(YOUR_PATH) \
                .filter(lambda x: (x['difficulty'] == 'medium' and x['missing'] >= 50))
        elif data_type == "new_1045_variant":
            dataset = load_dataset(YOUR_PATH)
        elif data_type == "new_abc_variant":
            dataset = load_dataset(YOUR_PATH)
            indices_dataset = load_dataset(YOUR_PATH)
            dataset = dataset.select(indices_dataset['index'])
        elif data_type == "last_final_standard":
            dataset = load_dataset(YOUR_PATH)
        elif data_type == "last_final_standard_hotfix":
            dataset = load_dataset(YOUR_PATH)
        elif data_type == "last_final_standard_hotfix_only":
            dataset = load_dataset(YOUR_PATH)
        elif data_type == "last_final_standard_hard":
            dataset = load_dataset(YOUR_PATH)
        elif data_type == "last_final_standard_remain":
            # variant, hard, hotfix(21-25, 26-30)
            dataset = concatenate_datasets([
                load_dataset(YOUR_PATH),
                load_dataset(YOUR_PATH),
                load_dataset(YOUR_PATH),
            ])
        elif data_type == "last_final_standard_all":
            dataset = load_dataset(YOUR_PATH)
        elif data_type == "100_long_for_5mini_pareto":
            dataset = load_dataset(YOUR_PATH)
            dataset = dataset.select([402, 411, 314, 370, 434, 428, 407, 381, 426, 395, 359, 433, 339, 376, 437, 328, 368, 341, 383, 316, 397, 313, 419, 412, 391, 743, 830, 734, 760, 741, 810, 801, 806, 781, 753, 738, 809, 727, 782, 794, 724, 838, 762, 756, 739, 769, 725, 848, 807, 770, 164, 173, 171, 242, 193, 228, 214, 204, 159, 190, 260, 250, 280, 245, 289, 263, 278, 218, 295, 157, 243, 269, 231, 247, 258, 602, 707, 693, 569, 646, 710, 675, 703, 549, 712, 530, 674, 611, 695, 601, 588, 717, 692, 694, 716, 653, 577, 721, 698, 687])
        elif data_type == "10_long_for_5mini_pareto":
            dataset = load_dataset(YOUR_PATH)
            dataset = dataset.select([848, 687, 760, 370, 383, 250, 611, 675, 695, 376])
        elif data_type == "5mini_solved_4b_unsolved":
            dataset = load_dataset(YOUR_PATH)
            dataset = dataset.select([193,201,207,217,227,228,230,235,236,237,238,242,244,245,246,248,249,250,251,252,253,257,258,259,261,262,263,265,266,267,269,270,272,275,277,278,279,280,281,282,283,284,285,287,288,289,290,291,292,293,294,295,296,297,299,300,301,302,460,466,470,472,475,476,477,478,483,484,486,487,488,491,495,496,497,499,500,501,502,503,505,506,507,508,510,513,518,519,520,521,522,523,525,527,528,532,534,535,536,538,539,541,542,545,546,547,548,549,550,551,555,557,558,559,562,563,565,566,569,570,571,572,574,577,578,580,581,582,583,586,587,590,591,592,593,594,595,597,598,600,603,606,607,608,609,611,612,613,614,615,616,618,619,620,621,622,623,625,626,627,628,629,630,631,632,633,634,635,636,637,638,639,640,642,643,644,645,646,647,648,649,650,651,652,653,654,655,656,657,658,659,660,661,662,663,664,666,667,668,669,670,671,850,851,852,853,854,856,857,858,859,860,863,865,870,871,872,874,875,876,879,880,883,884,886,887,888,889,893,895,896,897,899,900,901,902,903,906,907,908,909,911,912,913,914,916,918,919,920,921,922,923,924,928,929,930,932,934,935,936,938])
        elif data_type == "last_final_standard_hard_remove_hardskill":
            dataset = load_dataset(YOUR_PATH)
        else:
            raise ValueError(f"Invalid data type: {data_type}")

        dataset = dataset.map(preprocess_fn, with_indices=True)
        
        dataset = DatasetRegistry.register_dataset("sudoku_bench", dataset, "test")
        
        print(f"Datasets registered successfully!")

    
    return dataset


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--do_validation", action="store_true")
    parser.add_argument("--add_info", type=str, default=None)
    parser.add_argument(
        "--data_type",
        type=str,
        default="all",
        choices=[
            "all",
            "standard",
            "variant",
            "new_900",
            "new_1045_variant",
            "new_abc_variant",
            "last_final_standard",
            "last_final_standard_hotfix",
            "last_final_standard_hotfix_only",
            "last_final_standard_remain",
            "last_final_standard_hard",
            "last_final_standard_all",
            "100_long_for_5mini_pareto",
            "10_long_for_5mini_pareto",
            "last_final_standard_hard_remove_hardskill",
        ]
    )
    args = parser.parse_args()

    dataset = prepare_sudoku_data(args.do_validation, args.data_type, args.add_info)
    print(dataset)

