import itertools
from typing import Tuple, List, Union

import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Subset

from xad.datasets.bases import TorchvisionDataset
from xad.utils.logger import Logger
from xad.datasets.mnist import MNIST, EMNIST


class ADColoredMNIST(TorchvisionDataset):
    digits = ("zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine")
    colors = ("gray", "red", "yellow", "green", "cyan", "blue", "pink")
    classes = tuple(
        f"{d}-{c}" for d, c in sorted(
            itertools.product(
                ("zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"),
                ("gray", "red", "yellow", "green", "cyan", "blue", "pink")
            ),
            key=lambda x: ("gray", "red", "yellow", "green", "cyan", "blue", "pink").index(x[1])
        )
    )

    def __init__(self, root: str, normal_classes: List[int], nominal_label: int,
                 train_transform: transforms.Compose, test_transform: transforms.Compose,
                 raw_shape: Tuple[int, int, int], logger: Logger = None, limit_samples: Union[int, List[int]] = np.infty,
                 **kwargs):
        """ AD dataset for a modified MNIST with colors. Implements :class:`xad.datasets.bases.TorchvisionDataset`. """
        super().__init__(
            root, normal_classes, nominal_label, train_transform, test_transform, len(self.classes),
            raw_shape, logger, limit_samples, **kwargs
        )

        self._train_set = ColoredMNIST(
            self.root, train=True, download=True, transform=self.train_transform,
            target_transform=self.target_transform,
        )
        self._train_set = self.create_subset(self._train_set, self._train_set.targets)
        self._test_set = ColoredMNIST(
            root=self.root, train=False, download=True, transform=self.test_transform,
            target_transform=self.target_transform,
        )
        self._test_set = Subset(self._test_set, list(range(len(self._test_set))))  # create improper subset with all indices

    def _get_raw_train_set(self):
        train_set = ColoredMNIST(
            self.root, train=True, download=True,
            transform=transforms.Compose([transforms.Resize((self.raw_shape[-1])), transforms.ToTensor(), ]),
            target_transform=self.target_transform
        )
        return Subset(
            train_set,
            np.argwhere(
                np.isin(np.asarray(train_set.targets), self.normal_classes)
            ).flatten().tolist()
        )


class ColoredMNIST(MNIST):
    def __init__(self, root: str, train: bool = True, transform=None,
                 target_transform=None, download: bool = False, ):
        super().__init__(root, train, transform, target_transform, download)
        n = self.data.size(0)
        self.data = self.data[:, None, :, :].repeat(7, 3, 1, 1)  # 7 = len(colors)
        self.targets = self.targets.repeat(7)
        # "gray", "red", "yellow", "green", "cyan", "blue", "pink"
        colors = [(1., 1., 1.), (1., 0, 0), (1., 1., 0), (0, 1., 0), (0, 1., 1.), (0, 0, 1.), (1., 0, 1.)]
        for i, rgb in enumerate(colors):
            for j, c in enumerate(rgb):
                self.data[n * i:n * (i+1), j, :, :] = (
                        self.data[n * i:n * (i+1), j, :, :] * c
                ).type(torch.uint8)
            self.targets[n * i:n * (i+1)] = self.targets[n * i:n * (i+1)] + (10 * i)
        # print([self.targets[i].item() for i in [j*6000 for j in range(70)]]);
        # imshow(torch.stack([self.data[i] for i in [j*6000 for j in range(70)]]), nrow=10)

    def __getitem__(self, index) -> Tuple[torch.Tensor, int, int]:
        img, target = self.data[index], self.targets[index]
        if self.transform is None or isinstance(self.transform, transforms.Compose) and len(self.transform.transforms) == 0:
            img = img.float().div(255)
        else:
            img = Image.fromarray(img.numpy())
        if self.target_transform is not None:
            target = self.target_transform(target)
        if self.transform is not None:
            img = self.transform(img)
        return img, target, index


class ADThreeEMNIST(TorchvisionDataset):
    def __init__(self, root: str, normal_classes: List[int], nominal_label: int,
                 train_transform: transforms.Compose, test_transform: transforms.Compose,
                 raw_shape: Tuple[int, int, int], logger: Logger = None, limit_samples: Union[int, List[int]] = np.infty,
                 **kwargs):
        """ AD dataset for EMNIST. Implements :class:`xad.datasets.bases.TorchvisionDataset`. """
        super().__init__(
            root, normal_classes, nominal_label, train_transform, test_transform, 26, raw_shape, logger, limit_samples,
            **kwargs
        )

        self._train_set = ThreeEMNIST(
            self.root, train=True, download=True, transform=self.train_transform, split='letters',
            target_transform=self.target_transform,
        )
        self._train_set = self.create_subset(self._train_set, self._train_set.targets)
        self._test_set = ThreeEMNIST(
            root=self.root, train=False, download=True, transform=self.test_transform, split='letters',
            target_transform=self.target_transform,
        )
        self._test_set = Subset(self._test_set, list(range(len(self._test_set))))  # create improper subset with all indices

    def _get_raw_train_set(self):
        train_set = ThreeEMNIST(
            self.root, train=True, download=True,
            transform=transforms.Compose([transforms.Resize((self.raw_shape[-1])), transforms.ToTensor(), ]),
            target_transform=self.target_transform, split='letters',
        )
        return Subset(
            train_set,
            np.argwhere(
                np.isin(np.asarray(train_set.targets), self.normal_classes)
            ).flatten().tolist()
        )


class ThreeEMNIST(EMNIST):
    def __getitem__(self, index) -> Tuple[torch.Tensor, int, int]:
        img, target, index = super().__getitem__(index)
        img = img.repeat(3, 1, 1)
        return img, target, index


class ADColoredEMNIST(TorchvisionDataset):
    def __init__(self, root: str, normal_classes: List[int], nominal_label: int,
                 train_transform: transforms.Compose, test_transform: transforms.Compose,
                 raw_shape: Tuple[int, int, int], logger: Logger = None, limit_samples: Union[int, List[int]] = np.infty,
                 **kwargs):
        """ AD dataset for EMNIST. Implements :class:`xad.datasets.bases.TorchvisionDataset`. """
        super().__init__(
            root, normal_classes, nominal_label, train_transform, test_transform, 26 * 7, raw_shape, logger, limit_samples,
            **kwargs
        )

        self._train_set = ColoredEMNIST(
            self.root, train=True, download=True, transform=self.train_transform, split='letters',
            target_transform=self.target_transform,
        )
        self._train_set = self.create_subset(self._train_set, self._train_set.targets)
        self._test_set = ColoredEMNIST(
            root=self.root, train=False, download=True, transform=self.test_transform, split='letters',
            target_transform=self.target_transform,
        )
        self._test_set = Subset(self._test_set, list(range(len(self._test_set))))  # create improper subset with all indices

    def _get_raw_train_set(self):
        train_set = ColoredEMNIST(
            self.root, train=True, download=True,
            transform=transforms.Compose([transforms.Resize((self.raw_shape[-1])), transforms.ToTensor(), ]),
            target_transform=self.target_transform, split='letters',
        )
        return Subset(
            train_set,
            np.argwhere(
                np.isin(np.asarray(train_set.targets), self.normal_classes)
            ).flatten().tolist()
        )


class ColoredEMNIST(EMNIST):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        n = self.data.size(0)
        self.data = self.data[:, None, :, :].repeat(7, 3, 1, 1)
        self.targets = self.targets.repeat(7)
        # "gray", "red", "yellow", "green", "cyan", "blue", "pink"
        colors = [(1., 1., 1.), (1., 0, 0), (1., 1., 0), (0, 1., 0), (0, 1., 1.), (0, 0, 1.), (1., 0, 1.)]
        for i, rgb in enumerate(colors):
            for j, c in enumerate(rgb):
                self.data[n * i:n * (i+1), j, :, :] = (
                        self.data[n * i:n * (i+1), j, :, :] * c
                ).type(torch.uint8)
            self.targets[n * i:n * (i+1)] = self.targets[n * i:n * (i+1)] + (10 * i)

    def __getitem__(self, index) -> Tuple[torch.Tensor, int, int]:
        img, target = self.data[index], self.targets[index]
        img = img.transpose(1, 2)  # for some reasons, the dimensions are swapped...
        if self.transform is None or isinstance(self.transform, transforms.Compose) and len(self.transform.transforms) == 0:
            img = img.float().div(255)
        else:
            img = Image.fromarray(img.numpy())
        if self.target_transform is not None:
            target = self.target_transform(target)
        if self.transform is not None:
            img = self.transform(img)
        return img, target, index
