from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader

from src.utils.sysutils import get_cores_count


class Denormalize(object):
    """
    Undoes the normalization and returns the reconstructed images in the input domain.
    """

    def __init__(self, mean, std, inplace=False):
        self.mean = mean
        self.demean = [-m / s for m, s in zip(mean, std)]
        self.std = std
        self.destd = [1 / s for s in std]
        self.inplace = inplace

    def __call__(self, tensor):
        tensor.normalize(tensor, self.demean, self.destd, self.inplace)
        # clamp to get rid of numerical errors
        return torch.clamp(tensor, 0.0, 1.0)


class CIFAR10:

    def __init__(self,
                 train_data_args,
                 val_data_args,
                 dataset_args=None,
                 split_ratio=0.9):
        """
        use_random_flip not used.
        """

        self.cpu_count = get_cores_count()
        self.train_data_args = train_data_args
        self.val_data_args = val_data_args

        dataset_dir = './data/' + self.__class__.__name__

        # RGB Order
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2470, 0.2435, 0.2616)

        self.demean = [-m / s for m, s in zip(mean, std)]
        self.destd = [1 / s for s in std]

        self.__normalize_transform = torchvision.transforms.Compose(
            [torchvision.transforms.ToTensor(),
             torchvision.transforms.Normalize(mean, std)])

        # Normalization transform does (x - mean) / std
        # To denormalize use mean* = (-mean/std) and std* = (1/std)
        self.denormalization_transform = torchvision.transforms.Normalize(self.demean, self.destd, inplace=False)

        self.original_training_set = torchvision.datasets.CIFAR10(root=dataset_dir,
                                                                  train=True,
                                                                  download=True,
                                                                  transform=self.__normalize_transform)

        if split_ratio == 1.0:
            print('No split being done')
            self.trainset = self.original_training_set
        else:
            # Split train data into training and cross validation dataset using 9:1 split ration
            self.trainset, self.validationset = self._uniform_train_val_split(split_ratio)

        self.testset = torchvision.datasets.CIFAR10(root=dataset_dir,
                                                    train=False,
                                                    download=True,
                                                    transform=self.__normalize_transform)

    @property
    def train_dataloader(self) -> DataLoader:
        return torch.utils.data.DataLoader(self.trainset,
                                           batch_size=self.train_data_args['batch_size'],
                                           shuffle=self.train_data_args['shuffle'],
                                           pin_memory=True,
                                           num_workers=get_cores_count())

    @property
    def validation_dataloader(self) -> DataLoader:
        return torch.utils.data.DataLoader(self.validationset,
                                           batch_size=self.train_data_args['batch_size'],
                                           shuffle=self.train_data_args['shuffle'],
                                           pin_memory=True,
                                           num_workers=get_cores_count())

    @property
    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.testset,
                                           batch_size=self.val_data_args['batch_size'],
                                           shuffle=self.val_data_args['shuffle'],
                                           pin_memory=True,
                                           num_workers=get_cores_count())

    def imshow(self, img):
        # clamp to get rid of numerical errors
        img = torch.clamp(self.denormalize(img), 0.0, 1.0)  # denormalize
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.show()

    def debug(self):
        # get some random training images
        data_iter = iter(self.train_dataloader)
        images, labels = data_iter.next()

        # show images
        self.imshow(torchvision.utils.make_grid(images))

        # print labels
        pprint(' '.join('%s' % self.classes[labels[j]] for j in range(len(images))))

    def denormalize(self, x):
        return self.denormalization_transform(x)

    @property
    def train_dataset_size(self):
        return len(self.trainset)

    @property
    def val_dataset_size(self):
        return len(self.validationset)

    @property
    def test_dataset_size(self):
        return len(self.testset)

    @property
    def classes(self):
        return ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

    def _uniform_train_val_split(self, split_ratio):
        targets = self.original_training_set.targets
        if type(targets) == list:
            targets = np.array(targets)
            labels = targets
        elif type(targets) == torch.tensor or type(targets) == torch.Tensor:
            labels = targets.numpy()
        training_indices = []
        validation_indices = []
        for i in range(len(self.classes)):
            label_indices = np.argwhere(labels == i)
            samples_per_label = int(split_ratio * len(label_indices))
            training_label_indices = label_indices[:samples_per_label]
            validation_label_indices = label_indices[samples_per_label:]
            training_indices.extend(training_label_indices.squeeze().tolist())
            validation_indices.extend(validation_label_indices.squeeze().tolist())
            assert not set(training_label_indices.ravel().tolist()) & set(validation_label_indices.ravel().tolist())

        uniform_training_subset = torch.utils.data.Subset(self.original_training_set, training_indices)
        uniform_validation_subset = torch.utils.data.Subset(self.original_training_set, validation_indices)
        assert not set(training_indices) & set(validation_indices)
        return uniform_training_subset, uniform_validation_subset

    def pos_neg_balance_weights(self):
        return torch.tensor([1.0] * len(self.classes))

def get_cifar_object():
    dataset_args = dict(
    )

    train_data_args = dict(
        batch_size=8,
        shuffle=False,
    )

    val_data_args = dict(
        batch_size=train_data_args['batch_size'] * 4,
        shuffle=False,
        validate_step_size=1,
    )
    dataset = CIFAR10(train_data_args, val_data_args, dataset_args, split_ratio=0.9)
    return dataset


if __name__ == '__main__':
    dataset = get_cifar_object()
    dataset.debug()

    print('Length of training set = ', len(dataset.trainset))
    print('Length of validation set = ', len(dataset.validationset))
    print('Length of test set = ', len(dataset.testset))