from typing import Union, Callable, Dict
import datasets 
from datasets import Dataset, IterableDataset, load_dataset
from transformers import PreTrainedTokenizer

from roll.configs.data_args import DataArguments
from roll.utils.logging import get_logger

logger = get_logger()

REGISTERED_DATASETS: Dict[str, Callable[[DataArguments], Union[Dataset, IterableDataset]]] = {}

def register_dataset(key: str):
    def decorator(func: Callable[[DataArguments], Union[Dataset, IterableDataset]]):
        if key in REGISTERED_DATASETS:
            raise ValueError(f"Dataset type '{key}' already exists!")
        REGISTERED_DATASETS[key] = func
        return func
    return decorator

def get_dataset(data_args: "DataArguments"):
    key = data_args.dataset_type
    if key not in REGISTERED_DATASETS:
        raise ValueError(
            f"Dataset type '{key}' is not found! Available datasets: {list(REGISTERED_DATASETS.keys())}"
        )
        
    dataset_paths = []
    if data_args.file_name:
        dataset_paths.extend(data_args.file_name)

    logger.info(f'load_dataset_paths: {chr(10)} {chr(10).join(dataset_paths)}')
    logger.info(f'prompt column: {data_args.prompt}  label column: {data_args.response}')

    return REGISTERED_DATASETS[key](dataset_paths, data_args)


@register_dataset("default")
@register_dataset("json")
def default_json_dataset(
    dataset_paths: "DataPaths",
    data_args: "DataArguments",
) -> Union["Dataset", "IterableDataset"]:
    return datasets.load_dataset('json', data_files=dataset_paths)['train']