from typing import Tuple, Any
from collections.abc import Iterable

import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.utils import check_random_state

from PIL import Image
import torch
from torchvision import datasets, transforms
import torchvision.transforms.functional as TF
from torch.utils.data import random_split, Subset


def check_pil_input(img):
    if isinstance(img, torch.Tensor):
        img = img.squeeze().numpy()
    if isinstance(img, np.ndarray):
        img = Image.fromarray(img)
    return img


class MNIST(datasets.MNIST):
    def __init__(self, *args, preproc=None, img_transform=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.preproc = preproc
        self.img_transform = img_transform

    def _process_img(self, img, target):
        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = check_pil_input(img)
        if self.preproc is not None:
            img = self.preproc(img)

        if self.transform is not None:
            img = check_pil_input(img)
            img = self.transform(img, target)

        if self.img_transform is not None:
            img = check_pil_input(img)
            img = self.img_transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)
        if not isinstance(img, torch.Tensor):
            img = TF.to_tensor(np.array(img))
        return img, target

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        img, target = self.data[index], int(self.targets[index])
        return self._process_img(img, target)


class MNISTSubset(MNIST):
    """Takes a Subset object together with the MNIST object used to generate it
    and creates a new object which can benefit from MNIST sampling method
    (including transforms, etc.)

    """
    def __init__(self, subset, transform=None, img_transform=None,
                 target_transform=None):
        # For some obscure reason, the preprocessing is inherited by the Subset
        # so, in order to not apply it twice, we don't have any preprocessing
        # here
        self.preproc = None
        self.transform = transform
        self.img_transform = img_transform
        self.target_transform = target_transform
        self.subset = subset

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        img, target = self.subset.__getitem__(index)
        return super()._process_img(img.clone(), target)

    def __len__(self):
        return len(self.subset)


def filter_ds_by_class(ds, classes_to_keep):
    """ Allows to keep only the desired classes in the dataset """
    idx = torch.zeros(len(ds), dtype=bool)
    for c in classes_to_keep:
        idx = torch.logical_or(idx, ds.targets == c).type(torch.bool)
    ds.targets = ds.targets[idx]
    ds.data = ds.data[idx]
    return ds


def _structured_split(
    splitter_class,
    indices,
    ratio,
    groups,
    targets=None,
    random_state=None
):
    if ratio == 1:
        return indices, np.array([])

    assert (
        isinstance(ratio, float) and
        ratio > 0 and ratio < 1
    ), "When ratio is a float, it must be positive and <=1."
    splitter = splitter_class(
        n_splits=1,
        train_size=ratio,
        random_state=random_state
    )
    train_idx, test_idx = list(splitter.split(
        indices,
        y=targets,
        groups=groups
    ))[0]
    return indices[train_idx], indices[test_idx]


def stratified_split(indices, ratio, targets, random_state=None):
    """ Allows to make a stratified slipt of the dataset"""
    return _structured_split(
        splitter_class=StratifiedShuffleSplit,
        indices=indices,
        ratio=ratio,
        groups=None,
        targets=targets,
        random_state=random_state
    )


def make_datasets(
    root,
    train_size=0.34,  # roughly 1000 images per class
    valid_size=0.5,   # 3000 images per class
    classes_to_keep=None,
    random_state=None,
):
    """ Makes train, valid and test sets
    """
    rng = check_random_state(random_state)
    preproc = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])

    train_valid_set = MNIST(
        root=root,
        train=True,
        download=True,
        preproc=preproc,
    )
    test_set = MNIST(
        root=root,
        train=False,
        preproc=preproc,
    )

    if classes_to_keep is not None:
        assert isinstance(classes_to_keep, Iterable),\
            "classes_to_keep should be Iterable"
        train_valid_set = filter_ds_by_class(train_valid_set, classes_to_keep)
        test_set = filter_ds_by_class(test_set, classes_to_keep)

    # The problem with this is that it instantiate two Subset objects, which do
    # not have the methods from MNIST class (no transforms then)
    train_set_idx, valid_set_idx = stratified_split(
        np.arange(len(train_valid_set)),
        1 - valid_size,
        train_valid_set.targets,
        random_state=rng
    )
    naked_valid_set = Subset(train_valid_set, valid_set_idx)
    train_subset_idx, _ = stratified_split(
        train_set_idx,
        train_size,
        train_valid_set.targets[train_set_idx],
        random_state=rng
    )
    naked_train_subset = Subset(train_valid_set, train_subset_idx)

    # This line allows to circumvent the issue above
    train_subset = MNISTSubset(naked_train_subset)
    valid_set = MNISTSubset(naked_valid_set)

    print(">>> Valid set targets <<<")
    valid_first_targets = [valid_set[i][1] for i in range(30)]
    print(valid_first_targets)

    return train_subset, valid_set, test_set


def make_dataloaders(
    root,
    train_kwargs,
    valid_kwargs,
    test_kwargs,
    valid_size=0.2,
):
    """ Makes train, valid and test dataloaders
    """
    train_set, valid_set, test_set = make_datasets(root, valid_size=valid_size)

    train_loader = torch.utils.data.DataLoader(train_set, **train_kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_set, **valid_kwargs)
    test_loader = torch.utils.data.DataLoader(test_set, **test_kwargs)
    return train_loader, valid_loader, test_loader
