from PIL import Image
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import Dataset
import pickle


class CIFAR10Pair(CIFAR10):
    """CIFAR10 Dataset.
    """

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            pos_1 = self.transform(img)
            pos_2 = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return pos_1, pos_2, target


class CorruptedCIFAR10(Dataset):
    def __init__(self, data_path, transform=None):
        """
        Args:
            data_path (string): Path to the file containing the corrupted CIFAR-10 dataset.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        with open(data_path, 'rb') as f:
            self.data = pickle.load(f)
        self.transform = transform

    def __len__(self):
        return len(self.data['data'])

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the data to get.
        
        Returns:
            tuple: (image, label) where image is a transformed version of the
            CIFAR-10 image and label is the integer label of the image class.
        """
        image = self.data['data'][idx]
        target = self.data['labels'][idx]
        img = Image.fromarray(image)

        if self.transform is not None:
            pos_1 = self.transform(img)
            pos_2 = self.transform(img)

        if self.transform:
            image = self.transform(img)

        return pos_1, pos_2, target


train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
