import os

import h5py
import numpy as np
import torchvision
from skimage.transform import resize

from .base import BaseDataset, TransformDataset


class MNIST(BaseDataset):
    def __init__(self, path_input):
        super().__init__()
        self.path_input = path_input
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'synthetic'

    def load_dataset(self):
        dataset = torchvision.datasets.MNIST(root=self.path_input,
                                             download=True)
        self.data = dataset.data
        self.targets = dataset.targets


class EMNIST(BaseDataset):
    def __init__(self, path_input):
        super().__init__()
        self.path_input = path_input
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'synthetic'

    def load_dataset(self):
        dataset = torchvision.datasets.EMNIST(root=self.path_input,
                                              split='letters',
                                              download=True)
        self.data = dataset.data
        self.targets = dataset.targets


class KMNIST(BaseDataset):
    def __init__(self, path_input):
        super().__init__()
        self.path_input = path_input
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'synthetic'

    def load_dataset(self):
        dataset = torchvision.datasets.KMNIST(root=self.path_input,
                                              download=True)
        self.data = dataset.data
        self.targets = dataset.targets


class FashionMNIST(BaseDataset):
    def __init__(self, path_input):
        super().__init__()
        self.path_input = path_input
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'synthetic'

    def load_dataset(self):
        dataset = torchvision.datasets.FashionMNIST(root=self.path_input,
                                                    download=True)
        self.data = dataset.data
        self.targets = dataset.targets


class OMNIGLOT(TransformDataset):
    def __init__(self, path_input, size=32):
        super().__init__()
        self.path_input = path_input
        self.size = size
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'synthetic'

    def load_dataset(self):
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((self.size, self.size)),
            torchvision.transforms.ToTensor()
        ])

        dataset = torchvision.datasets.Omniglot(root=self.path_input,
                                                download=True)

        self.to_numpy(dataset, transform)
        self.data = self.data.transpose(0, 2, 3, 1)
        self.data = self.data.squeeze(-1)
        self.data = 1 - self.data


class QuickDraw(BaseDataset):
    def __init__(self, path_input):
        super().__init__()
        self.path_input = path_input
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'synthetic'

    def load_dataset(self):
        path_dataset = os.path.join(self.path_input,
                                    'quickdraw_subset.npz')
        dataset = np.load(path_dataset)
        self.data = dataset['images']
        self.targets = dataset['classes']


class dSprites(BaseDataset):
    def __init__(self, path_input):
        super().__init__()
        self.path_input = path_input
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'synthetic'

    def load_dataset(self):
        path_dataset = \
            os.path.join(self.path_input,
                         'dSprites_subset.npz')
        dataset = np.load(path_dataset)
        self.data = dataset['images']
        self.targets = dataset['classes']


class HDW(BaseDataset):
    def __init__(self, path_input, size=128):
        super().__init__()
        self.path_input = path_input
        self.size = size
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'synthetic'

    def load_dataset(self):
        path_dataset = \
            os.path.join(self.path_input,
                         r'Images(500x500).npy')
        path_targets = \
            os.path.join(self.path_input,
                         r'WriterInfo.npy')
        dataset = np.load(path_dataset)
        targets = np.load(path_targets)
        self.data = np.array([resize(image, (self.size, self.size),
                                     anti_aliasing=True)
                              for image in dataset])
        self.data = \
            (self.data - self.data.min()) / (self.data.max() - self.data.min())
        self.data = 1 - self.data
        self.targets = targets[:, 0]


class Shapes3D(BaseDataset):
    def __init__(self, path_input):
        super().__init__()
        self.path_input = path_input
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'synthetic'

    def load_dataset(self):
        path_dataset = os.path.join(self.path_input, 'subset_3dshapes.h5')
        dataset = h5py.File(path_dataset, 'r')

        self.data = dataset['data'][:]
        self.targets = dataset['labels'][:]


class KKanji(TransformDataset):
    def __init__(self, path_input):
        super().__init__()
        self.path_input = os.path.join(path_input, 'kkanji2')
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'synthetic'

    def load_dataset(self):
        def is_valid_file(path):
            return not os.path.basename(path).startswith('._')

        transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Grayscale()
        ])

        dataset = torchvision.datasets.ImageFolder(root=self.path_input,
                                                   is_valid_file=is_valid_file)

        self.to_numpy(dataset, transform)
        self.data = self.data.transpose(0, 2, 3, 1)
        self.data = np.squeeze(self.data, -1)
        self.data = \
            (self.data - self.data.min()) / (self.data.max() - self.data.min())
        self.filenames = \
            np.array([os.path.basename(path) for path, _ in dataset.imgs])


class SVHN(BaseDataset):
    def __init__(self, path_input):
        super().__init__()
        self.path_input = path_input
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'natural'

    def load_dataset(self):
        dataset = torchvision.datasets.SVHN(root=self.path_input,
                                            download=True)
        self.data = dataset.data
        self.data = self.data.transpose(0, 2, 3, 1)
        self.targets = dataset.labels  # SVHN has no attribute targets!


class CIFAR10(BaseDataset):
    def __init__(self, path_input):
        super().__init__()
        self.path_input = path_input
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'natural'

    def load_dataset(self):
        dataset = torchvision.datasets.CIFAR10(root=self.path_input,
                                               download=True)
        self.data = dataset.data
        self.targets = np.array(dataset.targets)


class CIFAR100(BaseDataset):
    def __init__(self, path_input):
        super().__init__()
        self.path_input = path_input
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'natural'

    def load_dataset(self):
        dataset = torchvision.datasets.CIFAR100(root=self.path_input,
                                                download=True)
        self.data = dataset.data
        self.targets = np.array(dataset.targets)


class Caltech101(TransformDataset):
    def __init__(self, path_input, size=128):
        super().__init__()
        self.path_input = path_input
        self.size = size
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'natural'

    def load_dataset(self):
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((self.size, self.size)),
            torchvision.transforms.ToTensor()
        ])

        dataset = torchvision.datasets.Caltech101(root=self.path_input,
                                                  download=True)
        self.to_numpy(dataset, transform)
        self.data = self.data.transpose(0, 2, 3, 1)


class CelebA(TransformDataset):
    def __init__(self, path_input, size=64):
        super().__init__()
        self.path_input = path_input
        self.size = size
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'natural'

    def load_dataset(self):
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((self.size, self.size)),
            torchvision.transforms.ToTensor()
        ])
        dataset = torchvision.datasets.CelebA(root=self.path_input,
                                              target_type='identity',
                                              download=True)
        self.to_numpy(dataset, transform)
        self.data = self.data.transpose(0, 2, 3, 1)


class StanfordCars(TransformDataset):
    def __init__(self, path_input, size=128):
        super().__init__()
        self.path_input = path_input
        self.size = size
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'natural'

    def load_dataset(self):
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((self.size, self.size)),
            torchvision.transforms.ToTensor()
        ])
        dataset = torchvision.datasets.StanfordCars(root=self.path_input,
                                                    download=False)
        self.to_numpy(dataset, transform)
        self.data = self.data.transpose(0, 2, 3, 1)


class MiniEcoset(BaseDataset):
    def __init__(self, path_input):
        super().__init__()
        self.path_input = path_input
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'natural'

    def load_dataset(self):
        path_dataset = os.path.join(self.path_input, 'subset_miniecoset.h5')
        dataset = h5py.File(path_dataset, 'r')

        self.data = dataset['data'][:]
        self.targets = dataset['labels'][:]


class ImageNet(BaseDataset):
    def __init__(self, path_input):
        super().__init__()
        self.path_input = path_input
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'natural'

    def load_dataset(self):
        path_dataset = os.path.join(self.path_input, 'subset_imagenet.h5')
        dataset = h5py.File(path_dataset, 'r')

        self.data = dataset['data'][:]
        self.targets = dataset['labels'][:]


class MSCOCO(BaseDataset):
    def __init__(self, path_input):
        super().__init__()
        self.path_input = path_input
        self.load_dataset()
        self.preprocess_dataset()
        self.remove_duplicates()
        self.type = 'natural'

    def load_dataset(self):
        path_dataset = os.path.join(self.path_input, 'subset_mscoco.h5')
        dataset = h5py.File(path_dataset, 'r')

        self.data = dataset['data'][:]
        self.targets = dataset['labels'][:]


class THINGS(TransformDataset):
    def __init__(self, path_input, size=128):
        super().__init__()
        self.path_input = os.path.join(path_input, 'images', 'classes')
        self.size = size
        self.load_dataset()
        self.preprocess_dataset()
        self.type = 'natural'

    def load_dataset(self):
        def is_valid_file(path):
            return not os.path.basename(path).startswith('._')

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((self.size, self.size)),
            torchvision.transforms.ToTensor()
        ])

        dataset = torchvision.datasets.ImageFolder(root=self.path_input,
                                                   is_valid_file=is_valid_file)

        self.to_numpy(dataset, transform)
        self.data = self.data.transpose(0, 2, 3, 1)
        self.filenames = \
            np.array([os.path.basename(path) for path, _ in dataset.imgs])
