import torch.distributed as dist

from torch.utils.data import Dataset
from typing import (
    Dict,
    Union,
)

from .datasets import (
    C4,
    LaMini,
    SlimPajama,
)

dataset_dict = {
    'c4': C4,
    'lamini': LaMini,
    'slimpajama': SlimPajama,
}


def initialize_dataset(
    config: Dict[str, Union[Dict, int, str]],
    mode: str,
) -> Dataset:
    """ Initialize the dataset.

    Args:
        config (Dict[str, Union[Dict, int, str]]): The configuration.
        mode (str): The mode.

    Raises:
        KeyError: If the dataset name is not provided or the dataset configuration is not found.

    Returns:
        Dataset: The dataset.
    """

    # Get the local rank.
    local_rank = dist.get_rank() if dist.is_initialized() else 0

    # Initialize the multi-process logger.
    try:
        from ..logging import get_logger
        from ..utils import get_path

        logger = get_logger()
        source = f'{get_path(source_file=__file__)}.{initialize_dataset.__name__}'
    except:
        pass

    # Check the dataset name.
    dataset_name = config[mode]['dataset']['name']
    if not dataset_name:
        message = 'The dataset name is not provided.'

        try:
            logger.log(
                message=message,
                level='error',
                source=source,
            )
        except:
            print(message)

            pass

        raise KeyError(message)

    # Try to get the dataset configuration.
    try:
        dataset_config = config['datasets'][dataset_name]
    except:
        message = f'The dataset configuration for {dataset_name} is not found in the configuration.'

        try:
            logger.log(
                message=message,
                level='error',
                source=source,
            )
        except:
            print(message)

            pass

        raise KeyError(message)

    # Initialize the dataset.
    for name, dataset_class in dataset_dict.items():
        if name in dataset_name:
            dataset = dataset_class(
                config=config,
                dataset_config=dataset_config,
                mode=mode,
                local_rank=local_rank,
            )

            try:
                logger.log(
                    message=dataset.message,
                    source=source,
                )
            except:
                print(dataset.message)

                pass

            break

    return dataset
