from datasets import load_dataset


class ImageDatasetLoader:
    @staticmethod
    @staticmethod
    def _load_huggingface_dataset(dataset_name: str, seed: int, split: str = "train"):
        """Load Hugging Face dataset and add zero labels."""
        dataset = load_dataset(dataset_name, split=split)
        dataset = dataset.shuffle(seed=seed)

        return dataset

    @staticmethod
    def load_dataset(dataset_name: str, seed: int = 1, split: str = None):
        """Load dataset and get classnames."""
        dataset_loaders = {
            "imagenet-sketch": lambda: ImageDatasetLoader._load_huggingface_dataset(
                "clip-benchmark/wds_imagenet_sketch", seed, split=split
            ),
            "imagenet": lambda: ImageDatasetLoader._load_huggingface_dataset(
                "evanarlian/imagenet_1k_resized_256", seed, split=split
            ),
        }

        if dataset_name in dataset_loaders:
            dataset = dataset_loaders[dataset_name]()

        else:
            raise ValueError(f"Unknown dataset: {dataset_name}")

        return dataset
