import torch
from torchvision.datasets import CIFAR10 as torchCIFAR10
import torchvision.transforms as transforms
from mlsuite.pytorch.utils import to_one_hot


class CIFAR10(torchCIFAR10):
    tasks = {
        f'class{i+1}': [i, ['bce', 'f1']] for i in range(10)
    }

    def __init__(self, root, tag):
        super(CIFAR10, self).__init__(str(root), train=tag != 'test', download=True,
                                     transform=transforms.Compose([
                                         transforms.RandomCrop(32, padding=4),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                     ] if tag != 'test' else [
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                     ]))
        self.input_size = 3 * 32 * 32

        split = 40000
        if tag == 'train':
            self.data, self.targets = self.data[:split], self.targets[:split]
        elif tag == 'val':
            self.data, self.targets = self.data[split:], self.targets[split:]

    def __getitem__(self, index):
        data, target = super(CIFAR10, self).__getitem__(index)
        return data, to_one_hot(torch.tensor([target]).float(), 10).float().unbind(dim=-1)

    # def __len__(self):
    #     return 8

