from types import NoneType
from datasets import concatenate_datasets, load_dataset, Dataset
import os
import sys
import json

script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(script_dir))
sys.path.insert(0, project_root)

os.chdir(project_root)

from rllm.data.dataset import DatasetRegistry

def load_json_dataset(dataset_path: str):
    with open(dataset_path, "r") as f:
        data = json.load(f)
    # json to dataset
    dataset = Dataset.from_list(data)
    return dataset

def collect_datasets(dataset_path: str, data_type: str = "rl"):
    if "json" in dataset_path:
        train_dataset = load_json_dataset(dataset_path)
        if data_type == "rl":
            dataset_folder = os.path.dirname(dataset_path)
            valid_dataset = load_json_dataset(f"{dataset_folder}/rl_valid.json")
        elif data_type == "sft":
            valid_dataset = None
        else:
            raise ValueError(f"Invalid data type: {data_type}")
    else:
        train_dataset = load_dataset(dataset_path, split="train")
        if data_type == "rl":
            valid_dataset = load_dataset(YOUR_PATH)
            valid_easy = valid_dataset.select(range(80, 120))
            valid_medium = valid_dataset.select(range(200, 240))
            valid_hard = valid_dataset.select(range(340, 380))
            valid_dataset = concatenate_datasets([valid_easy, valid_medium, valid_hard])
        elif data_type == "sft":
            valid_dataset = None
        else:
            raise ValueError(f"Invalid data type: {data_type}")

    return train_dataset, valid_dataset


def prepare_sudoku_data(dataset_path, add_info, data_type: str = "rl"):
    train_dataset, valid_dataset = collect_datasets(dataset_path, data_type)

    def process_fn_sft(example):
        if isinstance(example, dict):
            row_dict = example
        else:
            row_dict = dict(example)

        row_dict["task_type"] = "sudoku"
        row_dict["data_source"] = "sudoku_train"
        return row_dict

    def process_fn_rl(example):
        if isinstance(example, dict):
            row_dict = example
        else:
            row_dict = dict(example)
        # add task_type
        row_dict["task_type"] = "sudoku"
        if add_info:
            row_dict["add_info"] = add_info
        return {
            "data_source": "sudoku_rl",
            "prompt": [
                {
                    "role": "user",
                    "content": "",  # placeholder since there is no real prompt is needed to environment based trajectory collection
                }
            ],
            "ability": "sudoku",
            "reward_model": {"style": "rule", "ground_truth": ""},
            "extra_info": row_dict,
        }
    
    if data_type == "rl":
        train_dataset = train_dataset.map(process_fn_rl)
        valid_dataset = valid_dataset.map(process_fn_rl)
        train_dataset = DatasetRegistry.register_dataset("sudoku_train", train_dataset, "rl_train")
        valid_dataset = DatasetRegistry.register_dataset("sudoku_train", valid_dataset, "rl_valid")
    elif data_type == "sft":
        train_dataset = train_dataset.map(process_fn_sft)
        valid_dataset = None
        train_dataset = DatasetRegistry.register_dataset("sudoku_train", train_dataset, "sft_train")
    else:
        raise ValueError(f"Invalid data type: {data_type}")

    print(f"Current working directory: {os.getcwd()}")
    print(f"Registering datasets in: {os.path.abspath('rllm/data/datasets')}")
    
    print(f"Datasets registered successfully!")
    print(f"Available datasets: {DatasetRegistry.get_dataset_names()}")
    print(f"Sudoku bench splits: {DatasetRegistry.get_dataset_splits('sudoku_train')}")
    
    return train_dataset, valid_dataset


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", type=str, required=True)
    parser.add_argument("--add_info", type=str, default=None)
    parser.add_argument("--data_type", type=str, default="rl")
    args = parser.parse_args()


    train_dataset, valid_dataset = prepare_sudoku_data(args.dataset_path, args.add_info, args.data_type)
    print(len(train_dataset))

