import seq_tinyimagenet
import hetro_cifar100
import torchvision.transforms as transforms
from datasets.utils.validation import get_train_val
from utils.conf import base_path_dataset as base_path


class HetroTinyImagenet(seq_tinyimagenet.SequentialTinyImagenet):

    NAME = 'hetro-tinyimg'
    SETTING = 'class-il'
    N_CLASSES_PER_TASK = 100
    N_TASKS = 20

    def __init__(self, args):
        self.num_classes = 100
        super().__init__(args=args)
        self.task_num = 0

        #self.task_class_nums = np.random.choice(8, 30, p=[0.2, 0.15, 0.1, 0.1, 0.1, 0.1, 0.1, 0.15])+2
        self.task_class_nums = [9, 2, 7, 3, 4, 9, 8, 3, 3, 7, 4, 4, 5, 9, 4, 5, 2, 8, 2, 2]

    def get_data_loaders(self):
        transform = self.TRANSFORM

        test_transform = transforms.Compose(
            [transforms.ToTensor(), self.get_normalization_transform()])

        train_dataset = seq_tinyimagenet.MyTinyImagenet(base_path() + 'TINYIMG',
                                       train=True, download=True, transform=transform)
        if self.args.validation:
            train_dataset, test_dataset = get_train_val(train_dataset,
                                                        test_transform, self.NAME)
        else:
            test_dataset = seq_tinyimagenet.TinyImagenet(base_path() + 'TINYIMG',
                                        train=False, download=True, transform=test_transform)

        train_loader, test_loader = hetro_cifar100.get_hetro_split(self, train_dataset, test_dataset)
        return train_loader, test_loader
