from __future__ import print_function

import numbers

import torch
import torch.utils.data as data_utils
import pickle
from scipy.io import loadmat
from copy import deepcopy

import numpy as np

import os
from torch.utils.data import Dataset
import torchvision
from torchvision import transforms
from torchvision.transforms import functional as vf
from torch.utils.data import ConcatDataset

from PIL import Image

import os
import os.path
from os.path import join
import sys
import tarfile


class ToTensorNoNorm():
    def __call__(self, X_i):
        return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1)


class PadToMultiple(object):
    def __init__(self, multiple, fill=0, padding_mode='constant'):
        assert isinstance(multiple, numbers.Number)
        assert isinstance(fill, (numbers.Number, str, tuple))
        assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']

        self.multiple = multiple
        self.fill = fill
        self.padding_mode = padding_mode

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be padded.
        Returns:
            PIL Image: Padded image.
        """
        w, h = img.size
        m = self.multiple
        nw = (w // m + int((w % m) != 0)) * m
        nh = (h // m + int((h % m) != 0)) * m
        padw = nw - w
        padh = nh - h

        out = vf.pad(img, (0, 0, padw, padh), self.fill, self.padding_mode)
        return out

    def __repr__(self):
        return self.__class__.__name__ + '(multiple={0}, fill={1}, padding_mode={2})'.\
            format(self.mulitple, self.fill, self.padding_mode)


class CustomTensorDataset(Dataset):
    """Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """

    def __init__(self, *tensors, transform=None):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
        self.transform = transform

    def __getitem__(self, index):
        from PIL import Image

        X, y = self.tensors
        X_i, y_i, = X[index], y[index]

        if self.transform:
            X_i = self.transform(X_i)
            X_i = torch.from_numpy(np.array(X_i, copy=True))
            X_i = X_i.permute(2, 0, 1)

        return X_i, y_i

    def __len__(self):
        return self.tensors[0].size(0)
    
    
def load_mnist(args, **kwargs):
    # set args
    args.input_size = [1, 28, 28]
    args.input_type = 'continuous'
    args.dynamic_binarization = False
    
    data = np.load("../../binarizedMNIST/data/mnist.npz")
    x_train = data["train_data"].astype(np.uint8).reshape(-1, 1, 28, 28)
    y_train = np.zeros([x_train.shape[0]])
    x_valid = data["valid_data"].astype(np.uint8).reshape(-1, 1, 28, 28)
    y_valid = np.zeros([x_valid.shape[0]])
    x_test = data["test_data"].astype(np.uint8).reshape(-1, 1, 28, 28)
    y_test = np.zeros([x_test.shape[0]])
    
    train = CustomTensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train), transform = None)
    train_loader = data_utils.DataLoader(train, batch_size = args.batch_size, shuffle = True, **kwargs)
    
    valid = CustomTensorDataset(torch.from_numpy(x_valid), torch.from_numpy(y_valid), transform = None)
    valid_loader = data_utils.DataLoader(valid, batch_size = args.batch_size, shuffle = True, **kwargs)
    
    test = CustomTensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test), transform = None)
    test_loader = data_utils.DataLoader(test, batch_size = args.batch_size, shuffle = True, **kwargs)
    
    return train_loader, valid_loader, test_loader, args


def load_custom_data(args, file_name, **kwargs):
    # set args
    args.input_size = [3, 2, 2]
    args.input_type = 'continuous'
    args.dynamic_binarization = False
    
    data = np.load(file_name)
    x_train = data["train_data"].astype(np.uint8).reshape(-1, 3, 2, 2)
    y_train = np.zeros([x_train.shape[0]])
    x_valid = data["test_data"].astype(np.uint8).reshape(-1, 3, 2, 2)
    y_valid = np.zeros([x_valid.shape[0]])
    x_test = data["test_data"].astype(np.uint8).reshape(-1, 3, 2, 2)
    y_test = np.zeros([x_test.shape[0]])
    
    train = CustomTensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train), transform = None)
    train_loader = data_utils.DataLoader(train, batch_size = args.batch_size, shuffle = True, **kwargs)
    
    valid = CustomTensorDataset(torch.from_numpy(x_valid), torch.from_numpy(y_valid), transform = None)
    valid_loader = data_utils.DataLoader(valid, batch_size = args.batch_size, shuffle = True, **kwargs)
    
    test = CustomTensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test), transform = None)
    test_loader = data_utils.DataLoader(test, batch_size = args.batch_size, shuffle = True, **kwargs)
    
    return train_loader, valid_loader, test_loader, args


class RandomCropToTensor():
    def __init__(self, patch_size):
        self.patch_size = patch_size
        
    def __call__(self, x):
        width, height = x.size
        x_idx = np.random.randint(0, width - self.patch_size[0] + 1)
        y_idx = np.random.randint(0, height - self.patch_size[1] + 1)
        x = x.crop(box = (x_idx, y_idx, x_idx + self.patch_size[0], y_idx + self.patch_size[1]))
        return x
    
class Norm():
    def __init__(self):
        pass
        
    def __call__(self, x):
        return torch.tensor(x * 255.0, dtype = torch.uint8)


def load_svhn(args, patch_size = (32, 32), **kwargs):
    # set args
    args.input_size = [3, patch_size[0], patch_size[1]]
    args.input_type = 'continuous'
    args.dynamic_binarization = False
    
    from torchvision.datasets import SVHN
    
    data_transform = transforms.Compose([
        RandomCropToTensor(patch_size),
        transforms.ToTensor(),
        Norm()
    ])
    
    svhn_train_dataset = SVHN(
        root = "data/svhn",
        split = 'train',
        transform = data_transform,
        download = True
    )
    
    svhn_train_loader = torch.utils.data.DataLoader(
        dataset = svhn_train_dataset,
        batch_size = args.batch_size,
        shuffle = True
    )
    
    svhn_test_dataset = SVHN(
        root = "data/svhn",
        split = 'test',
        transform = data_transform,
        download = True
    )
    svhn_test_loader = torch.utils.data.DataLoader(
        dataset = svhn_test_dataset,
        batch_size = args.batch_size,
        shuffle = False
    )
    
    return svhn_train_loader, svhn_test_loader, args


class FourierTrans():
    def __init__(self):
        pass
    
    def __call__(self, img):
        for c_idx in range(3):
            img[c_idx, :, :] = np.fft.fftshift(np.fft.fft2(img[c_idx, :, :]))
        return img


def load_cifar10(args, fourier = False, **kwargs):
    # set args
    args.input_size = [3, 32, 32]
    args.input_type = 'continuous'
    args.dynamic_binarization = False

    from keras.datasets import cifar10
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()

    x_train = x_train.transpose(0, 3, 1, 2)
    x_test = x_test.transpose(0, 3, 1, 2)

    import math

    if args.data_augmentation_level == 2:
        data_transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(),
                transforms.Pad(int(math.ceil(32 * 0.05)), padding_mode='edge'),
                transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
                transforms.CenterCrop(32)
            ])
    elif args.data_augmentation_level == 1:
        data_transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(),
            ])
    else:
        if fourier:
            data_transform = transforms.Compose([
                FourierTrans(),
                transforms.ToPILImage(),
            ])
        else:
            data_transform = transforms.Compose([
                transforms.ToPILImage(),
            ])

    x_val = x_train[-10000:]
    y_val = y_train[-10000:]

    x_train = x_train[:-10000]
    y_train = y_train[:-10000]

    train = CustomTensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train), transform=data_transform)
    train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)

    validation = data_utils.TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val))
    val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs)

    test = data_utils.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test))
    test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs)

    return train_loader, val_loader, test_loader, args

def load_custom_cifar10_patch(args, file_name, **kwargs):
    pass

class RandomCrop():
    def __init__(self, patch_size):
        self.patch_size = patch_size
        
    def __call__(self, x):
        width, height = x.size
        x_idx = np.random.randint(0, width - self.patch_size[0])
        y_idx = np.random.randint(0, height - self.patch_size[1])
        x = x.crop(box = (x_idx, y_idx, x_idx + self.patch_size[0], y_idx + self.patch_size[1]))
        return x
    
    
class RepTrans():
    def __init__(self, rep):
        self.rep = rep
        
    def __call__(self, x):
        if self.rep:
            pixels = x.load()
            pixels[0,1] = pixels[0,0]
            pixels[1,0] = pixels[0,0]
            pixels[1,1] = pixels[0,0]
            return x
        else:
            return x
    

def load_cifar10_patch(args, patch_size, rep = False, batch_size = None, **kwargs):
    # set args
    if args is not None:
        args.input_size = [3, patch_size[0], patch_size[1]]
        args.input_type = 'continuous'
        args.dynamic_binarization = False

    from keras.datasets import cifar10
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()

    x_train = x_train.transpose(0, 3, 1, 2)
    x_test = x_test.transpose(0, 3, 1, 2)

    import math

    data_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(),
        RandomCrop(patch_size),
        RepTrans(rep)
    ])

    x_val = x_train[-10000:]
    y_val = y_train[-10000:]

    x_train = x_train[:-10000]
    y_train = y_train[:-10000]

    train = CustomTensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train), transform=data_transform)
    train_loader = data_utils.DataLoader(train, batch_size=(batch_size if batch_size is not None else args.batch_size), 
                                         shuffle=True, **kwargs)

    validation = CustomTensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val), transform=data_transform)
    val_loader = data_utils.DataLoader(validation, batch_size=(batch_size if batch_size is not None else args.batch_size), 
                                       shuffle=False, **kwargs)

    test = CustomTensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test), transform=data_transform)
    test_loader = data_utils.DataLoader(test, batch_size=(batch_size if batch_size is not None else args.batch_size),
                                        shuffle=False, **kwargs)

    return train_loader, val_loader, test_loader, args


def extract_tar(tarpath):
    assert tarpath.endswith('.tar')

    startdir = tarpath[:-4] + '/'

    if os.path.exists(startdir):
        return startdir

    print('Extracting', tarpath)

    with tarfile.open(name=tarpath) as tar:
        t = 0
        done = False
        while not done:
            path = join(startdir, 'images{}'.format(t))
            os.makedirs(path, exist_ok=True)

            print(path)

            for i in range(50000):
                member = tar.next()

                if member is None:
                    done = True
                    break

                # Skip directories
                while member.isdir():
                    member = tar.next()
                    if member is None:
                        done = True
                        break

                member.name = member.name.split('/')[-1]

                tar.extract(member, path=path)

            t += 1

    return startdir


def load_imagenet(resolution, args, **kwargs):
    assert resolution == 32 or resolution == 64

    args.input_size = [3, resolution, resolution]

    trainpath = '../imagenet{res}/train_{res}x{res}.tar'.format(res=resolution)
    valpath = '../imagenet{res}/valid_{res}x{res}.tar'.format(res=resolution)

    trainpath = extract_tar(trainpath)
    valpath = extract_tar(valpath)

    data_transform = transforms.Compose([
        ToTensorNoNorm()
    ])

    print('Starting loading ImageNet')

    imagenet_data = torchvision.datasets.ImageFolder(
        trainpath,
        transform=data_transform)

    print('Number of data images', len(imagenet_data))

    val_idcs = np.random.choice(len(imagenet_data), size=20000, replace=False)
    train_idcs = np.setdiff1d(np.arange(len(imagenet_data)), val_idcs)

    train_dataset = torch.utils.data.dataset.Subset(
        imagenet_data, train_idcs)
    val_dataset = torch.utils.data.dataset.Subset(
        imagenet_data, val_idcs)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        **kwargs)

    test_dataset = torchvision.datasets.ImageFolder(
        valpath,
        transform=data_transform)

    print('Number of val images:', len(test_dataset))

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        **kwargs)

    return train_loader, val_loader, test_loader, args


def load_dataset(args, **kwargs):

    if args.dataset == 'cifar10':
        train_loader, val_loader, test_loader, args = load_cifar10(args, **kwargs)
    elif args.dataset == 'imagenet32':
        train_loader, val_loader, test_loader, args = load_imagenet(32, args, **kwargs)
    elif args.dataset == 'imagenet64':
        train_loader, val_loader, test_loader, args = load_imagenet(64, args, **kwargs)
    elif args.dataset == 'mnist':
        train_loader, val_loader, test_loader, args = load_mnist(args, **kwargs)
    else:
        raise Exception('Wrong name of the dataset!')

    return train_loader, val_loader, test_loader, args


if __name__ == '__main__':
    class Args():
        def __init__(self):
            self.batch_size = 128
    train_loader, val_loader, test_loader, args = load_imagenet32(Args())
