import torch
from torch.utils.data import Dataset
import torchvision
from torchvision import transforms
from PIL import Image
import numpy as np
import os

class CIFAR10Noisy(Dataset):
    def __init__(self, train=True, download=False):
        self.name = 'cifar10_noisy'
        self.train = train
        self.transform = transforms.Compose([
            Cutout(num_cutouts=2, size=8, p=0.8),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.49139968, 0.48215841, 0.44653091),
                                 (0.24703223, 0.24348513, 0.26158784))])
        datasets_dir = os.path.abspath(__file__).split('data/')[0] + 'data/'
        path = f'{datasets_dir}/cifar10/data/'
        self.dataset = torchvision.datasets.CIFAR10(root=path, train=train,
                                                    download=download, transform=self.transform)
        if train:
            # Load noisy labels
            folder = os.path.dirname(os.path.abspath(__file__))
            noisy_labels = torch.load(os.path.join(folder, 'noisy_labels.pt'))
            self.dataset.targets = noisy_labels

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

class Cutout(object):
    """
    Implements Cutout regularization as proposed by DeVries and Taylor (2017), https://arxiv.org/pdf/1708.04552.pdf.
    """

    def __init__(self, num_cutouts, size, p=0.5):
        """
        Parameters
        ----------
        num_cutouts : int
            The number of cutouts
        size : int
            The size of the cutout
        p : float (0 <= p <= 1)
            The probability that a cutout is applied (similar to keep_prob for Dropout)
        """
        self.num_cutouts = num_cutouts
        self.size = size
        self.p = p

    def __call__(self, img):

        height, width = img.size

        cutouts = np.ones((height, width))

        if np.random.uniform() < 1 - self.p:
            return img

        for i in range(self.num_cutouts):
            y_center = np.random.randint(0, height)
            x_center = np.random.randint(0, width)

            y1 = np.clip(y_center - self.size // 2, 0, height)
            y2 = np.clip(y_center + self.size // 2, 0, height)
            x1 = np.clip(x_center - self.size // 2, 0, width)
            x2 = np.clip(x_center + self.size // 2, 0, width)

            cutouts[y1:y2, x1:x2] = 0

        cutouts = np.broadcast_to(cutouts, (3, height, width))
        cutouts = np.moveaxis(cutouts, 0, 2)
        img = np.array(img)
        img = img * cutouts
        return Image.fromarray(img.astype('uint8'), 'RGB')

if __name__ == "__main__":
    # Load dataset and generate noisy mask for the labels with seed 42
    from data import CIFAR10
    from data.utils import set_seed
    cifar10_trainset = CIFAR10(train=True, download=True)
    n_train = len(cifar10_trainset)
    set_seed(42)
    flipped_indices = torch.randperm(n_train)[:int(n_train*0.1)]
    targets = torch.tensor(cifar10_trainset.dataset.targets).clone()
    for idx in flipped_indices:
        while targets[idx] == cifar10_trainset.dataset.targets[idx]:
            targets[idx] = torch.randint(0, 10, (1,))

    folder = os.path.dirname(os.path.abspath(__file__))
    torch.save(flipped_indices, os.path.join(folder, 'flipped_indices.pt'))
    torch.save(targets, os.path.join(folder, 'noisy_labels.pt'))