from .distributed_dataset import DistributedDataset, distributed_dataset

from .mnist import mnist


def get_dataset(rank, dataset_name, sample_size=10,
                 batch_size=None, remove_index=0,
                transform=None, is_distribute=True,
                seed=777, path="../data", node=2):

    if dataset_name == "mnist":
        return mnist(rank=rank,
                     batch_size=batch_size,
                     transform=transform,
                     is_distribute=is_distribute,
                     sample_size=sample_size,
                     remove_index=remove_index,
                     seed=seed,
                     path=path,
                     node=node)

    # if dataset_name == "CIFAR10":
    #     return cifar10(rank=rank,
    #                    split=split,
    #                    batch_size=batch_size,
    #                    transforms=transforms,
    #                    is_distribute=is_distribute,
    #                    seed=seed,
    #                    path=path,
    #                    **kwargs)
    # elif dataset_name == "CIFAR100":
    #     return cifar100(rank=rank,
    #                     split=split,
    #                     batch_size=batch_size,
    #                     transforms=transforms,
    #                     is_distribute=is_distribute,
    #                     seed=seed,
    #                     path=path,
    #                     **kwargs)
    # elif dataset_name == "TinyImageNet":
    #     return tiny_imagenet(rank=rank,
    #                          split=split,
    #                          batch_size=batch_size,
    #                          transforms=transforms,
    #                          is_distribute=is_distribute,
    #                          seed=seed,
    #                          path=path,
    #                          **kwargs)
