from torchvision import datasets, transforms
import torch

import os
import warnings

from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive
import numpy as np


def load_digit_train_dataset():
    if not os.path.exists('../data_np/svhn_data.npy'):
        ## load mnist
        transform_svhn = transforms.Compose([
            transforms.CenterCrop([28, 28]),
            # transforms.RandomCrop([28,28]),
            transforms.ToTensor(),
            transforms.Normalize([0.4362, 0.4432, 0.4744], [0.1973, 0.2003, 0.1963])
        ])

        dset = datasets.SVHN("../data/svhn", split='test',
                             download=True, transform=transform_svhn)

        test_loader = torch.utils.data.DataLoader(dset, batch_size=len(dset), shuffle=False)
        for batch_idx, (data, labels) in enumerate(test_loader):
            x_svhn, y_svhn = data, labels
            break

        x_svhn = x_svhn.expand(x_svhn.shape[0], 3, 28, 28)
        y_svhn = y_svhn.type(torch.LongTensor)

        np.save('../data_np/svhn_data.npy', x_svhn.numpy())
        np.save('../data_np/svhn_label.npy', y_svhn.numpy())

    x_train, y_train = np.load('../data_np/%s_data.npy' % 'svhn'), np.load('../data_np/%s_label.npy' % 'svhn')
    x_train, y_train = torch.tensor(x_train), torch.tensor(y_train).long()

    return x_train, y_train


def load_digit_test_datasets():
    if not os.path.exists('../data_np/mnist_data.npy'):
        ## load mnist
        transform_mnist = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])

        dset = datasets.MNIST('.../data/mnist', train=False, download=True,
                              transform=transform_mnist)

        test_loader = torch.utils.data.DataLoader(dset, batch_size=len(dset), shuffle=False)
        for batch_idx, (data, labels) in enumerate(test_loader):
            x_mnist, y_mnist = data, labels
            break

        x_mnist = x_mnist.expand(x_mnist.shape[0], 3, 28, 28)
        y_mnist = y_mnist.type(torch.LongTensor)

        np.save('../data_np/mnist_data.npy', x_mnist.numpy())
        np.save('../data_np/mnist_label.npy', y_mnist.numpy())

    if not os.path.exists('../data_np/usps_data.npy'):
        ## load usps
        transform_usps = transforms.Compose([
            transforms.Resize([28, 28]),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])

        dset = datasets.USPS('.../data/usps', train=False,
                             download=True, transform=transform_usps)

        test_loader = torch.utils.data.DataLoader(dset, batch_size=len(dset), shuffle=False)
        for batch_idx, (data, labels) in enumerate(test_loader):
            x_usps, y_usps = data, labels
            break

        x_usps = x_usps.expand(x_usps.shape[0], 3, 28, 28)
        y_usps = y_usps.type(torch.LongTensor)
        np.save('../data_np/usps_data.npy', x_usps.numpy())
        np.save('../data_np/usps_label.npy', y_usps.numpy())

    if not os.path.exists('../data_np/mnistm_data.npy'):
        ## load mnist-m
        transform_mnistm = transforms.Compose([transforms.Resize([28, 28]),
                                               transforms.Grayscale(1),
                                               transforms.ToTensor(),
                                               transforms.Normalize((0.1307,), (0.3081,))])
        # dset = datasets.ImageFolder('../data/mnist-m/MNIST-M/testing',transform=transform_mnistm)
        dset = MNISTM('.../data/mnistm', train=False,
                      download=True, transform=transform_mnistm)
        test_loader = torch.utils.data.DataLoader(dset, batch_size=len(dset),
                                                  shuffle=False)
        for batch_idx, (data, labels) in enumerate(test_loader):
            # print (temp_data)
            x_mnistm, y_mnistm = data, labels
            break

        x_mnistm = x_mnistm.expand(x_mnistm.shape[0], 3, 28, 28)
        y_mnistm = y_mnistm.type(torch.LongTensor)

        np.save('../data_np/mnistm_data.npy', x_mnistm.numpy())
        np.save('../data_np/mnistm_label.npy', y_mnistm.numpy())

    # load dataset
    x_tests, y_tests = [], []
    test_datasets = ['mnist', 'usps', 'mnistm']
    for test_dataset in test_datasets:
        temp_x, temp_y = np.load('../data_np/%s_data.npy' % test_dataset), np.load(
            '../data_np/%s_label.npy' % test_dataset)

        x_tests.append(torch.tensor(temp_x))
        y_tests.append(torch.tensor(temp_y).long())

    return x_tests, y_tests


# https://github.com/liyxi/mnist-m/blob/main/mnist_m.py
class MNISTM(VisionDataset):
    """MNIST-M Dataset.
    """

    resources = [
        ('https://github.com/liyxi/mnist-m/releases/download/data/mnist_m_train.pt.tar.gz',
         '191ed53db9933bd85cc9700558847391'),
        ('https://github.com/liyxi/mnist-m/releases/download/data/mnist_m_test.pt.tar.gz',
         'e11cb4d7fff76d7ec588b1134907db59')
    ]

    training_file = "mnist_m_train.pt"
    test_file = "mnist_m_test.pt"
    classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

    @property
    def train_labels(self):
        warnings.warn("train_labels has been renamed targets")
        return self.targets

    @property
    def test_labels(self):
        warnings.warn("test_labels has been renamed targets")
        return self.targets

    @property
    def train_data(self):
        warnings.warn("train_data has been renamed data")
        return self.data

    @property
    def test_data(self):
        warnings.warn("test_data has been renamed data")
        return self.data

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        """Init MNIST-M dataset."""
        super(MNISTM, self).__init__(root, transform=transform, target_transform=target_transform)

        self.train = train

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError("Dataset not found." +
                               " You can use download=True to download it")

        if self.train:
            data_file = self.training_file
        else:
            data_file = self.test_file

        print(os.path.join(self.processed_folder, data_file))

        self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))

    def __getitem__(self, index):
        """Get images and target for data loader.
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.squeeze().numpy(), mode="RGB")

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        """Return size of dataset."""
        return len(self.data)

    @property
    def raw_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'raw')

    @property
    def processed_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'processed')

    @property
    def class_to_idx(self):
        return {_class: i for i, _class in enumerate(self.classes)}

    def _check_exists(self):
        return (os.path.exists(os.path.join(self.processed_folder, self.training_file)) and
                os.path.exists(os.path.join(self.processed_folder, self.test_file)))

    def download(self):
        """Download the MNIST-M data."""

        if self._check_exists():
            return

        os.makedirs(self.raw_folder, exist_ok=True)
        os.makedirs(self.processed_folder, exist_ok=True)

        # download files
        for url, md5 in self.resources:
            filename = url.rpartition('/')[2]
            download_and_extract_archive(url, download_root=self.raw_folder,
                                         extract_root=self.processed_folder,
                                         filename=filename, md5=md5)

        print('Done!')

    def extra_repr(self):
        return "Split: {}".format("Train" if self.train is True else "Test")