import os
from typing import Dict, Type

from torch.utils.data import Dataset
from torchvision.datasets import CIFAR10, CIFAR100, SVHN, ImageFolder, ImageNet

from src.config import DATASETS_DIR

datasets_registry: Dict[str, Dataset] = {
    "cifar10": CIFAR10,
    "cifar100": CIFAR100,
    "svhn": SVHN,
    "imagenet1k": ImageNet,
    "ilsvrc2012": ImageNet,
}


def register_dataset(dataset_name: str):
    def register_model_cls(cls):
        if dataset_name in datasets_registry:
            raise ValueError(f"Cannot register duplicate dataset ({dataset_name})")
        datasets_registry[dataset_name] = cls
        return cls

    return register_model_cls


def get_dataset(
    dataset_name: str = None, root: str = DATASETS_DIR, **kwargs
) -> Dataset:
    if dataset_name is not None:
        return datasets_registry[dataset_name](root, **kwargs)
    else:
        try:
            return ImageFolder(root, **kwargs)
        except:
            raise ValueError(f"Dataset {root} not found")


def get_dataset_cls(dataset_name: str) -> Type[Dataset]:
    try:
        return datasets_registry[dataset_name]
    except:
        return ImageFolder


def get_datasets_names():
    return list(datasets_registry.keys())
