import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from os.path import join
from .coil20 import COIL20
from .shapes3d import Shapes3D
from .colors import COLORS
from transforms.reshape_transform import ReshapeTransform
from .custom_dataset import CustomDataset
from .ssl_dataset import SSLDataset


class Datasets(data.Dataset):

    def __init__(self, dataset, root_folder="raw-datasets/", test_root_folder=None, test_dataset=None,
                 labels_sampling='fixed', n_labels=None, n_samples=100, balanced=False,
                 norm="minmax", flatten=False, coil20_unprocessed=False, debug=False):
        super(Datasets, self).__init__()

        transform_train_list = []
        transform_test_list = []

        if dataset == "mnist":
            transform_train_list.append(transforms.ToTensor())
            transform_train_list.append(transforms.Normalize((0.1307,), (0.3081,)))

            if flatten:
                transform_train_list.append(ReshapeTransform((-1,)))

            transform_train = transforms.Compose(transform_train_list)
            transform_test = transform_train

            self.train_data = datasets.MNIST(root=root_folder, train=True, download=True, transform=transform_train)
            self.test_data = datasets.MNIST(root=root_folder, train=False, download=True, transform=transform_test)

            self.apply_debug_conditions(debug, n_samples)
            self.dim_flatten = self.train_data.data.size(1) * self.train_data.data.size(2)

            self.apply_ssl_conditions(n_labels, labels_sampling, dataset,
                                      root=root_folder, debug=debug, n_samples=n_samples, balanced=balanced)

            self.d_in = 1
            self.hw_in = 28

        elif dataset == "fashion":
            transform_train_list.append(transforms.ToTensor())
            transform_train_list.append(transforms.Normalize((0.5,), (0.5,)))

            if flatten:
                transform_train_list.append(ReshapeTransform((-1,)))

            transform_train = transforms.Compose(transform_train_list)
            transform_test = transform_train

            self.train_data = datasets.FashionMNIST(root=root_folder, train=True,
                                                    download=True, transform=transform_train)
            self.test_data = datasets.FashionMNIST(root=root_folder, train=False,
                                                   download=True, transform=transform_test)

            self.apply_debug_conditions(debug, n_samples)
            self.dim_flatten = self.train_data.data.size(1) * self.train_data.data.size(2)

            self.apply_ssl_conditions(n_labels, labels_sampling, dataset,
                                      root=root_folder, debug=debug, n_samples=n_samples, balanced=balanced)

            self.d_in = 1
            self.hw_in = 28

        elif dataset == "cifar10":

            transform_train_list.append(transforms.RandomCrop(32, padding=4))
            transform_train_list.append(transforms.RandomHorizontalFlip())
            transform_train_list.append(transforms.ToTensor())
            transform_train_list.append(transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]))

            transform_test_list.append(transforms.ToTensor())
            transform_test_list.append(transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]))

            if flatten:
                transform_train_list.append(ReshapeTransform((-1,)))
                transform_test_list.append(ReshapeTransform((-1,)))

            transform_train = transforms.Compose(transform_train_list)
            transform_test = transforms.Compose(transform_test_list)

            self.train_data = datasets.CIFAR10(root=root_folder, train=True, download=True, transform=transform_train)
            self.test_data = datasets.CIFAR10(root=root_folder, train=False, download=True, transform=transform_test)
            
            self.apply_debug_conditions(debug, n_samples)
            
            data_shape = self.train_data.data.shape
            self.dim_flatten = data_shape[1] * data_shape[2] * data_shape[3]

            self.apply_ssl_conditions(n_labels, labels_sampling, dataset,
                                      root=root_folder, debug=debug, n_samples=n_samples, balanced=balanced)

            self.d_in = 3
            self.hw_in = 32

        elif dataset == "cifar100":
            transform_train_list.append(transforms.RandomCrop(32, padding=4))
            transform_train_list.append(transforms.RandomHorizontalFlip())
            transform_train_list.append(transforms.ToTensor())
            transform_train_list.append(transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]))

            transform_test_list.append(transforms.ToTensor())
            transform_test_list.append(transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]))

            if flatten:
                transform_train_list.append(ReshapeTransform((-1,)))
                transform_test_list.append(ReshapeTransform((-1,)))

            transform_train = transforms.Compose(transform_train_list)
            transform_test = transforms.Compose(transform_test_list)

            self.train_data = datasets.CIFAR100(root=root_folder, train=True, download=True, transform=transform_train)
            self.test_data = datasets.CIFAR100(root=root_folder, train=False, download=True, transform=transform_test)

            self.apply_debug_conditions(debug, n_samples)
            
            data_shape = self.train_data.data.shape
            self.dim_flatten = data_shape[1] * data_shape[2] * data_shape[3]

            self.apply_ssl_conditions(n_labels, labels_sampling, dataset,
                                      root=root_folder, debug=debug, n_samples=n_samples, balanced=balanced)

            self.d_in = 3
            self.hw_in = 32

        elif dataset == "svhn":
            transform_train_list.append(transforms.ToTensor())
            if flatten:
                transform_train_list.append(ReshapeTransform((-1,)))

            transform_train = transforms.Compose(transform_train_list)
            transform_test = transform_train

            self.train_data = datasets.SVHN(root=root_folder, split='train', download=True, transform=transform_train)
            self.test_data = datasets.SVHN(root=root_folder, split='test', download=True, transform=transform_test)

            self.apply_debug_conditions(debug, n_samples)

            data_shape = self.train_data.data.shape
            self.dim_flatten = data_shape[1] * data_shape[2] * data_shape[3]

            self.apply_ssl_conditions(n_labels, labels_sampling, dataset,
                                      root=root_folder, debug=debug, n_samples=n_samples, balanced=balanced)

            self.d_in = 3
            self.hw_in = 32

        elif dataset == "usps":

            transform_train_list.append(transforms.ToTensor())
            if flatten:
                transform_train_list.append(ReshapeTransform((-1,)))

            transform_train = transforms.Compose(transform_train_list)
            transform_test = transform_train

            self.train_data = datasets.USPS(root=root_folder, train=True, download=True, transform=transform_train)
            self.test_data = datasets.USPS(root=root_folder, train=False, download=True, transform=transform_test)

            self.apply_debug_conditions(debug, n_samples)

            data_shape = self.train_data.data.shape
            self.dim_flatten = data_shape[1] * data_shape[2]

            self.apply_ssl_conditions(n_labels, labels_sampling, dataset,
                                      root=root_folder, debug=debug, n_samples=n_samples, balanced=balanced)

            self.d_in = 1
            self.hw_in = 16

        elif dataset == "coil20":
            transform_train_list.append(transforms.ToTensor())

            if flatten:
                transform_train_list.append(ReshapeTransform((-1,)))

            transform_train = transforms.Compose(transform_train_list)
            transform_test = transform_train

            self.train_data = COIL20(root=root_folder, processed=not coil20_unprocessed,
                                     download=True, transform=transform_train)
            self.test_data = COIL20(root=root_folder, processed=not coil20_unprocessed,
                                    download=True, transform=transform_test)

            self.apply_debug_conditions(debug, n_samples)

            data_shape = self.train_data.data.shape
            self.dim_flatten = data_shape[1] * data_shape[2]

            self.apply_ssl_conditions(n_labels, labels_sampling, dataset,
                                      root=root_folder, debug=debug, n_samples=n_samples, balanced=balanced)

            self.d_in = 1
            self.hw_in = 32

        elif dataset == "shapes3d":
            transform_train_list.append(transforms.ToTensor())

            if flatten:
                transform_train_list.append(ReshapeTransform((-1,)))

            transform_train = transforms.Compose(transform_train_list)
            transform_test = transform_train

            self.train_data = Shapes3D(root=root_folder, transform=transform_train)
            self.test_data = Shapes3D(root=root_folder, transform=transform_test)

            self.apply_debug_conditions(debug, n_samples)

            self.apply_ssl_conditions(n_labels, labels_sampling, dataset,
                                      root=root_folder, debug=debug, n_samples=n_samples, balanced=balanced)

            self.d_in = 3
            self.hw_in = 64

        elif dataset == "colors":
            transform_train_list.append(transforms.ToTensor())

            if flatten:
                transform_train_list.append(ReshapeTransform((-1,)))

            transform_train = transforms.Compose(transform_train_list)
            transform_test = transform_train

            self.train_data = COLORS(root=root_folder, transform=transform_train)
            self.test_data = COLORS(root=root_folder, transform=transform_test)

            self.apply_debug_conditions(debug, n_samples)
            data_shape = self.train_data.data.shape
            self.dim_flatten = data_shape[1]

            self.apply_ssl_conditions(n_labels, labels_sampling, dataset,
                                      root=root_folder, debug=debug, n_samples=n_samples, balanced=balanced)

            self.d_in = 3
            self.hw_in = 1

        else:
            self.train_data = CustomDataset(load_path=join(root_folder, dataset), norm=norm)
            self.test_data = self.train_data

            if test_root_folder is not None:
                self.test_data = CustomDataset(load_path=join(test_root_folder, test_dataset), norm=norm)

            self.apply_debug_conditions(debug, n_samples)
            self.dim_flatten = self.train_data.data.shape[1]

            self.apply_ssl_conditions(n_labels, labels_sampling, dataset,
                                      root=root_folder, debug=debug, n_samples=n_samples, balanced=balanced)

    def apply_ssl_conditions(self, sampling, sampling_type, dataset,
                             root='raw-datasets/', debug=False, n_samples=100, balanced=False):
        if sampling is not None:
            filename = dataset if not debug else dataset + "_n" + str(n_samples)
            if sampling_type == 'perc':
                filename += "_" + ('%.2f' % sampling).split(".")[1] + "p"
            else:
                filename += "_" + str(int(sampling))

            if balanced:
                filename += "_balanced"

            check_exists = join(root, filename + ".unlabel")

            self.train_data = SSLDataset(self.train_data, sampling_type, sampling, balanced, check_exists)

    def apply_debug_conditions(self, debug, n_samples):
        if debug:
            self.train_data.data = self.train_data.data[:n_samples]
            self.train_data.targets = self.train_data.targets[:n_samples]
            self.test_data.data = self.test_data.data[:n_samples]
            self.test_data.targets = self.test_data.targets[:n_samples]
