
import torch.utils.data as data
import os
import os.path
from common_imports import np, Image, transforms
from torchvision.datasets.utils import check_integrity, download_url
from torchvision.datasets import ImageFolder

class SVHN(data.Dataset):
    url = ""
    filename = ""
    file_md5 = ""

    split_list = {
        'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
                  "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
        'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
                 "selected_test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"],
        'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
                  "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"],
        'train_and_extra': [
                ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
                 "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
                ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
                 "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]]}

    def __init__(self, root, split='train',
                 transform=None, target_transform=None, download=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.split = split  # training set or test set or extra set

        if self.split not in self.split_list:
            raise ValueError('Wrong split entered! Please use split="train" '
                             'or split="extra" or split="test" '
                             'or split="train_and_extra" ')

        if self.split == "train_and_extra":
            self.url = self.split_list[split][0][0]
            self.filename = self.split_list[split][0][1]
            self.file_md5 = self.split_list[split][0][2]
        else:
            self.url = self.split_list[split][0]
            self.filename = self.split_list[split][1]
            self.file_md5 = self.split_list[split][2]

        if download:

            self.download()

        # import here rather than at top of file because this is
        # an optional dependency for torchvision
        import scipy.io as sio

        # reading(loading) mat file as array
        loaded_mat = sio.loadmat(os.path.join(root, self.filename))

        if self.split == "test":
            self.data = loaded_mat['X']
            self.targets = loaded_mat['y']
            # Note label 10 == 0 so modulo operator required
            self.targets = (self.targets % 10).squeeze()    # convert to zero-based indexing
            self.data = np.transpose(self.data, (3, 2, 0, 1))
        else:
            self.data = loaded_mat['X']
            self.targets = loaded_mat['y']

            if self.split == "train_and_extra":
                extra_filename = self.split_list[split][1][1]
                loaded_mat = sio.loadmat(os.path.join(root, extra_filename))
                self.data = np.concatenate([self.data,
                                                  loaded_mat['X']], axis=3)
                self.targets = np.vstack((self.targets,
                                               loaded_mat['y']))
            # Note label 10 == 0 so modulo operator required
            self.targets = (self.targets % 10).squeeze()    # convert to zero-based indexing
            self.data = np.transpose(self.data, (3, 2, 0, 1))

    def __getitem__(self, index):
        if self.split == "test":
            img, target = self.data[index], self.targets[index]
        else:
            img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(np.transpose(img, (1, 2, 0)))

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target.astype(np.int64)

    def __len__(self):
        if self.split == "test":
            return len(self.data)
        else:
            return len(self.data)

    def _check_integrity(self):
        root = self.root
        if self.split == "train_and_extra":
            md5 = self.split_list[self.split][0][2]
            fpath = os.path.join(root, self.filename)
            train_integrity = check_integrity(fpath, md5)
            extra_filename = self.split_list[self.split][1][1]
            md5 = self.split_list[self.split][1][2]
            fpath = os.path.join(root, extra_filename)
            return check_integrity(fpath, md5) and train_integrity
        else:
            md5 = self.split_list[self.split][2]
            fpath = os.path.join(root, self.filename)
            return check_integrity(fpath, md5)

    def download(self):
        if self._check_integrity():
            print('Files already downloaded and verified')
            return
        if self.split == "train_and_extra":
            md5 = self.split_list[self.split][0][2]
            download_url(self.url, self.root, self.filename, md5)
            extra_filename = self.split_list[self.split][1][1]
            md5 = self.split_list[self.split][1][2]
            download_url(self.url, self.root, extra_filename, md5)
        else:
            md5 = self.split_list[self.split][2]
            download_url(self.url, self.root, self.filename, md5)

def get_CIFAR_ood_datasets(set_name='svhn', normalize=False, norm_params=[(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)]):
    """
    This function gets the exact same ood datasets as presented in the DICE paper.

    set_name: The desired ood set name, should be "svhn", "dtd", or "places".
    normalize: Boolean determines if we apply the normalization.
    norm_params: The parameter for the normalization. (Not used when normalize=False)
    """
    transform_ood = None
    if normalize:
        transform_ood = transforms.Compose([
            transforms.Resize(32),
            transforms.CenterCrop(32),
            transforms.ToTensor(),
            transforms.Normalize(*norm_params),
        ])
    else:
        transform_ood = transforms.Compose([
            transforms.Resize(32),
            transforms.CenterCrop(32),
            transforms.ToTensor(),
        ])
    ood_set = None
    if set_name == 'svhn':
        ood_set = SVHN('datasets/ood_datasets/svhn/', split='test', transform=transform_ood, download=False)
    elif set_name == 'dtd':
        ood_set = ImageFolder(root="datasets/ood_datasets/dtd/images", transform=transform_ood)
    elif set_name == 'places':
        ood_set = ImageFolder(root="datasets/ood_datasets/places365", transform=transform_ood)
    else:
        ood_set = ImageFolder(root="./datasets/ood_datasets/{}".format(set_name), transform=transform_ood)

    return ood_set

def get_imagenet_ood_datasets(set_name='dtd', normalize=True, norm_params=[[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]):
    """
    This function gets the exact same ood datasets as presented in the DICE paper.

    set_name: The desired ood set name, should be "svhn", "dtd", or "places".
    normalize: Boolean determines if we apply the normalization.
    norm_params: The parameter for the normalization. (Not used when normalize=False)
    """
    transform_ood = None
    if normalize:
        transform_ood = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=norm_params[0],
                                std=norm_params[1]),
        ])
    else:
        transform_ood = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])
    ood_set = None
    if set_name == 'dtd':
        ood_set = ImageFolder(root="datasets/ood_datasets/dtd/images", transform=transform_ood)
    elif set_name == 'inat':
        ood_set = ImageFolder(root="./datasets/ood_datasets/iNaturalist", transform=transform_ood)
    elif set_name == 'places':
        ood_set = ImageFolder(root="datasets/ood_datasets/Places", transform=transform_ood)
    elif set_name == 'sun':
        ood_set = ImageFolder(root="datasets/ood_datasets/SUN", transform=transform_ood)
    elif set_name == 'openimage':
        ood_set = ImageFolder(root="datasets/ood_datasets/OpenImage-O", transform=transform_ood)
    else:
        print('Please provide a valid ood set name (dtd, inat, places or sun).')

    return ood_set

def get_imagenet_ood_datasets_for_regnet(set_name='dtd', normalize=True, norm_params=[[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]):
    """
    This function gets the exact same ood datasets as presented in the DICE paper.

    set_name: The desired ood set name, should be "svhn", "dtd", or "places".
    normalize: Boolean determines if we apply the normalization.
    norm_params: The parameter for the normalization. (Not used when normalize=False)
    """
    transform_ood = None
    if normalize:
        transform_ood = transforms.Compose([
            transforms.Resize(384, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(384),
            transforms.ToTensor(),
            transforms.Normalize(mean=norm_params[0],
                                std=norm_params[1]),
        ])
    else:
        transform_ood = transforms.Compose([
            transforms.Resize(384, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(384),
            transforms.ToTensor(),
        ])
    ood_set = None
    if set_name == 'dtd':
        ood_set = ImageFolder(root="datasets/ood_datasets/dtd/images", transform=transform_ood)
    elif set_name == 'inat':
        ood_set = ImageFolder(root="./datasets/ood_datasets/iNaturalist", transform=transform_ood)
    elif set_name == 'places':
        ood_set = ImageFolder(root="datasets/ood_datasets/Places", transform=transform_ood)
    elif set_name == 'sun':
        ood_set = ImageFolder(root="datasets/ood_datasets/SUN", transform=transform_ood)
    elif set_name == 'openimage':
        ood_set = ImageFolder(root="datasets/ood_datasets/OpenImage-O", transform=transform_ood)
    else:
        print('Please provide a valid ood set name (dtd, inat, places, sun or openimage).')

    return ood_set

def get_imagenet_ood_datasets_for_vit(set_name='dtd', normalize=True, norm_params=[[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]):
    """
    This function gets the exact same ood datasets as presented in the DICE paper.

    set_name: The desired ood set name, should be "svhn", "dtd", or "places".
    normalize: Boolean determines if we apply the normalization.
    norm_params: The parameter for the normalization. (Not used when normalize=False)
    """
    transform_ood = None
    if normalize:
        transform_ood = transforms.Compose([
            transforms.Resize(384, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(384),
            transforms.ToTensor(),
            transforms.Normalize(mean=norm_params[0],
                                std=norm_params[1]),
        ])
    else:
        transform_ood = transforms.Compose([
            transforms.Resize(384, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(384),
            transforms.ToTensor(),
        ])
    ood_set = None
    if set_name == 'dtd':
        ood_set = ImageFolder(root="datasets/ood_datasets/dtd/images", transform=transform_ood)
    elif set_name == 'inat':
        ood_set = ImageFolder(root="./datasets/ood_datasets/iNaturalist", transform=transform_ood)
    elif set_name == 'places':
        ood_set = ImageFolder(root="datasets/ood_datasets/Places", transform=transform_ood)
    elif set_name == 'sun':
        ood_set = ImageFolder(root="datasets/ood_datasets/SUN", transform=transform_ood)
    elif set_name == 'openimage':
        ood_set = ImageFolder(root="datasets/ood_datasets/OpenImage-O", transform=transform_ood)
    else:
        print('Please provide a valid ood set name (dtd, inat, places, sun or openimage).')

    return ood_set