# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import open_clip
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from datasets.seq_cars196 import JointCars196
from datasets.seq_dtd import JointDTD
from datasets.seq_eurosat_rgb import JointEurosatRgb
from datasets.seq_gtsrb import JointGTSRB
from datasets.seq_mnist_224 import JointMNIST224
from datasets.seq_resisc45 import JointResisc45
from datasets.seq_sun397 import JointSUN397
from datasets.seq_svhn import JointSVHN
from itertools import chain


from datasets.utils.continual_dataset import (ContinualDataset)
from datasets.utils import set_default_from_args
from utils.prompt_templates import templates

class Sequential8Vision(ContinualDataset):
    NAME = 'seq-8vision'
    SETTING = 'class-il'
    DATASET_NAMES = ["joint-cars196", "joint-dtd", "joint-eurosat-rgb", "joint-gtsrb", "joint-mnist-224", "joint-resisc45", "joint-sun397", "joint-svhn"]
    DATASETS = [JointCars196, JointDTD, JointEurosatRgb, JointGTSRB, JointMNIST224, JointResisc45, JointSUN397, JointSVHN]
    N_CLASSES_PER_TASK = [196, 47, 10, 43, 10, 45, 397, 10]
    PCA_PROJ_DIMS = [128, 128, 64, 96, 64, 96, 192, 64]
    N_TASKS = 8
    N_CLASSES = sum(N_CLASSES_PER_TASK)
    SIZE = (224, 224)
    MEAN = (0.48145466, 0.4578275, 0.40821073)
    STD = (0.26862954, 0.26130258, 0.27577711)

    TRANSFORM = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD),
    ])

    TEST_TRANSFORM = open_clip.create_model_and_transforms(
        'ViT-B-16', pretrained='openai', cache_dir='checkpoints/ViT-B-16/cachedir/open_clip')[2]


    def __init__(self, args):
        super().__init__(args)
        self.dataset_instances = []
        for dataset in self.DATASETS:
            self.dataset_instances.append(dataset(self.args))
            self.dataset_instances[-1].TRANSFORM = self.TRANSFORM
            self.dataset_instances[-1].TEST_TRANSFORM = self.TEST_TRANSFORM
        self.test_loaders = []

    def get_data_loaders(self) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
        self.c_task += 1

        cur_dataset = self.dataset_instances[self.c_task]
        self.train_loader, test_loader = cur_dataset.get_data_loaders()
        self.train_loader.dataset.targets += sum(self.N_CLASSES_PER_TASK[:self.c_task])
        test_loader.dataset.targets += sum(self.N_CLASSES_PER_TASK[:self.c_task])
        self.test_loaders.append(test_loader)
        return self.train_loader, self.test_loaders

    @staticmethod
    def get_transform():
        print("YOU SHOULDN'T BE HERE")
        return None

    @set_default_from_args("backbone")
    def get_backbone():
        return "vit"

    @staticmethod
    def get_loss():
        return F.cross_entropy

    @staticmethod
    def get_normalization_transform():
        print("YOU SHOULDN'T BE HERE")
        return None

    @staticmethod
    def get_denormalization_transform():
        print("YOU SHOULDN'T BE HERE")
        return None

    @set_default_from_args('n_epochs')
    def get_epochs(self):
        if self is None:
            return 20
        else:
            if self.args.n_epochs is not None:
                return self.args.n_epochs
            return self.dataset_instances[self.c_task].get_epochs()

    def get_task_epochs(self, t):
        epochs = {
            "joint-cars196": 35,
            "joint-dtd": 76,
            "joint-eurosat-rgb": 12,
            "joint-gtsrb": 11,
            "joint-mnist-224": 5,
            "joint-resisc45": 15,
            "joint-sun397": 14,
            "joint-svhn": 4,
        }
        return epochs[self.DATASET_NAMES[t]]
    
    def get_iters(self):
        iters = 2000
        if self.args.chunks is not None:
            iters *= self.args.chunks
        return iters

    @set_default_from_args('batch_size')
    def get_batch_size(self):
        return 32

    def get_class_names(self):
        if self.class_names is not None:
            return self.class_names
        from itertools import chain
        classes = list(chain.from_iterable(
            dataset.get_class_names(True) if isinstance(dataset, JointEurosatRgb)
            else dataset.get_class_names() for dataset in self.dataset_instances
        ))
        self.class_names = classes
        return self.class_names

    @staticmethod
    def get_prompt_templates():
        return templates['seq-8vision']

    def get_pca_proj_dim(self):
        return self.PCA_PROJ_DIMS[self.c_task]