import numpy as np
import pandas as pd
import torch
import torch.utils.data as td
import torchvision.datasets as tvd
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as tvt
import sklearn.datasets as skd

import os
curr_dir = os.getcwd()
DATASET_DIR = './data/'


def get_normalization_params(dataset_name):
    """
    Returns the normalization parameters (mean, std) for different datasets.
    For OOD datasets, returns the normalization parameters of the corresponding ID dataset.
    """
    # Standard normalization parameters for common datasets
    normalization_params = {
        'MNIST': ([0.1307], [0.3081]),
        'KMNIST': ([0.1307], [0.3081]),  # Similar to MNIST
        'FMNIST': ([0.2860], [0.3530]),  # Fashion-MNIST
        'EMNIST': ([0.1751], [0.3332]),  # EMNIST letters
        'CIFAR10': ([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        'CIFAR100': ([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]),
        'SVHN': ([0.4377, 0.4438, 0.4728], [0.1980, 0.2010, 0.1970]),
        # For OOD datasets, use the corresponding ID dataset's parameters
        'CIFAR10_oodom': ([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),  # Use CIFAR10 params
        'CIFAR100_oodom': ([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]),  # Use CIFAR100 params
        'SVHN_oodom': ([0.4377, 0.4438, 0.4728], [0.1980, 0.2010, 0.1970]),  # Use SVHN params
        # For other OOD datasets, use ImageNet normalization as default
        'DTD': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        'GTSRB': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        'OxfordIIITPet': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        'FGVCAircraft': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        'LFWPeople': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        'Places365': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        'Flowers102': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        'Food101': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        'LSUN': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        'FAKE': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    }

    return normalization_params.get(dataset_name, (None, None))


def get_dataset_info(dataset_name, normalize=False):
    dims = (1, 28, 28)
    if dataset_name == 'segment_window_sky_missing':
        dataset = Dataset.segment_window_sky_missing(train=True, normalize=normalize)
        test_set = Dataset.segment_window_sky_missing(train=False, normalize=normalize)
        num_classes = 5
    elif dataset_name == 'segment_window_only':
        dataset = Dataset.segment_window_only(normalize=normalize)
        test_set = Dataset.segment_window_only(normalize=normalize)
        num_classes = 7
    elif dataset_name == 'segment_sky_only':
        dataset = Dataset.segment_sky_only(normalize=normalize)
        test_set = Dataset.segment_sky_only(normalize=normalize)
        num_classes = 7
    elif dataset_name == 'sensorless_drive_9_10_11_missing':
        dataset = Dataset.sensorless_drive_9_10_11_missing(train=True, normalize=normalize)
        test_set = Dataset.sensorless_drive_9_10_11_missing(train=False, normalize=normalize)
        num_classes = 8
    elif dataset_name == 'sensorless_drive_9_only':
        dataset = Dataset.sensorless_drive_9_only(normalize=normalize)
        test_set = Dataset.sensorless_drive_9_only(normalize=normalize)
        num_classes = 11
    elif dataset_name == 'sensorless_drive_10_11_only':
        dataset = Dataset.sensorless_drive_10_11_only(normalize=normalize)
        test_set = Dataset.sensorless_drive_10_11_only(normalize=normalize)
        num_classes = 11
    elif dataset_name == 'MNIST':
        dataset = Dataset.mnist(train=True, normalize=normalize)
        test_set = Dataset.mnist(train=False, normalize=normalize)
        num_classes = 10
    elif dataset_name == 'KMNIST':
        dataset = Dataset.kmnist(train=True, normalize=normalize)
        test_set = Dataset.kmnist(train=False, normalize=normalize)
        num_classes = 10
    elif dataset_name == 'FMNIST':
        dataset = Dataset.fashion_mnist(train=True, normalize=normalize)
        test_set = Dataset.fashion_mnist(train=False, normalize=normalize)
        num_classes = 10
    elif dataset_name == 'CIFAR10':
        dataset = Dataset.cifar10(train=True, image_transforms=[tvt.RandomHorizontalFlip(),
                                                                tvt.RandomCrop(32, 4),
                                                                tvt.RandomRotation(degrees=15)], normalize=normalize, dataset_name='CIFAR10')
        test_set = Dataset.cifar10(train=False, normalize=normalize, dataset_name='CIFAR10')
        num_classes = 10
        dims = (3, 32, 32)
    elif dataset_name == 'CIFAR10_oodom':
        dataset = Dataset.cifar10(train=True,
                                  image_transforms=[tvt.RandomHorizontalFlip(),
                                                    tvt.RandomCrop(32, 4),
                                                    tvt.RandomRotation(degrees=15)],
                                  tensor_transforms=[tvt.Lambda(lambda x: x * 255.)], normalize=normalize, dataset_name='CIFAR10_oodom')
        test_set = Dataset.cifar10(train=False, tensor_transforms=[tvt.Lambda(lambda x: x * 255.)], normalize=normalize, dataset_name='CIFAR10_oodom')
        num_classes = 10
        dims = (3, 32, 32)
    elif dataset_name == 'CIFAR100':
        dataset = Dataset.cifar100(train=True, image_transforms=[tvt.RandomHorizontalFlip(),
                                                                tvt.RandomCrop(32, 4),
                                                                tvt.RandomRotation(degrees=15)], normalize=normalize, dataset_name='CIFAR100')
        test_set = Dataset.cifar100(train=False, normalize=normalize, dataset_name='CIFAR100')
        num_classes = 100
        dims = (3, 32, 32)
    elif dataset_name == 'CIFAR100_oodom':
        dataset = Dataset.cifar100(train=True,
                                   image_transforms=[tvt.RandomHorizontalFlip(),
                                                     tvt.RandomCrop(32, 4),
                                                     tvt.RandomRotation(degrees=15)],
                                   tensor_transforms=[tvt.Lambda(lambda x: x * 255.)], normalize=normalize, dataset_name='CIFAR100_oodom')
        test_set = Dataset.cifar100(train=False, tensor_transforms=[tvt.Lambda(lambda x: x * 255.)], normalize=normalize, dataset_name='CIFAR100_oodom')
        num_classes = 100
        dims = (3, 32, 32)
    elif dataset_name == 'SVHN':
        dataset = Dataset.svhn(train=True, normalize=normalize, dataset_name='SVHN')
        test_set = Dataset.svhn(train=False, normalize=normalize, dataset_name='SVHN')
        num_classes = 10
        dims = (3, 32, 32)
    elif dataset_name == 'SVHN_oodom':
        dataset = Dataset.svhn(train=True, tensor_transforms=[tvt.Lambda(lambda x: x * 255.)], normalize=normalize, dataset_name='SVHN_oodom')
        test_set = Dataset.svhn(train=False, tensor_transforms=[tvt.Lambda(lambda x: x * 255.)], normalize=normalize, dataset_name='SVHN_oodom')
        num_classes = 10
        dims = (3, 32, 32)
    elif dataset_name == 'LSUN':
        dataset = Dataset.lsun(train=True, normalize=normalize)
        test_set = Dataset.lsun(train=False, normalize=normalize)
        num_classes = 10
    elif dataset_name == 'FAKE':
        dataset = Dataset.fake(train=True, normalize=normalize)
        test_set = Dataset.fake(train=False, normalize=normalize)
        num_classes = 10
    else:
        raise NotImplementedError

    return len(dataset), len(test_set), num_classes, dims


def get_dataset(dataset_name, batch_size, split=[.8, .2], seed=1, test_shuffle_seed=None, batch_size_eval=1024,
                n_test_data=None, id_dataset=None, normalize=False):
    if dataset_name == 'segment_window_sky_missing':
        dataset = Dataset.segment_window_sky_missing(train=True, normalize=normalize)
        test_set = Dataset.segment_window_sky_missing(train=False, normalize=normalize)
        output_dim = 5
    elif dataset_name == 'segment_window_only':
        dataset = Dataset.segment_window_only(normalize=normalize)
        test_set = Dataset.segment_window_only(normalize=normalize)
        output_dim = 7
    elif dataset_name == 'segment_sky_only':
        dataset = Dataset.segment_sky_only(normalize=normalize)
        test_set = Dataset.segment_sky_only(normalize=normalize)
        output_dim = 7
    elif dataset_name == 'sensorless_drive_9_10_11_missing':
        dataset = Dataset.sensorless_drive_9_10_11_missing(train=True, normalize=normalize)
        test_set = Dataset.sensorless_drive_9_10_11_missing(train=False, normalize=normalize)
        output_dim = 8
    elif dataset_name == 'sensorless_drive_9_only':
        dataset = Dataset.sensorless_drive_9_only(normalize=normalize)
        test_set = Dataset.sensorless_drive_9_only(normalize=normalize)
        output_dim = 11
    elif dataset_name == 'sensorless_drive_10_11_only':
        dataset = Dataset.sensorless_drive_10_11_only(normalize=normalize)
        test_set = Dataset.sensorless_drive_10_11_only(normalize=normalize)
        output_dim = 11
    elif dataset_name == 'MNIST':
        dataset = Dataset.mnist(train=True, normalize=normalize)
        test_set = Dataset.mnist(train=False, normalize=normalize)
        output_dim = 10
    elif dataset_name == 'KMNIST':
        dataset = Dataset.kmnist(train=True, normalize=normalize)
        test_set = Dataset.kmnist(train=False, normalize=normalize)
        output_dim = 10
    elif dataset_name == 'FMNIST':
        dataset = Dataset.fashion_mnist(train=True, normalize=normalize)
        test_set = Dataset.fashion_mnist(train=False, normalize=normalize)
        output_dim = 10
    elif dataset_name == 'EMNIST':
        dataset = Dataset.emnist(train=True, normalize=normalize)
        test_set = Dataset.emnist(train=False, normalize=normalize)
        output_dim = 26
    elif dataset_name == 'CIFAR10':
        dataset = Dataset.cifar10(train=True, image_transforms=[tvt.RandomHorizontalFlip(),
                                                                tvt.RandomCrop(32, 4),
                                                                tvt.RandomRotation(degrees=15)], normalize=normalize, dataset_name='CIFAR10')
        test_set = Dataset.cifar10(train=False, normalize=normalize, dataset_name='CIFAR10')
        output_dim = 10
    elif dataset_name == 'CIFAR10_oodom':
        dataset = Dataset.cifar10(train=True,
                                  image_transforms=[tvt.RandomHorizontalFlip(),
                                                    tvt.RandomCrop(32, 4),
                                                    tvt.RandomRotation(degrees=15)],
                                  tensor_transforms=[tvt.Lambda(lambda x: x * 255.)], normalize=normalize, dataset_name='CIFAR10_oodom')
        test_set = Dataset.cifar10(train=False, tensor_transforms=[tvt.Lambda(lambda x: x * 255.)], normalize=normalize, dataset_name='CIFAR10_oodom')
        output_dim = 10
    elif dataset_name == 'CIFAR100':
        dataset = Dataset.cifar100(train=True, image_transforms=[tvt.RandomHorizontalFlip(),
                                                                tvt.RandomCrop(32, 4),
                                                                tvt.RandomRotation(degrees=15)], normalize=normalize, dataset_name='CIFAR100')
        test_set = Dataset.cifar100(train=False, normalize=normalize, dataset_name='CIFAR100')
        output_dim = 100
    elif dataset_name == 'CIFAR100_oodom':
        dataset = Dataset.cifar100(train=True,
                                   image_transforms=[tvt.RandomHorizontalFlip(),
                                                     tvt.RandomCrop(32, 4),
                                                     tvt.RandomRotation(degrees=15)],
                                   tensor_transforms=[tvt.Lambda(lambda x: x * 255.)], normalize=normalize, dataset_name='CIFAR100_oodom')
        test_set = Dataset.cifar100(train=False, tensor_transforms=[tvt.Lambda(lambda x: x * 255.)], normalize=normalize, dataset_name='CIFAR100_oodom')
        output_dim = 100
    elif dataset_name == 'SVHN':
        dataset = Dataset.svhn(train=True, normalize=normalize, dataset_name='SVHN')
        test_set = Dataset.svhn(train=False, normalize=normalize, dataset_name='SVHN')
        output_dim = 10
    elif dataset_name == 'DTD':
        dataset = Dataset.dtd(train=True, image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        test_set = Dataset.dtd(train=False, image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        output_dim = 47
    elif dataset_name == 'GTSRB':
        dataset = Dataset.GTSRB(train=True, image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        test_set = Dataset.GTSRB(train=False, image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        output_dim = 50
    elif dataset_name == 'OxfordIIITPet':
        dataset = Dataset.OxfordIIITPet(train=True, image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        test_set = Dataset.OxfordIIITPet(train=False, image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        output_dim = 37
    elif dataset_name == 'FGVCAircraft':
        dataset = Dataset.FGVCAircraft(train=True, image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        test_set = Dataset.FGVCAircraft(train=False, image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        output_dim = 100
    elif dataset_name == 'SVHN_oodom':
        dataset = Dataset.svhn(train=True, tensor_transforms=[tvt.Lambda(lambda x: x * 255.)], normalize=normalize, dataset_name='SVHN_oodom')
        test_set = Dataset.svhn(train=False, tensor_transforms=[tvt.Lambda(lambda x: x * 255.)], normalize=normalize, dataset_name='SVHN_oodom')
        output_dim = 10
    elif dataset_name == 'LSUN':
        dataset = Dataset.lsun(train=True, normalize=normalize)
        test_set = Dataset.lsun(train=False, normalize=normalize)
        output_dim = 10
    elif dataset_name == 'FAKE':
        dataset = Dataset.fake(train=True, normalize=normalize)
        test_set = Dataset.fake(train=False, normalize=normalize)
        output_dim = 10
    elif dataset_name == 'LFWPeople':
        dataset = Dataset.LFWPeople(train=True, image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        test_set = Dataset.LFWPeople(train=False, image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        output_dim = 5749  # ?
    elif dataset_name == 'Places365':
        dataset = Dataset.Places365(image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        test_set = Dataset.Places365(image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        output_dim = 5749  # ?
    elif dataset_name == 'Flowers102':
        dataset = Dataset.Flowers102(image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        test_set = Dataset.Flowers102(image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        output_dim = 211  # ?
    elif dataset_name == 'Food101':
        dataset = Dataset.Food101(image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        test_set = Dataset.Food101(image_transforms=[tvt.Resize((32, 32))], normalize=normalize)
        output_dim = 101  # ?
    elif dataset_name == 'random_noise':
        train_num_images, test_num_images, num_classes, dims = get_dataset_info(id_dataset, normalize=normalize)
        dataset = Dataset.random_noise_image_dataset(num_classes=num_classes,
                                                     num_images_per_class=train_num_images // num_classes,
                                                     dims=dims, normalize=normalize)
        test_set = Dataset.random_noise_image_dataset(num_classes=num_classes,
                                                      num_images_per_class=test_num_images // num_classes,
                                                      dims=dims, normalize=normalize)
        output_dim = num_classes
    else:
        raise NotImplementedError

    indices = list(range(len(dataset)))
    assert np.sum(split) == 1.0
    split = int(len(dataset) * split[0])

    np.random.seed(seed)
    torch.manual_seed(seed)

    np.random.shuffle(indices)

    train_indices, val_indices = indices[:split], indices[split:]
    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, num_workers=0)
    val_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size_eval, sampler=val_sampler, num_workers=0)
    if n_test_data is not None:
        test_set = torch.utils.data.Subset(test_set, np.arange(min(n_test_data, len(test_set))))
    test_loader = td.DataLoader(test_set, batch_size_eval, shuffle=True, num_workers=0)

    N = torch.zeros(output_dim, dtype=torch.long)
    for _, Y in train_loader:
        N.scatter_add_(0, Y, torch.ones_like(Y))

    return train_loader, val_loader, test_loader, N, output_dim


class Dataset:
    """
    The dataset class provides static methods to use different datasets and their splits for
    different seeds. The following is ensured for all datasets:
    * All features are normalized to have zero mean and unit variance.
    All static methods accept a single `seed` parameter which governs the seed to use for splitting
    the datasets into train/val/test. They always return a triple with PyTorch datasets for training,
    validation, and test datasets, respectively.
    """

    @classmethod
    def toy_classification(cls):
        """
        Generates a 2D toy dataset consisting of three clusters to be used for classification. Each
        cluster has its own label, two clusters are slightly overlapping to model a region of high
        epistemic uncertainty.
        * Task: classification
        * Features: [2]
        * Samples: 3,072
        """
        X, y = skd.make_blobs(
            3072, 2, centers=[(-1, 0), (1, 0), (0.5, 0.5)], cluster_std=0.25, random_state=42
        )
        return cls._tensor_dataset(X.astype(np.float32), y.astype(np.int64))

    @classmethod
    def segment_window_sky_missing(cls, train=True, normalize=False):
        """
        Returns the segment dataset, either for training or testing.
        * Task: classification
        * Features: [18]
        * Samples: 1650
        """

        directory_dataset = '/nfs/staff-hdd/charpent/dirichlet-robustness/datasets/segmentation/'
        values = pd.read_csv(directory_dataset + 'segment_window_sky_missing.csv', header=0, index_col=0).values
        n_data, input_dim = values.shape[0], values.shape[1] - 1

        # Features and label datasets
        if train:
            X, y = values[:int(.8 * n_data), :-1], np.squeeze(values[:int(.8 * n_data), -1:])
        else:
            X, y = values[int(.8 * n_data):, :-1], np.squeeze(values[int(.8 * n_data):, -1:])

        if normalize:
            mean, std = get_normalization_params('segment_window_sky_missing')
            if mean is not None and std is not None:
                X = (X - mean) / std

        return cls._tensor_dataset(X.astype(np.float32), y.astype(np.int64))

    @classmethod
    def segment_window_only(cls, normalize=False):
        """
        Returns the segment dataset with class window.
        * Task: classification
        * Features: [18]
        * Samples: 330
        """

        directory_dataset = '/nfs/staff-hdd/charpent/dirichlet-robustness/datasets/segmentation/'
        values = pd.read_csv(directory_dataset + 'segment_window_only.csv', header=0, index_col=0).values
        n_data, input_dim = values.shape[0], values.shape[1] - 1

        # Features and label datasets
        X, y = values[:, :-1], np.squeeze(values[:, -1:])

        if normalize:
            mean, std = get_normalization_params('segment_window_only')
            if mean is not None and std is not None:
                X = (X - mean) / std

        return cls._tensor_dataset(X.astype(np.float32), y.astype(np.int64))

    @classmethod
    def segment_sky_only(cls, normalize=False):
        """
        Returns the segment dataset with class sky.
        * Task: classification
        * Features: [18]
        * Samples: 330
        """

        directory_dataset = '/nfs/staff-hdd/charpent/dirichlet-robustness/datasets/segmentation/'
        values = pd.read_csv(directory_dataset + 'segment_sky_only.csv', header=0, index_col=0).values
        n_data, input_dim = values.shape[0], values.shape[1] - 1

        # Features and label datasets
        X, y = values[:, :-1], np.squeeze(values[:, -1:])

        if normalize:
            mean, std = get_normalization_params('segment_sky_only')
            if mean is not None and std is not None:
                X = (X - mean) / std

        return cls._tensor_dataset(X.astype(np.float32), y.astype(np.int64))

    @classmethod
    def sensorless_drive_9_10_11_missing(cls, train=True, normalize=False):
        """
        Returns the segment dataset, either for training or testing.
        * Task: classification
        * Features: [48]
        * Samples: 42552
        """

        directory_dataset = '/nfs/staff-hdd/charpent/dirichlet-robustness/datasets/sensorless_drive/'
        values = pd.read_csv(directory_dataset + 'sensorless_drive_9_10_11_missing.csv', header=0, index_col=0).values
        n_data, input_dim = values.shape[0], values.shape[1] - 1

        # Features and label datasets
        if train:
            X, y = values[:int(.8 * n_data), :-1], np.squeeze(values[:int(.8 * n_data), -1:])
        else:
            X, y = values[int(.8 * n_data):, :-1], np.squeeze(values[int(.8 * n_data):, -1:])

        if normalize:
            mean, std = get_normalization_params('sensorless_drive_9_10_11_missing')
            if mean is not None and std is not None:
                X = (X - mean) / std

        return cls._tensor_dataset(X.astype(np.float32), y.astype(np.int64))

    @classmethod
    def sensorless_drive_9_only(cls, normalize=False):
        """
        Returns the segment dataset with class 9.
        * Task: classification
        * Features: [48]
        * Samples: 5319
        """

        directory_dataset = '/nfs/staff-hdd/charpent/dirichlet-robustness/datasets/sensorless_drive/'
        values = pd.read_csv(directory_dataset + 'sensorless_drive_9_only.csv', header=0, index_col=0).values
        n_data, input_dim = values.shape[0], values.shape[1] - 1

        # Features and label datasets
        X, y = values[:, :-1], np.squeeze(values[:, -1:])

        if normalize:
            mean, std = get_normalization_params('sensorless_drive_9_only')
            if mean is not None and std is not None:
                X = (X - mean) / std

        return cls._tensor_dataset(X.astype(np.float32), y.astype(np.int64))

    @classmethod
    def sensorless_drive_10_11_only(cls, normalize=False):
        """
        Returns the segment dataset wiht class 10 and 11.
        * Task: classification
        * Features: [48]
        * Samples: 10638
        """

        directory_dataset = '/nfs/staff-hdd/charpent/dirichlet-robustness/datasets/sensorless_drive/'
        values = pd.read_csv(directory_dataset + 'sensorless_drive_10_11_only.csv', header=0, index_col=0).values
        n_data, input_dim = values.shape[0], values.shape[1] - 1

        # Features and label datasets
        X, y = values[:, :-1], np.squeeze(values[:, -1:])

        if normalize:
            mean, std = get_normalization_params('sensorless_drive_10_11_only')
            if mean is not None and std is not None:
                X = (X - mean) / std

        return cls._tensor_dataset(X.astype(np.float32), y.astype(np.int64))


    @classmethod
    def mnist(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False):
        """
        Returns the MNIST dataset, either for training or testing:
        * Task: classification
        * Features: [1, 28, 28]
        * Classes: 10
        * Samples: 60,000 | 10,000
        """
        dataset = tvd.MNIST(
            DATASET_DIR, download=True, train=train,
            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms)
        )

        if normalize:
            mean, std = get_normalization_params('MNIST')
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def kmnist(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False):
        """
        Returns the KMNIST dataset, either for training or testing:
        * Task: classification
        * Features: [1, 28, 28]
        * Classes: 10
        * Samples: 60,000 | 10,000
        """
        dataset = tvd.KMNIST(
            DATASET_DIR, download=True, train=train,
            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms)
        )

        if normalize:
            mean, std = get_normalization_params('KMNIST')
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def emnist(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False):
        """
        Returns the KMNIST dataset, either for training or testing:
        * Task: classification
        * Features: [1, 28, 28]
        * Classes: 10
        * Samples: 60,000 | 10,000
        """
        dataset = tvd.EMNIST(
            DATASET_DIR, download=True, train=train, split='letters',
            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms)
        )

        if normalize:
            mean, std = get_normalization_params('EMNIST')
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def fashion_mnist(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False):
        """
        Returns the Fashion-MNIST dataset, either for training or testing:
        * Task: classification
        * Features: [1, 28, 28]
        * Classes: 10
        * Samples: 60,000 | 10,000
        """
        dataset = tvd.FashionMNIST(
            DATASET_DIR, download=True, train=train,
            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms)
        )

        if normalize:
            mean, std = get_normalization_params('FMNIST')
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def cifar10(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False, dataset_name='CIFAR10'):
        """
        Returns the CIFAR-10 dataset, either for training or testing:
        * Task: classification
        * Features: [3, 32, 32]
        * Classes: 10
        * Samples: 50,000 | 10,000
        """
        dataset = tvd.CIFAR10(
            DATASET_DIR, download=True, train=train,
            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms)
        )

        if normalize:
            # For OOD datasets, use the base dataset's normalization parameters
            base_dataset_name = 'CIFAR10' if dataset_name.startswith('CIFAR10') else dataset_name
            mean, std = get_normalization_params(base_dataset_name)
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def cifar100(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False, dataset_name='CIFAR100'):
        """
        Returns the CIFAR-100 dataset, either for training or testing:
        * Task: classification
        * Features: [3, 32, 32]
        * Classes: 100
        * Samples: 50,000 | 10,000
        """
        dataset = tvd.CIFAR100(
            DATASET_DIR, download=True, train=train,
            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms)
        )

        if normalize:
            # For OOD datasets, use the base dataset's normalization parameters
            base_dataset_name = 'CIFAR100' if dataset_name.startswith('CIFAR100') else dataset_name
            mean, std = get_normalization_params(base_dataset_name)
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def svhn(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False, dataset_name='SVHN'):
        """
        Returns the SVHN dataset, either for training or testing:
        * Task: classification
        * Features: [3, 32, 32]
        * Classes: 10
        * Samples: 73,257 | 26,032
        """
        split = 'train' if train else 'test'
        dataset = tvd.SVHN(
            DATASET_DIR, download=True, split=split,
            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms)
        )

        if normalize:
            # For OOD datasets, use the base dataset's normalization parameters
            base_dataset_name = 'SVHN' if dataset_name.startswith('SVHN') else dataset_name
            mean, std = get_normalization_params(base_dataset_name)
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def lsun(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False):
        """
        Returns the SVHN dataset, either for training or testing:
        * Task: classification
        * Features: XXX
        * Classes: 10
        * Samples: 73,257 | 26,032
        """
        # split = 'train' if train else 'test'
        split = 'test' # I only downloaded the test set.
        dataset = tvd.LSUN(DATASET_DIR, classes=split,
                        transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms))

        if normalize:
            mean, std = get_normalization_params('LSUN')
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def fake(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False):
        """
        Returns the FAKE dataset, either for training or testing:
        * Task: classification
        * Features: [3, 32, 32]
        * Classes: 10
        * Samples: 100 000
        """
        dataset = tvd.FakeData(size=10000, image_size=(3, 32, 32), num_classes=10,
                            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms))

        if normalize:
            mean, std = get_normalization_params('FAKE')
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def random_noise_image_dataset(cls, num_classes, num_images_per_class, mean=0, sigma=1, dims=(1, 28, 28),
                                   bounds=None, normalize=False):
        if bounds is None:
            bounds = torch.tensor([0.0, 1.0], dtype=torch.float32)
        clip_lower, clip_upper = bounds
        import random
        random_ixs = list(range(num_classes*num_images_per_class))
        random.shuffle(random_ixs)

        generate_dims = (num_classes*num_images_per_class,) + dims
        X = torch.randn(generate_dims) * sigma + mean
        X = X.clamp(clip_lower, clip_upper)
        y = torch.repeat_interleave(torch.arange(0, num_classes), num_images_per_class)

        if normalize:
            mean, std = get_normalization_params('random_noise')
            if mean is not None and std is not None:
                X = (X - mean) / std

        return td.TensorDataset(X[random_ixs], y[random_ixs])

    @classmethod
    def dtd(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False):
        split = 'train' if train else 'test'
        dataset = tvd.DTD(
            DATASET_DIR, download=True, split=split,
            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms)
        )

        if normalize:
            mean, std = get_normalization_params('DTD')
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def GTSRB(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False):
        split = 'train' if train else 'test'
        dataset = tvd.GTSRB(
            DATASET_DIR, download=True, split=split,
            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms)
        )

        if normalize:
            mean, std = get_normalization_params('GTSRB')
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def OxfordIIITPet(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False):
        split = 'trainval' if train else 'test'
        dataset = tvd.OxfordIIITPet(
            DATASET_DIR, download=True, split=split,
            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms)
        )

        if normalize:
            mean, std = get_normalization_params('OxfordIIITPet')
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def FGVCAircraft(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False):
        split = 'train' if train else 'test'
        dataset = tvd.FGVCAircraft(
            DATASET_DIR, download=True, split=split,
            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms)
        )

        if normalize:
            mean, std = get_normalization_params('FGVCAircraft')
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def LFWPeople(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False):
        # split = 'train' if train else 'test'
        split = 'test'  # only need to download the test split
        dataset = tvd.LFWPeople(
            DATASET_DIR, download=True, split=split,
            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms)
        )

        if normalize:
            mean, std = get_normalization_params('LFWPeople')
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def Places365(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False):
        # split = 'train' if train else 'test'
        split = 'val'  # only need to download the val split
        dataset = tvd.Places365(
            DATASET_DIR, download=False, split=split, small=True,
            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms)
        )

        if normalize:
            mean, std = get_normalization_params('Places365')
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def Flowers102(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False):
        # split = 'train' if train else 'test'
        split = 'test'  # only need to download the test split
        dataset = tvd.Flowers102(
            DATASET_DIR, download=True, split=split,
            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms)
        )

        if normalize:
            mean, std = get_normalization_params('Flowers102')
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @classmethod
    def Food101(cls, train=True, image_transforms=[], tensor_transforms=[], normalize=False):
        # split = 'train' if train else 'test'
        split = 'test'  # only need to download the test split
        dataset = tvd.Food101(
            DATASET_DIR, download=False, split=split,
            transform=tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms)
        )

        if normalize:
            mean, std = get_normalization_params('Food101')
            if mean is not None and std is not None:
                dataset.transform = tvt.Compose(image_transforms + [tvt.ToTensor()] + tensor_transforms + [tvt.Normalize(mean, std)])

        return dataset

    @staticmethod
    def _tensor_dataset(X, y):
        """
        Helper method to create a TensorDataset from numpy arrays.
        """
        return td.TensorDataset(torch.from_numpy(X), torch.from_numpy(y))