from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image
from tqdm import tqdm
from numpy.testing import assert_array_almost_equal
import numpy as np
import os
import torch
import random


def multiclass_noisify(y, P, seed=123):
    """ Flip classes according to transition probability matrix T.
    It expects a number between 0 and the number of classes - 1.
    """

    if isinstance(y, torch.Tensor):
        noisy_y = y.cpu().numpy().copy()
    elif isinstance(y, np.ndarray):
        noisy_y = y.copy()
    else:
        noisy_y = np.array(y)

    assert P.shape[0] == P.shape[1]
    assert np.max(noisy_y) < P.shape[0]
    # assert row stochastic matrix
    assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
    assert (P >= 0.0).all()

    rand = np.random.RandomState(seed)
    for i in range(noisy_y.shape[0]):
        noisy_one_hot = rand.multinomial(1, P[noisy_y[i], :])
        noisy_y[i] = noisy_one_hot.argmax()

    return noisy_y


def noisify_multiclass_symmetric(y, noise, seed=123, nb_classes=10):
    if noise <= 0.0:
        return y
    assert noise < 1.0
    P = np.full((nb_classes, nb_classes), noise / (nb_classes - 1))
    np.fill_diagonal(P, 1 - noise)
    noisy_y = multiclass_noisify(y, P=P, seed=seed)
    return noisy_y


def noisify_mnist_asymmetric(y, noise, seed=123):
    if noise <= 0.0:
        return y
    assert noise <= 1.0
    P = np.eye(10)
    # 1 <- 7
    P[7, 7], P[7, 1] = 1. - noise, noise
    # 2 -> 7
    P[2, 2], P[2, 7] = 1. - noise, noise
    # 5 <-> 6
    P[5, 5], P[5, 6] = 1. - noise, noise
    P[6, 6], P[6, 5] = 1. - noise, noise
    # 3 -> 8
    P[3, 3], P[3, 8] = 1. - noise, noise

    return multiclass_noisify(y, P=P, seed=seed)


def noisify_cifar10_asymmetric(y, noise, seed=123):
    if noise <= 0.0:
        return y
    assert noise <= 1.0
    P = np.eye(10)
    # automobile <- truck
    P[9, 9], P[9, 1] = 1. - noise, noise
    # bird -> airplane
    P[2, 2], P[2, 0] = 1. - noise, noise
    # cat <-> dog
    P[3, 3], P[3, 5] = 1. - noise, noise
    P[5, 5], P[5, 3] = 1. - noise, noise
    # deer -> horse
    P[4, 4], P[4, 7] = 1. - noise, noise
    return multiclass_noisify(y, P=P, seed=seed)


def noisify_cifar100_asymmetric(y, noise, seed=123):
    if noise <= 0.0:
        return y
    assert noise <= 1.0

    def circular_flip(size, noise):
        """random flip to the next class circularly"""
        P = np.roll(np.eye(size) * noise, 1, axis=-1)
        np.fill_diagonal(P, 1 - noise)
        return P

    P = np.eye(100)
    nb_superclasses = 20
    nb_subclasses = 5
    for i in np.arange(nb_superclasses):
        init, end = i * nb_subclasses, (i + 1) * nb_subclasses
        P[init:end, init:end] = circular_flip(nb_subclasses, noise)
    return multiclass_noisify(y, P=P, seed=seed)


def noisify(targets, dataset='MNIST', nb_classes=10, noise_type="symmetric", noise_rate=0.0, seed=123):
    dataset = dataset.upper()
    if noise_type == 'symmetric':
        if dataset in ["MNIST", "CIFAR10"]:
            nb_classes = 10
        elif dataset == "CIFAR100":
            nb_classes = 100
        else:
            raise ValueError("Not supported dataset: {}".format(dataset))
        return noisify_multiclass_symmetric(targets, noise_rate, seed=seed, nb_classes=nb_classes)

    elif noise_type == 'asymmetric':
        if dataset == 'MNIST':
            return noisify_mnist_asymmetric(targets, noise_rate, seed=seed)
        elif dataset == 'CIFAR10':
            return noisify_cifar10_asymmetric(targets, noise_rate, seed=seed)
        elif dataset == 'CIFAR100':
            return noisify_cifar100_asymmetric(targets, noise_rate, seed=seed)
    else:
        raise ValueError("Not supported noise type: {}".format(noise_type))


def torch_long(data):
    if not isinstance(data, torch.Tensor):
        return torch.tensor(data, dtype=torch.long)
    else:
        return data.long()


class MNIST(datasets.MNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=True, noise_rate=0.0, noise_type="symmetric", seed=123):
        super(MNIST, self).__init__(root, train=train,
                                    transform=transform,
                                    target_transform=target_transform,
                                    download=download)
        self.targets = torch_long(self.targets)
        if train and noise_rate > 0:
            noisy_targets = torch_long(noisify(targets=self.targets,
                                               dataset="MNIST",
                                               noise_type=noise_type,
                                               noise_rate=noise_rate,
                                               seed=seed))
            self.is_noise = torch_long(self.targets != noisy_targets)
            self.targets = noisy_targets
            noise_count = self.is_noise.count_nonzero().item()
            print("Noise rate: {}, num noise: {}".format(noise_count / self.targets.shape[0], noise_count))
        else:
            self.is_noise = torch.zeros_like(self.targets)

    def __getitem__(self, index):
        img, target = super(MNIST, self).__getitem__(index)
        is_noise = int(self.is_noise[index])
        return img, target, is_noise


class CIFAR10(datasets.CIFAR10):
    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=True, noise_rate=0.0, noise_type="symmetric", seed=123):
        super(CIFAR10, self).__init__(root, train=train,
                                      transform=transform,
                                      target_transform=target_transform,
                                      download=download)
        self.targets = torch_long(self.targets)
        if train and noise_type == "human":
            human = torch.load(os.path.join(root, "CIFAR-10_human.pt"))
            assert torch.tensor(human['clean_label']).allclose(self.targets)
            noisy_targets = torch.tensor(human["worse_label"])
            self.is_noise = torch_long(self.targets != noisy_targets)
            self.targets = noisy_targets
            noise_count = self.is_noise.count_nonzero().item()
            print("Noise rate: {}, num noise: {}".format(noise_count / self.targets.shape[0], noise_count))
        elif train and noise_rate > 0:
            noisy_targets = torch_long(noisify(targets=self.targets,
                                               dataset="CIFAR10",
                                               noise_type=noise_type,
                                               noise_rate=noise_rate,
                                               seed=seed))
            self.is_noise = torch_long(self.targets != noisy_targets)
            self.targets = noisy_targets
            noise_count = self.is_noise.count_nonzero().item()
            print("Noise rate: {}, num noise: {}".format(noise_count / self.targets.shape[0], noise_count))
        else:
            self.is_noise = torch.zeros_like(self.targets)

    def __getitem__(self, index):
        img, target = super(CIFAR10, self).__getitem__(index)
        is_noise = int(self.is_noise[index])
        return img, target, is_noise


class CIFAR100(datasets.CIFAR100):
    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=True, noise_rate=0.0, noise_type="symmetric", seed=123):
        super(CIFAR100, self).__init__(root, train=train,
                                       transform=transform,
                                       target_transform=target_transform,
                                       download=download)
        self.targets = torch_long(self.targets)
        if train and noise_type == "human":
            human = torch.load(os.path.join(root, "CIFAR-100_human.pt"))
            assert torch.tensor(human['clean_label']).allclose(self.targets)
            noisy_targets = torch.tensor(human["noisy_label"])
            self.is_noise = torch_long(self.targets != noisy_targets)
            self.targets = noisy_targets
            noise_count = self.is_noise.count_nonzero().item()
            print("Loaded human noise.")
            print("Noise rate: {}, num noise: {}".format(noise_count / self.targets.shape[0], noise_count))
        elif train and noise_rate > 0:
            noisy_targets = torch_long(noisify(targets=self.targets,
                                               dataset="CIFAR100",
                                               noise_type=noise_type,
                                               noise_rate=noise_rate,
                                               seed=seed))
            self.is_noise = torch_long(self.targets != noisy_targets)
            self.targets = noisy_targets
            noise_count = self.is_noise.count_nonzero().item()
            print("Noise rate: {}, num noise: {}".format(noise_count / self.targets.shape[0], noise_count))
        else:
            self.is_noise = torch.zeros_like(self.targets)

    def __getitem__(self, index):
        img, target = super(CIFAR100, self).__getitem__(index)
        is_noise = int(self.is_noise[index])
        return img, target, is_noise


class DatasetGenerator:
    def __init__(self,
                 train_batch_size=128,
                 eval_batch_size=128,
                 data_path='data-bin',
                 seed=123,
                 num_of_workers=4,
                 noise_type='symmetric',
                 dataset='CIFAR10',
                 cutout_length=16,
                 noise_rate=0.4):
        self.seed = seed
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.data_path = data_path
        self.num_of_workers = num_of_workers
        self.cutout_length = cutout_length
        self.noise_rate = noise_rate
        self.dataset = dataset
        self.noise_type = noise_type
        self.data_loaders = self.loadData()

    def getDataLoader(self):
        return self.data_loaders

    def loadData(self):
        if self.dataset == 'MNIST':
            MEAN = [0.1307]
            STD = [0.3081]
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(MEAN, STD)])
            train_dataset = MNIST(root=self.data_path,
                                  train=True,
                                  transform=transform,
                                  noise_rate=self.noise_rate,
                                  noise_type=self.noise_type,
                                  seed=self.seed)
            test_dataset = MNIST(root=self.data_path,
                                 train=False,
                                 transform=transform)
        elif self.dataset == 'CIFAR10':
            CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
            CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

            train_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(CIFAR_MEAN, CIFAR_STD)])

            test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(CIFAR_MEAN, CIFAR_STD)])

            train_dataset = CIFAR10(root=self.data_path,
                                    train=True,
                                    transform=train_transform,
                                    noise_rate=self.noise_rate,
                                    noise_type=self.noise_type,
                                    seed=self.seed)

            test_dataset = CIFAR10(root=self.data_path,
                                   train=False,
                                   transform=test_transform)

        elif self.dataset == 'CIFAR100':
            CIFAR_MEAN = [0.5071, 0.4865, 0.4409]
            CIFAR_STD = [0.2673, 0.2564, 0.2762]
            train_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(20),
                transforms.ToTensor(),
                transforms.Normalize(CIFAR_MEAN, CIFAR_STD)])

            test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(CIFAR_MEAN, CIFAR_STD)])

            train_dataset = CIFAR100(root=self.data_path,
                                     train=True,
                                     transform=train_transform,
                                     noise_rate=self.noise_rate,
                                     noise_type=self.noise_type,
                                     seed=self.seed)

            test_dataset = CIFAR100(root=self.data_path,
                                    train=False,
                                    transform=test_transform)

        else:
            raise ("Unknown Dataset")

        data_loaders = {}

        data_loaders['train_dataset'] = DataLoader(dataset=train_dataset,
                                                   batch_size=self.train_batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   prefetch_factor=12,
                                                   num_workers=self.num_of_workers)

        data_loaders['test_dataset'] = DataLoader(dataset=test_dataset,
                                                  batch_size=self.eval_batch_size,
                                                  shuffle=False,
                                                  pin_memory=True,
                                                  prefetch_factor=12,
                                                  num_workers=self.num_of_workers)

        print("Num of train %d" % (len(train_dataset)))
        print("Num of test %d" % (len(test_dataset)))

        return data_loaders


class WebVisionTrainDataset(torch.utils.data.Dataset):
    def __init__(self, root, num_class=1000,
                 listfile='train_filelist_google.txt',
                 transform=None, target_transform=None):
        self.root = root
        self.num_class = num_class
        self.transform = transform
        self.target_transform = target_transform
        self.samples = []
        with open(os.path.join(root, listfile), "r") as fin:
            for line in fin:
                train_file, label = line.split()
                label = int(label)
                if label < self.num_class:
                    self.samples.append((train_file, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        impath, target = self.samples[index]
        img = Image.open(os.path.join(self.root, impath)).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        return img, target, 0


class WebVisionValDataset(datasets.ImageNet):
    def __init__(self, root, num_class=1000, transform=None, target_transform=None):
        super(WebVisionValDataset, self).__init__(root=root, split='val', transform=transform,
                                                  target_transform=target_transform)
        # eliminate invalid samples
        self.samples = [sample for sample in self.samples if sample[-1] < num_class]
        self.imgs = [img for img in self.imgs if img[-1] < num_class]
        self.targets = [target for target in self.targets if target < num_class]

    def __getitem__(self, index):
        img, target = super(WebVisionValDataset, self).__getitem__(index)
        return img, target, 0


class WebVisionDatasetLoader:
    def __init__(self, num_class=1000,
                 train_batch_size=128,
                 eval_batch_size=256,
                 data_path='data/',
                 num_of_workers=4):
        self.num_class = num_class
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.data_path = data_path
        self.num_of_workers = num_of_workers
        self.data_loaders = self.loadData()

    def getDataLoader(self):
        return self.data_loaders

    def loadData(self):
        IMAGENET_MEAN = [0.485, 0.456, 0.406]
        IMAGENET_STD = [0.229, 0.224, 0.225]
        train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                              transforms.RandomHorizontalFlip(),
                                              transforms.ColorJitter(brightness=0.4,
                                                                     contrast=0.4,
                                                                     saturation=0.4,
                                                                     hue=0.2),
                                              transforms.ToTensor(),
                                              transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)])

        test_transform = transforms.Compose([transforms.Resize(256),
                                             transforms.CenterCrop(224),
                                             transforms.ToTensor(),
                                             transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)])

        train_dataset = WebVisionTrainDataset(root=self.data_path,
                                              num_class=self.num_class,
                                              listfile='train_filelist_google.txt',
                                              transform=train_transform)
        test_dataset = WebVisionValDataset(root=self.data_path,
                                           num_class=self.num_class,
                                           transform=test_transform)

        data_loaders = {}
        print('Training Set Size %d' % (len(train_dataset)))
        print('Test Set Size %d' % (len(test_dataset)))

        data_loaders['train_dataset'] = DataLoader(dataset=train_dataset,
                                                   batch_size=self.train_batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   num_workers=self.num_of_workers)

        data_loaders['test_dataset'] = DataLoader(dataset=test_dataset,
                                                  batch_size=self.eval_batch_size,
                                                  shuffle=False,
                                                  pin_memory=True,
                                                  num_workers=self.num_of_workers)
        return data_loaders
