from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader

from src.utils.sysutils import get_cores_count


class Denormalize(object):
    """
    Undoes the normalization and returns the reconstructed images in the input domain.
    """

    def __init__(self, mean, std, inplace=False):
        self.mean = mean
        self.demean = [-m / s for m, s in zip(mean, std)]
        self.std = std
        self.destd = [1 / s for s in std]
        self.inplace = inplace

    def __call__(self, tensor):
        tensor.normalize(tensor, self.demean, self.destd, self.inplace)
        # clamp to get rid of numerical errors
        return torch.clamp(tensor, 0.0, 1.0)


class Birdsnap:

    def __init__(self,
                 train_data_args,
                 val_data_args,
                 dataset_args=None,
                 split_ratio=[0.8, 0.1, 0.1]):
        """
        use_random_flip not used.
        """

        self.cpu_count = get_cores_count()
        self.train_data_args = train_data_args
        self.val_data_args = val_data_args

        dataset_dir = './data/' + self.__class__.__name__ + '/download/images'

        # RGB Order
        mean = (0.491, 0.506, 0.451)
        std = (0.229, 0.226, 0.267)

        self.demean = [-m / s for m, s in zip(mean, std)]
        self.destd = [1 / s for s in std]

        self.train_normalize_transform = dataset_args.get('train_transform', torchvision.transforms.Compose(
            [torchvision.transforms.RandomHorizontalFlip(),
             torchvision.transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
             torchvision.transforms.RandomCrop(224),
             torchvision.transforms.ToTensor(),
             torchvision.transforms.Normalize(mean, std)]))
        self.evaluation_normalize_transform = dataset_args.get('eval_transform', torchvision.transforms.Compose(
            [torchvision.transforms.Resize(256),
             torchvision.transforms.CenterCrop(224),
             torchvision.transforms.ToTensor(),
             torchvision.transforms.Normalize(mean, std)]))

        # Normalization transform does (x - mean) / std
        # To denormalize use mean* = (-mean/std) and std* = (1/std)
        self.denormalization_transform = torchvision.transforms.Normalize(self.demean, self.destd, inplace=False)

        self.dataset = torchvision.datasets.ImageFolder(root=dataset_dir,
                                                            transform=self.train_normalize_transform)

        # Split train data into training and cross validation dataset using 9:1:1 split ration
        self.trainset, self.validationset, self.testset = self.uniform_train_val_split(self.dataset,
                                                                                       split_ratio,
                                                                                       len(self.classes))

    @property
    def train_dataloader(self) -> DataLoader:
        self.dataset.transform = self.train_normalize_transform
        return torch.utils.data.DataLoader(self.trainset,
                                           batch_size=self.train_data_args['batch_size'],
                                           shuffle=self.train_data_args['shuffle'],
                                           pin_memory=True,
                                           num_workers=get_cores_count())

    @property
    def validation_dataloader(self) -> DataLoader:
        self.dataset.transform = self.evaluation_normalize_transform
        return torch.utils.data.DataLoader(self.validationset,
                                           batch_size=self.train_data_args['batch_size'],
                                           shuffle=self.train_data_args['shuffle'],
                                           pin_memory=True,
                                           num_workers=get_cores_count())

    @property
    def test_dataloader(self):
        self.dataset.transform = self.evaluation_normalize_transform
        return torch.utils.data.DataLoader(self.testset,
                                           batch_size=self.val_data_args['batch_size'],
                                           shuffle=self.val_data_args['shuffle'],
                                           pin_memory=True,
                                           num_workers=get_cores_count())

    def imshow(self, img):
        # clamp to get rid of numerical errors
        img = torch.clamp(self.denormalize(img), 0.0, 1.0)  # denormalize
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.show()

    def debug(self):
        # get some random training images
        data_iter = iter(self.train_dataloader)
        images, labels = data_iter.next()

        # show images
        self.imshow(torchvision.utils.make_grid(images, nrow=4))

        # print labels
        pprint(', '.join('%s' % self.classes[labels[j]] for j in range(len(images))))

    def denormalize(self, x):
        return self.denormalization_transform(x)

    @property
    def train_dataset_size(self):
        return len(self.trainset)

    @property
    def val_dataset_size(self):
        return len(self.validationset)

    @property
    def test_dataset_size(self):
        return len(self.testset)

    @property
    def classes(self):
        return self.dataset.classes

    @staticmethod
    def uniform_train_val_split(dataset, split_ratio, num_classes):
        targets = dataset.targets
        if type(targets) == list:
            targets = np.array(targets)
            labels = targets
        elif type(targets) == torch.tensor or type(targets) == torch.Tensor:
            labels = targets.numpy()
        training_indices = []
        validation_indices = []
        test_indices = []
        for i in range(num_classes):
            label_indices = np.argwhere(labels == i)
            training_samples_per_label = int(split_ratio[0] * len(label_indices))
            validation_samples_per_label = int(split_ratio[1] * len(label_indices))
            training_label_indices = label_indices[:training_samples_per_label]
            validation_label_indices = label_indices[training_samples_per_label:training_samples_per_label + validation_samples_per_label]
            test_label_indices = label_indices[training_samples_per_label + validation_samples_per_label:]
            training_indices.extend(training_label_indices.squeeze().tolist())
            if validation_label_indices.size == 1:
                validation_indices.append(validation_label_indices.squeeze().tolist())
            else:
                validation_indices.extend(list(validation_label_indices.squeeze().tolist()))
            test_indices.extend(test_label_indices.squeeze().tolist())
            assert not set(training_label_indices.ravel().tolist()) & set(validation_label_indices.ravel().tolist())
            assert not set(training_label_indices.ravel().tolist()) & set(test_label_indices.ravel().tolist())
            assert not set(validation_label_indices.ravel().tolist()) & set(test_label_indices.ravel().tolist())

        uniform_training_subset = torch.utils.data.Subset(dataset, training_indices)
        uniform_validation_subset = torch.utils.data.Subset(dataset, validation_indices)
        uniform_test_subset = torch.utils.data.Subset(dataset, test_indices)
        assert not set(training_indices) & set(validation_indices)
        assert not set(test_indices) & set(validation_indices)
        assert not set(training_indices) & set(test_indices)
        return uniform_training_subset, uniform_validation_subset, uniform_test_subset

    def pos_neg_balance_weights(self):
        return torch.tensor([1.0] * len(self.classes))

def get_birdsnap_object():
    dataset_args = dict(
    )

    train_data_args = dict(
        batch_size=8,
        shuffle=True,
    )

    val_data_args = dict(
        batch_size=train_data_args['batch_size'] * 4,
        shuffle=False,
        validate_step_size=1,
    )
    dataset = Birdsnap(train_data_args, val_data_args, dataset_args, split_ratio=[0.8, 0.1, 0.1])
    return dataset


if __name__ == '__main__':
    dataset = get_birdsnap_object()
    dataset.debug()

    print('Length of training set = ', len(dataset.trainset))
    print('Length of validation set = ', len(dataset.validationset))
    print('Length of test set = ', len(dataset.testset))