# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
from typing import Optional
from torchvision.datasets import ImageFolder
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from backbone.ResNet18 import resnet18
from backbone.CLIP import Clip
from PIL import Image
from torch.utils.data import Dataset
import torch
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from dataset.transforms.denormalization import DeNormalize
from dataset.utils.continual_dataset import ContinualDataset
from dataset.utils.validation import get_train_val
from utils.conf import base_path_dataset as base_path
from torch.utils.data import DataLoader, Dataset

class TinyImagenet(Dataset):
    """
    Defines Tiny Imagenet as for the others pytorch datasets.
    """

    def __init__(self, root: str, train: bool = True, transform: Optional[nn.Module] = None,
                 target_transform: Optional[nn.Module] = None, download: bool = False) -> None:
        self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
        self.root = root
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        self.download = download

        self.Dataset = ImageFolder(root=self.root, transform=self.transform)
        classes = self.Dataset.classes
        self.classes = get_text(classes)

    def __len__(self):
        return len(self.Dataset.samples)

    def __getitem__(self, index):
        path, target = self.Dataset.samples[index]
        img, target = default_loader(path), target
        img = np.array(img)   
        img = Image.fromarray(img)
        original_img = img.copy()

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        if hasattr(self, 'logits'):
            return img, target, original_img, self.logits[index]

        return img, target



class MyTinyImagenet(TinyImagenet):
    """
    Defines Tiny Imagenet as for the others pytorch datasets.
    """

    def __init__(self, root: str, train: bool = True, transform: Optional[nn.Module] = None,
                 target_transform: Optional[nn.Module] = None, download: bool = False) -> None:
        super(MyTinyImagenet, self).__init__(
            root, train, transform, target_transform, download)

    def __getitem__(self, index):
        path, target = self.Dataset.samples[index]
        img, target = default_loader(path), target
        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = np.array(img)   
        img = Image.fromarray(img)
        original_img = img.copy()

        not_aug_img = self.not_aug_transform(original_img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        if hasattr(self, 'logits'):
            return img, target, not_aug_img, self.logits[index]

        return img, target, not_aug_img


class SequentialTinyImagenet(ContinualDataset):

    NAME = 'seq-tinyimg'
    SETTING = 'class-il'
    N_CLASSES_PER_TASK = 20
    N_TASKS = 10
    TRANSFORM = transforms.Compose(
        [transforms.RandomCrop(64, padding=4),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor()
         #transforms.Normalize((0.4802, 0.4480, 0.3975),(0.2770, 0.2691, 0.2821))
         ])

    def get_data_loaders(self):
        transform_aug = []
        if self.args.aug == 'aua':
            transform_aug = [transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10)]
        elif self.args.aug == 'ra':
            transform_aug = [transforms.RandAugment(2,8)]
        elif self.args.aug == 'none':
            transform_aug = []
        
        if self.args.model_type == 'clip':
            transform_aug.append(transforms.Resize((224,224)))
            transform_aug.append(transforms.RandomCrop(224, padding=4))
            test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])
        else:
            transform_aug.append(transforms.RandomCrop(64, padding=4))
            test_transform = transforms.Compose([transforms.ToTensor()])
        transform = transforms.Compose(transform_aug + 
            [#transforms.RandomCrop(64, padding=4),
             transforms.RandomHorizontalFlip(),
             transforms.ToTensor()])
        
        train_dataset = MyTinyImagenet(base_path() + './tinyImageNet/tiny-imagenet-200/train/', train=True,
                                       download=True, transform=transform)
        if self.args.validation:
            train_dataset, test_dataset = get_train_val(train_dataset,
                                                        test_transform, self.NAME)
        else:
            test_dataset = TinyImagenet(base_path() + './tinyImageNet/tiny-imagenet-200/val/',
                                        train=False, download=True, transform=test_transform)


        train, test = self.store_masked_loaders(train_dataset, test_dataset, self)
        return train, test

    def store_masked_loaders(self, train_dataset: Dataset, test_dataset: Dataset,
                            setting: ContinualDataset) -> Tuple[DataLoader, DataLoader]:
        """
        Divides the dataset into tasks.
        :param train_dataset: train dataset
        :param test_dataset: test dataset
        :param setting: continual learning setting
        :return: train and test loaders
        """
        train_mask = np.logical_and(np.array(train_dataset.Dataset.targets) >= setting.i,
                                    np.array(train_dataset.Dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK)
        test_mask = np.logical_and(np.array(test_dataset.Dataset.targets) >= setting.i,
                                np.array(test_dataset.Dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK)

        train_dataset.Dataset.samples = [s for s, keep in zip(train_dataset.Dataset.samples, train_mask) if keep]
        test_dataset.Dataset.samples = [s for s, keep in zip(test_dataset.Dataset.samples, test_mask) if keep]

        train_dataset.Dataset.targets = np.array(train_dataset.Dataset.targets)[train_mask]
        test_dataset.Dataset.targets = np.array(test_dataset.Dataset.targets)[test_mask]

        train_loader = DataLoader(train_dataset,
                                batch_size=setting.args.batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(test_dataset,
                                batch_size=setting.args.batch_size, shuffle=False, num_workers=4)
        setting.test_loaders.append(test_loader)
        setting.train_loader = train_loader

        setting.i += setting.N_CLASSES_PER_TASK
        return train_loader, test_loader

    @staticmethod
    def get_backbone(args):
        mean = torch.tensor((0.4802, 0.4480, 0.3975)).cuda().view(-1, 1, 1)
        std = torch.tensor((0.2770, 0.2691, 0.2821)).cuda().view(-1, 1, 1)
        if args.architecture=="RES-18":
            return resnet18(mean, std, SequentialTinyImagenet.N_CLASSES_PER_TASK
                        * SequentialTinyImagenet.N_TASKS)
        elif args.model_type=="clip":
            return Clip(args, mean, std)
        else:
            raise NotImplementedError
        
    @staticmethod
    def get_loss():
        return F.cross_entropy

    def get_transform(self):
        transform = transforms.Compose(
            [transforms.ToPILImage(), self.TRANSFORM])
        return transform

    @staticmethod
    def get_normalization_transform():
        transform = transforms.Normalize((0.4802, 0.4480, 0.3975),
                                         (0.2770, 0.2691, 0.2821))
        return transform

    @staticmethod
    def get_denormalization_transform():
        transform = DeNormalize((0.4802, 0.4480, 0.3975),
                                (0.2770, 0.2691, 0.2821))
        return transform

    @staticmethod
    def get_scheduler(model, args):
        return None

    @staticmethod
    def get_epochs():
        return 50

    @staticmethod
    def get_batch_size():
        return 32

    @staticmethod
    def get_minibatch_size():
        return SequentialTinyImagenet.get_batch_size()


    @staticmethod
    def get_scheduler(model, args) -> torch.optim.lr_scheduler:
        if args.model_type == 'clip':
            if hasattr(model, 'get_optimizer'):
                scheduler = model.get_optimizer(args)
                return scheduler
            else:
                model.opt = torch.optim.SGD(model.net.parameters(), lr=args.lr, weight_decay=args.optim_wd, momentum=args.optim_mom)
                scheduler = torch.optim.lr_scheduler.MultiStepLR(model.opt, [int(args.n_epochs * 0.48), int(args.n_epochs * 0.62), int(args.n_epochs * 0.80)], gamma=0.1, verbose=False)
        else:
            model.opt = torch.optim.SGD(model.net.parameters(), lr=args.lr, weight_decay=args.optim_wd, momentum=args.optim_mom)
            scheduler = torch.optim.lr_scheduler.MultiStepLR(model.opt, [35, 45], gamma=0.1, verbose=False)
        return scheduler



    @staticmethod
    def get_robust_scheduler(model, args) -> torch.optim.lr_scheduler:
        if hasattr(model, 'get_optimizer'):
            scheduler = model.get_optimizer(args)
            return scheduler
        else:
            model.opt = torch.optim.SGD(model.parameters(args), lr=args.lr, weight_decay=args.optim_wd, momentum=args.optim_mom)
            scheduler = torch.optim.lr_scheduler.MultiStepLR(model.opt, [int(args.n_epochs * 0.48), int(args.n_epochs * 0.62), int(args.n_epochs * 0.80)], gamma=0.1, verbose=False)
        return scheduler



def pil_loader(path: str) -> Image.Image:
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")


# TODO: specify the return type
def accimage_loader(path: str) -> Any:
    import accimage
    try:
        return accimage.Image(path)
    except OSError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)

def default_loader(path: str) -> Any:
    from torchvision import get_image_backend

    if get_image_backend() == "accimage":
        return accimage_loader(path)
    else:
        return pil_loader(path)
    
def get_text(class_names):
    folder2name = load_imagenet_folder2name('./imagenet_classes_names.txt')
    new_class_names = []
    for each in class_names:
        new_class_names.append(folder2name[each])
    class_names = new_class_names
    return class_names

def load_imagenet_folder2name(path):
    dict_imagenet_folder2name = {}
    with open(path) as f:
        line = f.readline()
        while line:
            split_name = line.strip().split()
            cat_name = split_name[2]
            id = split_name[0]
            dict_imagenet_folder2name[id] = cat_name
            line = f.readline()
    return dict_imagenet_folder2name