from typing import Tuple, List, Union

import numpy as np
import torch
import torchvision.datasets
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Subset

from xad.datasets.bases import TorchvisionDataset
from xad.utils.logger import Logger


class ADMNIST(TorchvisionDataset):
    def __init__(self, root: str, normal_classes: List[int], nominal_label: int,
                 train_transform: transforms.Compose, test_transform: transforms.Compose, 
                 raw_shape: Tuple[int, int, int], logger: Logger = None, limit_samples: Union[int, List[int]] = np.infty,
                 **kwargs):
        """ AD dataset for MNIST. Implements :class:`xad.datasets.bases.TorchvisionDataset`. """
        super().__init__(
            root, normal_classes, nominal_label, train_transform, test_transform, 10, raw_shape, logger, limit_samples,
            **kwargs
        )

        self._train_set = MNIST(
            self.root, train=True, download=True, transform=self.train_transform,
            target_transform=self.target_transform,
        )
        self._train_set = self.create_subset(self._train_set, self._train_set.targets)
        self._test_set = MNIST(
            root=self.root, train=False, download=True, transform=self.test_transform,
            target_transform=self.target_transform,
        )
        self._test_set = Subset(self._test_set, list(range(len(self._test_set))))  # create improper subset with all indices

    def _get_raw_train_set(self):
        train_set = MNIST(
            self.root, train=True, download=True,
            transform=transforms.Compose([transforms.Resize((self.raw_shape[-1])), transforms.ToTensor(), ]),
            target_transform=self.target_transform
        )
        return Subset(
            train_set,
            np.argwhere(
                np.isin(np.asarray(train_set.targets), self.normal_classes)
            ).flatten().tolist()
        )


class MNIST(torchvision.datasets.MNIST):
    def __getitem__(self, index) -> Tuple[torch.Tensor, int, int]:
        img, target = self.data[index], self.targets[index]
        if self.transform is None or isinstance(self.transform, transforms.Compose) and len(self.transform.transforms) == 0:
            img = img.float().div(255).unsqueeze(0)
        else:
            img = Image.fromarray(img.numpy(), mode="L")
        if self.target_transform is not None:
            target = self.target_transform(target)
        if self.transform is not None:
            img = self.transform(img)
        return img, target, index


class ADEMNIST(TorchvisionDataset):
    def __init__(self, root: str, normal_classes: List[int], nominal_label: int,
                 train_transform: transforms.Compose, test_transform: transforms.Compose,
                 raw_shape: Tuple[int, int, int], logger: Logger = None, limit_samples: Union[int, List[int]] = np.infty,
                 **kwargs):
        """ AD dataset for EMNIST. Implements :class:`xad.datasets.bases.TorchvisionDataset`. """
        super().__init__(
            root, normal_classes, nominal_label, train_transform, test_transform, 10, raw_shape, logger, limit_samples,
            **kwargs
        )

        self._train_set = EMNIST(
            self.root, train=True, download=True, transform=self.train_transform, split='letters',
            target_transform=self.target_transform,
        )
        self._train_set = self.create_subset(self._train_set, self._train_set.targets)
        self._test_set = EMNIST(
            root=self.root, train=False, download=True, transform=self.test_transform, split='letters',
            target_transform=self.target_transform,
        )
        self._test_set = Subset(self._test_set, list(range(len(self._test_set))))  # create improper subset with all indices

    def _get_raw_train_set(self):
        train_set = EMNIST(
            self.root, train=True, download=True,
            transform=transforms.Compose([transforms.Resize((self.raw_shape[-1])), transforms.ToTensor(), ]),
            target_transform=self.target_transform, split='letters',
        )
        return Subset(
            train_set,
            np.argwhere(
                np.isin(np.asarray(train_set.targets), self.normal_classes)
            ).flatten().tolist()
        )


class EMNIST(torchvision.datasets.EMNIST):
    def __getitem__(self, index) -> Tuple[torch.Tensor, int, int]:
        img, target = self.data[index], self.targets[index]
        img = img.transpose(0, 1)  # for some reasons, the dimensions are swapped...
        if self.transform is None or isinstance(self.transform, transforms.Compose) and len(self.transform.transforms) == 0:
            img = img.float().div(255).unsqueeze(0)
        else:
            img = Image.fromarray(img.numpy(), mode="L")
        if self.target_transform is not None:
            target = self.target_transform(target)
        if self.transform is not None:
            img = self.transform(img)
        return img, target, index


