from typing import Tuple

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from backbone.MNISTMLP import MNISTMLP
from PIL import Image
from torchvision.datasets import MNIST

from datasets.transforms.denormalization import DeNormalize
from datasets.utils.continual_dataset import (ContinualDataset,
                                              store_masked_loaders)
from datasets.utils.validation import get_train_val
from utils.conf import base_path_dataset as base_path


class MyMNIST(MNIST):
    """
    Overrides the MNIST dataset to change the getitem function.
    """

    def __init__(self, root, train=True, transform=None,
                 target_transform=None, download=False, multiplicate_classes=None, multiplicate_by=None) -> None:
        self.not_aug_transform = transforms.ToTensor()
        super(MyMNIST, self).__init__(root, train,
                                      transform, target_transform, download)

        if multiplicate_classes:
            assert len(multiplicate_classes) == len(multiplicate_by)
            assert len(multiplicate_classes) == len(set(multiplicate_classes))
            x = [self.data[~torch.isin(self.targets, torch.tensor(multiplicate_classes))]]
            y = [self.targets[~torch.isin(self.targets, torch.tensor(multiplicate_classes))]]

            for c, mult in zip(multiplicate_classes, multiplicate_by):
                length = int(mult * sum(self.targets == c))
                x_repeat_shape = [int(mult + 1)] + [1] * (len(self.data.shape) - 1)
                x.append(self.data[self.targets == c].repeat(*x_repeat_shape)[:length])
                y.append(self.targets[self.targets == c].repeat(int(mult + 1))[:length])

            self.data = torch.cat(x)
            self.targets = torch.cat(y)

    def __getitem__(self, index: int) -> Tuple[Image.Image, int, Image.Image]:
        """
        Gets the requested element from the dataset.
        :param index: index of the element to be returned
        :returns: tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')
        original_img = self.not_aug_transform(img.copy())

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        if hasattr(self, 'logits'):
            return img, target, original_img, self.logits[index]

        return img, target, original_img


class SequentialMNIST(ContinualDataset):

    NAME = 'seq-mnist'
    SETTING = 'class-il'
    TRANSFORM = None
    NUM_CLASSES = 10
    IMG_SIZE = 28
    NUM_CHANNELS = 1

    def get_data_loaders(self):
        transform = transforms.ToTensor()
        train_dataset = MyMNIST(base_path() + 'MNIST',
                                train=True, download=True, transform=transform,
                                multiplicate_classes=self.args.multiplicate_classes,
                                multiplicate_by=self.args.multiplicate_by)
        if self.args.validation:
            train_dataset, test_dataset = get_train_val(train_dataset,
                                                        transform, self.NAME)
        else:
            test_dataset = MNIST(base_path() + 'MNIST',
                                 train=False, download=True, transform=transform)

        train, test = store_masked_loaders(train_dataset, test_dataset, self)
        return train, test

    def get_backbone(self):
        return MNISTMLP(28 * 28, SequentialMNIST.NUM_CLASSES, hidden_size=self.args.mlp_hidden_size)

    @staticmethod
    def get_transform():
        return None

    @staticmethod
    def get_loss():
        return F.cross_entropy

    @staticmethod
    def get_normalization_transform():
        return None

    @staticmethod
    def get_denormalization_transform():
        transform = DeNormalize((0.0,), (1.0,))
        return transform

    @staticmethod
    def get_scheduler(model, args):
        return None

    @staticmethod
    def get_batch_size():
        return 64

    @staticmethod
    def get_minibatch_size():
        return SequentialMNIST.get_batch_size()
