# -*- coding: utf-8 -*-
from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np

import torch.utils.data as data
from torchvision.datasets.utils import download_url, check_integrity


def define_svhn_folder(root, is_train, transform, target_transform, download):
    return SVHN(
        root=root,
        is_train=is_train,
        transform=transform,
        target_transform=target_transform,
        is_download=download,
    )


class SVHN(data.Dataset):
    """`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset.
    Note: The SVHN dataset assigns the label `10` to the digit `0`.
    However, in this Dataset, we assign the label `0` to the digit `0`
    to be compatible with PyTorch loss functions which
    expect the class labels to be in the range `[0, C-1]`

    Args:
        root (string): Root directory of dataset where directory
            ``SVHN`` exists.
        split (string): One of {'train', 'test', 'extra'}.
            Accordingly dataset is selected. 'extra' is Extra training set.
        transform (callable, optional): A function/transform that
            takes in an PIL image and returns a transformed version.
            E.g, ``transforms.RandomCrop``
        target_transform (callable, optional):
            A function/transform that takes in the target and transforms it.
        download (bool, optional): If true,
            downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded,
            it is not downloaded again.
    """

    url = ""
    filename = ""
    file_md5 = ""

    split_list = {
        "train": [
            "http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
            "train_32x32.mat",
            "e26dedcc434d2e4c54c9b2d4a06d8373",
        ],
        "test": [
            "http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
            "test_32x32.mat",
            "eb5a983be6a315427106f1b164d9cef3",
        ],
        "extra": [
            "http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
            "extra_32x32.mat",
            "a93ce644f1a588dc4d68dda5feec44a7",
        ],
    }

    def __init__(
        self,
        root,
        is_train="train",
        transform=None,
        target_transform=None,
        is_download=False,
    ):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.is_train = is_train  # training set or test set or extra set
        self.is_download = is_download

        if self.is_train:
            tr_data = self.load_svhn_data("train")
            ex_data = self.load_svhn_data("extra")
            self.data, self.labels = self.build_training(tr_data, ex_data)
        else:
            self.data, self.labels = self.load_svhn_data("test")

    def load_svhn_data(self, data_type):
        url = self.split_list[data_type][0]
        filename = self.split_list[data_type][1]
        file_md5 = self.split_list[data_type][2]

        if self.is_download:
            self.download(url, filename, file_md5)

        if not self._check_integrity(data_type, filename):
            raise RuntimeError(
                "Dataset not found or corrupted."
                + " You can use download=True to download it"
            )

        data, labels = self._load_svhn_data(filename)
        return data, labels

    def _load_svhn_data(self, filename):
        # import here rather than at top of file because this is
        # an optional dependency for torchvision
        import scipy.io as sio

        # reading(loading) mat file as array
        loaded_mat = sio.loadmat(os.path.join(self.root, filename))

        data = loaded_mat["X"]
        # loading from the .mat file gives an np array of type np.uint8
        # converting to np.int64, so that we have a LongTensor after
        # the conversion from the numpy array
        # the squeeze is needed to obtain a 1D tensor
        labels = loaded_mat["y"].astype(np.int64).squeeze()

        # the svhn dataset assigns the class label "10" to the digit 0
        # this makes it inconsistent with several loss functions
        # which expect the class labels to be in the range [0, C-1]
        np.place(labels, labels == 10, 0)
        data = np.transpose(data, (3, 2, 0, 1))
        return data, labels

    def build_training(self, tr_data, ex_data):
        def get_include_indices(total, exclude):
            return list(set(list(total)) - set(exclude))

        def exclude_samples(data, size_per_class):
            images, labels = data
            exclude_indices = []

            # get exclude indices.
            for label in range(min(labels), max(labels) + 1):
                matched_indices = np.where(labels == label)[0]
                # fix the choice to train data (do not use random.choice)
                exclude_index = matched_indices.tolist()[:size_per_class]
                exclude_indices += exclude_index

            # get include indices
            include_indices = get_include_indices(
                range(images.shape[0]), exclude_indices
            )
            images = images[include_indices, :, :, :]
            labels = labels[include_indices]
            return images, labels

        def build_train(tr_data, ex_data):
            # get indices to exclude.
            selected_tr_images, selected_tr_labels = exclude_samples(tr_data, 400)
            selected_ex_images, selected_ex_labels = exclude_samples(ex_data, 200)
            images = np.concatenate([selected_tr_images, selected_ex_images])
            labels = np.concatenate([selected_tr_labels, selected_ex_labels])
            return images, labels

        return build_train(tr_data, ex_data)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.labels[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(np.transpose(img, (1, 2, 0)))

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)

    def _check_integrity(self, data_type, filename):
        root = self.root
        md5 = self.split_list[data_type][2]
        fpath = os.path.join(root, filename)
        return check_integrity(fpath, md5)

    def download(self, url, filename, file_md5):
        download_url(url, self.root, filename, file_md5)

    def __repr__(self):
        fmt_str = "Dataset " + self.__class__.__name__ + "\n"
        fmt_str += "    Number of datapoints: {}\n".format(self.__len__())
        fmt_str += "    Split: {}\n".format(self.is_train)
        fmt_str += "    Root Location: {}\n".format(self.root)
        tmp = "    Transforms (if any): "
        fmt_str += "{0}{1}\n".format(
            tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp))
        )
        tmp = "    Target Transforms (if any): "
        fmt_str += "{0}{1}".format(
            tmp, self.target_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
        )
        return fmt_str
