import os

import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image


class TinyImages(Dataset):
    def __init__(self, datafile, transform, n_samples=-1):
        super(TinyImages, self).__init__()
  
        self.data = np.load(datafile)
        if n_samples > 0:
            self.data = self.data[:n_samples]
        assert self.data.shape[1:] == (32, 32, 3)
        assert self.data.dtype == np.uint8
      
        self.transform = transform
      
    def __getitem__(self, index):
        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(self.data[index])
        if self.transform is not None:
            img = self.transform(img)
        return img
                                                                                                      
    def __len__(self):
        return self.data.shape[0]


def get_imagenetcrop32_dataset(datadir, transform=transforms.Compose(
    [transforms.RandomResizedCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor()])):
    return torchvision.datasets.ImageFolder(os.path.join(datadir, 'imagenet128/train'), transform)


def get_imagenetdog256_dataset(datadir):
    transform = transforms.Compose(
        [transforms.Resize(256), transforms.CenterCrop(256),
         transforms.RandomHorizontalFlip(), transforms.ToTensor()])
    return torchvision.datasets.ImageFolder(os.path.join(datadir, 'imagenet256/train-dog'), transform)


def get_imagenetcrop128_dataset(datadir):
    transform = transforms.Compose(
        [transforms.RandomResizedCrop(128), transforms.RandomHorizontalFlip(),
         transforms.ToTensor()])
    return torchvision.datasets.ImageFolder(os.path.join(datadir, 'imagenet256/train'), transform)


def get_imagenet128_dataset(datadir,
                            transform=transforms.Compose([transforms.RandomHorizontalFlip(),
                                                          transforms.ToTensor()])):
    return torchvision.datasets.ImageFolder(os.path.join(datadir, 'imagenet128/train'), transform)


def get_imagenet128_val_dataset(datadir, transform=transforms.Compose([transforms.ToTensor()])):
    return torchvision.datasets.ImageFolder(os.path.join(datadir, 'imagenet128/val'), transform)


def get_imagenet256_dataset(datadir):
    return torchvision.datasets.ImageFolder(
        os.path.join(datadir, 'imagenet256/train'),
        transforms.Compose([transforms.RandomResizedCrop(256),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor()]))


def get_celebahq128_dataset(datadir, transform=transforms.Compose([transforms.ToTensor()])):
    return torchvision.datasets.ImageFolder(os.path.join(datadir, 'CelebAHQ128/train'), transform)


def get_celebahq128_val_dataset(datadir, transform=transforms.Compose([transforms.ToTensor()])):
    return torchvision.datasets.ImageFolder( os.path.join(datadir, 'CelebAHQ128/test'), transform)


def get_celebahq256_dataset(datadir, transform=transforms.Compose(
    [transforms.ToTensor()])):
    return torchvision.datasets.ImageFolder( os.path.join(datadir, 'CelebAHQ256/train'), transform)


def get_bedroom128_dataset(datadir, transform=transforms.Compose(
    [transforms.ToTensor()])):
    return torchvision.datasets.ImageFolder(
        os.path.join(datadir, 'Bedroom128/train'), transform)


def get_bedroom128_val_dataset(datadir, transform=transforms.Compose(
    [transforms.ToTensor()])):
    return torchvision.datasets.ImageFolder(
        os.path.join(datadir, 'Bedroom128/test'), transform)


def get_bedroom256_dataset(datadir, transform=transforms.Compose(
    [transforms.ToTensor()])):
    return torchvision.datasets.ImageFolder(
        os.path.join(datadir, 'Bedroom256/train'), transform)


def get_bedroom256_val_dataset(datadir, transform=transforms.Compose(
    [transforms.ToTensor()])):
    return torchvision.datasets.ImageFolder(
        os.path.join(datadir, 'Bedroom256/test'), transform)


def get_iSUN_dataset(datadir, transform=transforms.Compose([transforms.ToTensor()])):
    return torchvision.datasets.ImageFolder(os.path.join(datadir, 'iSUN'), transform)

def get_LSUN_dataset(datadir, transform=transforms.Compose([transforms.ToTensor()])):
    return torchvision.datasets.ImageFolder(os.path.join(datadir, 'LSUN_resize'), transform)

def get_TinyImageNet_dataset(datadir, transform=transforms.Compose([transforms.ToTensor()])):
    return torchvision.datasets.ImageFolder(os.path.join(datadir, 'Imagenet_resize'), transform)


class CIFAR10Unsupervised(torchvision.datasets.CIFAR10):
    def __init__(self, target_class, mode, **kwargs):
        super(CIFAR10Unsupervised, self).__init__(**kwargs)
        self.target_class = target_class
        assert mode in ['include', 'exclude', 'all']
        if target_class is not None:
            assert mode in ['include', 'exclude']
        else:
            assert mode == 'all'

        if target_class is not None:
            if mode == 'include':
                self.data = self.data[np.array(self.targets) == target_class]
            else:
                self.data = self.data[np.array(self.targets) != target_class]

    def __getitem__(self, index):
        # Doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(self.data[index])
        if self.transform is not None:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.data)
