
from datasets import load_dataset
import os
import sys
import argparse

# Add the rllm project root directory to the Python path
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(script_dir))
sys.path.insert(0, project_root)

# Change current working directory to the rllm project root
os.chdir(project_root)

from rllm.data.dataset import DatasetRegistry

def prepare_sudoku_data(data_type, add_info, exp_type):
    def preprocess_fn(example, idx):
        if "task_type" not in example:
            example["task_type"] = data_type
        
        if add_info != "None":
            example["add_info"] = add_info
        
        return {
            "data_source": "reasoning_gym_bench",
            "prompt": [
                {
                    "role": "user",
                    "content": "",  # placeholder since there is no real prompt is needed to environment based trajectory collection
                }
            ],
            "ability": data_type,
            "reward_model": {"style": "rule", "ground_truth": ""},
            "extra_info": example,
        }

    if data_type == "rush_hour":
        if exp_type == "test":
            dataset = load_dataset(YOUR_PATH)
        elif exp_type == "train":
            dataset = load_dataset(YOUR_PATH)
        elif exp_type == "570":
            dataset = load_dataset(YOUR_PATH)
        elif exp_type == "570_to7":
            dataset = load_dataset(YOUR_PATH)
            dataset = dataset.filter(lambda example: example.get("min_moves", 999) <= 7)
        elif exp_type == "570_5mini_solvable_12_15mm":
            indices = [67, 68, 71, 105, 109, 121, 130, 135, 137, 138, 145, 152, 165, 167, 176, 181, 187, 188, 189, 193, 197, 200, 201, 204, 205, 206, 207, 213, 214, 220, 230, 231, 232, 233, 234, 236, 237, 242, 243, 245, 248, 251, 258, 259, 260, 268, 269, 274, 275, 276, 277, 280, 281, 282, 284, 288, 291, 292, 293, 294, 296, 297, 299, 300, 302, 304, 306, 307, 309, 311, 312, 313, 314, 315, 322, 326, 328, 330, 333, 336, 340, 346, 347, 350, 353, 357, 363, 375, 377, 378, 382, 384, 387, 394, 397, 405, 406, 408, 437]
            full_dataset = load_dataset(YOUR_PATH)
            dataset = full_dataset.select(indices)
        elif exp_type == "320_12_15mm":
            dataset = load_dataset(YOUR_PATH)
        elif exp_type == "900_4_12mm":
            dataset = load_dataset(YOUR_PATH)
        elif exp_type == "320_13_21":
            dataset = load_dataset(YOUR_PATH)
        elif exp_type == "200_19_21":
            dataset = load_dataset(YOUR_PATH)
        elif exp_type == "last_final_rushhour":
            dataset = load_dataset(YOUR_PATH)
        else:
            raise ValueError(f"Invalid experiment type: {exp_type}")
    elif data_type == "tower_of_hanoi":
        dataset = load_dataset(YOUR_PATH)
    else:
        raise ValueError(f"Invalid data type: {data_type}")

    dataset = dataset.map(preprocess_fn, with_indices=True)
    dataset = DatasetRegistry.register_dataset("reasoning_gym_train", list(dataset), "test")
    dataset = DatasetRegistry.register_dataset("reasoning_gym_bench", list(dataset), "test")
    
    print(f"Datasets registered successfully!")

    
    return dataset


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_type", type=str, default="rush_hour", choices=["rush_hour", "tower_of_hanoi"])
    parser.add_argument("--add_info", type=str, default="None")
    parser.add_argument("--exp_type", type=str, default="test", choices=["test", "train", "570", "570_to7", "570_5mini_solvable_12_15mm", "320_12_15mm", "900_4_12mm", "320_13_21", "200_19_21", "last_final_rushhour"])
    args = parser.parse_args()

    dataset = prepare_sudoku_data(args.data_type, args.add_info, args.exp_type)
    print(dataset)

