import logging
from typing import Callable, Optional, Type

from torch.utils.data import Dataset
from torchvision.datasets import ImageNet


from ..config import DATA_DIR, IMAGENET_ROOT
from .constants import *

from .ninco_ssb_clean import (
    NINCO,
    INaturalistClean,
    NINCOFull,
    OpenImageOClean,
    PlacesClean,
    SSBEasy,
    SSBHard,
    SpeciesClean,
    TexturesClean,
)
from .openimage_o import OpenImageO
from .imagenet import ImageNetCnpz, ImageNetR


_logger = logging.getLogger(__name__)
datasets_registry = {
    "ninco_full": NINCOFull,
    "ninco": NINCO,
    "ssb_hard": SSBHard,
    "ssb_easy": SSBEasy,
    "textures_clean": TexturesClean,
    "places_clean": PlacesClean,
    "inaturalist_clean": INaturalistClean,
    "openimage_o_clean": OpenImageOClean,
    "species_clean": SpeciesClean,
    "imagenet": ImageNet,
    "imagenet_c": ImageNetCnpz,
    "imagenet_c_npz": ImageNetCnpz,
    "imagenet_r": ImageNetR,
    "openimage_o": OpenImageO,
}


def register_dataset(dataset_name: str):
    def register_model_cls(cls):
        if dataset_name in datasets_registry:
            raise ValueError(f"Cannot register duplicate dataset ({dataset_name})")
        datasets_registry[dataset_name] = cls
        return cls

    return register_model_cls


def create_dataset(
    dataset_name: str,
    root: str = DATA_DIR,
    split: Optional[str] = "train",
    transform: Optional[Callable] = None,
    download: Optional[bool] = True,
    **kwargs,
):
    try:
        if dataset_name in ["imagenet"]:
            return datasets_registry[dataset_name](root=IMAGENET_ROOT, split=split, transform=transform, **kwargs)
        return datasets_registry[dataset_name](root=root, split=split, transform=transform, download=download, **kwargs)
    except KeyError as e:
        _logger.error(e)
        raise ValueError("Dataset name is not specified")


def get_dataset_cls(dataset_name: str) -> Type[Dataset]:
    return datasets_registry[dataset_name]
