from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
from utils.conf import base_path
from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders
#import torch.nn.functional as F
from PIL import Image
#from datasets.utils.validation import get_train_val
#from datasets.utils.continual_dataset import get_previous_train_loader
#from datasets.transforms.denormalization import DeNormalize
from augmentations import get_aug
#from typing import Tuple
#import torch


class SequentialCIFAR100(ContinualDataset):
    NAME = 'seq-cifar100'
    #SETTING = 'class-il'
    N_CLASSES_PER_TASK = 5
    N_TASKS = 20

    def get_data_loaders(self):
        transform = get_aug(is_train=True, transform_single=False)
        eval_transform = get_aug(is_train=False, transform_single=True)

        train_dataset = CIFAR100(base_path() + 'CIFAR100', train=True,
                                  download=True, transform=transform)

        if self.args.validation:
            raise NotImplementedError("Validation set is not implemented yet.")
            #train_dataset, test_dataset = get_train_val(train_dataset, test_transform, self.NAME)
            #val_dataset, _ = get_train_val(memory_dataset, test_transform, self.NAME)
        else:
            test_dataset = CIFAR100(base_path() + 'CIFAR100', train=False,
                                   download=True, transform=eval_transform)

        train, valid, test = store_masked_loaders(train_dataset, None, test_dataset, self)
        return train, valid, test


    def get_transform(self, is_train=True):
        transform = get_aug(is_train=is_train, transform_single=True, to_pil_image=True)
        return transform


    def not_aug_dataloader(self, batch_size):
        cifar_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]]
        transform = transforms.Compose([transforms.ToTensor(),
                transforms.Normalize(*cifar_norm)])

        train_dataset = CIFAR100(base_path() + 'CIFAR100', train=True,
                                  download=True, transform=transform)
        train_loader = get_previous_train_loader(train_dataset, batch_size, self)

        return train_loader
