import torch
import torchvision
from torchvision import transforms

from skimage.filters import gaussian
from skimage.transform import resize

import numpy as np
import json

DATA_DIR = './data/'

MNIST_DIR = DATA_DIR + 'mnist'
EMNIST_DIR = DATA_DIR + 'emnist'
FASHION_MNIST_DIR = DATA_DIR + 'fashion_mnist'
CIFAR10_DIR = DATA_DIR + 'cifar10'
OMNIGLOT_DIR = DATA_DIR + 'omniglot'

def mnist():
    # train/val
    ds = torchvision.datasets.MNIST(MNIST_DIR, download=True, train=True)
    U = ds.data.float().reshape(-1, 1, 28, 28)
    U /= 78.5675 # normalize by U.std()
    Y = ds.targets

    # test
    ds = torchvision.datasets.MNIST(MNIST_DIR, download=True, train=False)
    U_test = ds.data.float().reshape(-1, 1, 28, 28)
    U_test /= 78.5675 # normalize by U.std()
    Y_test = ds.targets

    return torch.cat([U,U_test]), torch.cat([Y,Y_test]), ds.classes

def emnist_letters():
    # train/val
    ds = torchvision.datasets.EMNIST(EMNIST_DIR, download=True, train=True, split='letters')
    U = ds.data.float().reshape(-1, 1, 28, 28).permute(0,1,3,2)
    U /= 84.3914 # normalize by U.std()
    Y = ds.targets - 1

    # test
    ds = torchvision.datasets.EMNIST(EMNIST_DIR, download=True, train=False, split='letters')
    U_test = ds.data.float().reshape(-1, 1, 28, 28).permute(0,1,3,2)
    U_test /= 84.3914 # normalize by U.std()
    Y_test = ds.targets - 1

    return torch.cat([U,U_test]), torch.cat([Y,Y_test]), ds.classes[1:]

def cifar10(color=True):
    # train/val
    ds = torchvision.datasets.CIFAR10(CIFAR10_DIR, download=True, train=True)
    U = torch.from_numpy(ds.data).float() / 256.0
    U = U.permute(0,3,1,2)
    Y = torch.from_numpy(np.array(ds.targets))

    # test
    ds = torchvision.datasets.CIFAR10(CIFAR10_DIR, download=True, train=False)
    U_test = torch.from_numpy(ds.data).float() / 256.0
    U_test = U_test.permute(0,3,1,2)
    Y_test = torch.from_numpy(np.array(ds.targets))

    # convert to black and white
    if not color:
        U = 0.2989*U[:,:,:,0:1] + 0.5870*U[:,:,:,1:2] + 0.1140*U[:,:,:,2:3]
        U_test = 0.2989*U_test[:,:,:,0:1] + 0.5870*U_test[:,:,:,1:2] + 0.1140*U_test[:,:,:,2:3]

    return torch.cat([U,U_test]), torch.cat([Y,Y_test]), ds.classes

def fashion_mnist():
    # train/val
    ds = torchvision.datasets.FashionMNIST(FASHION_MNIST_DIR, download=True, train=True)
    U = ds.data.float().reshape(-1, 1, 28, 28)
    U /= 90.0212 # normalize by U.std()
    Y = ds.targets

    # test
    ds = torchvision.datasets.FashionMNIST(FASHION_MNIST_DIR, download=True, train=False)
    U_test = ds.data.float().reshape(-1, 1, 28, 28)
    U_test /= 90.0212 # normalize by U.std()
    Y_test = ds.targets

    return torch.cat([U,U_test]), torch.cat([Y,Y_test]), ds.classes

def omniglot(affine_correct=True):
    try:
        if affine_correct:
            x = torch.load(OMNIGLOT_DIR+'/x-affine_corrected.pt')
        else:
            x = torch.load(OMNIGLOT_DIR+'/x.pt')

        labels = torch.load(OMNIGLOT_DIR+'/labels.pt')
        with open(OMNIGLOT_DIR+'/char_map.json', 'r') as f:
                char_map = json.load(f)

    except FileNotFoundError:
        print('Preprocessed data not found. Generating it now')
        _make_omniglot(affine_correct)
        x, labels, char_map = omniglot(affine_correct)

    return x, labels, char_map

#################
# make omniglot #
#################
# def _make_omniglot():
#     transform = transforms.Compose([transforms.Resize(28), transforms.PILToTensor()])
#     ds = torchvision.datasets.Omniglot(
#         OMNIGLOT_DIR, background=True, download=True, transform=transform)
#     ds_test = torchvision.datasets.Omniglot(
#         OMNIGLOT_DIR, background=False, download=True, transform=transform)

#     alphabets = ds._alphabets + ds_test._alphabets
#     alphabets = dict((a,i) for i, a in enumerate(alphabets))
#     characters = ds._characters + ds_test._characters

#     x, labels = [], []
#     for i in range(len(ds)):
#         x_, label_ = ds[i]
#         alpha_ = alphabets[characters[label_].split('/')[0]]

#         x.append(x_)
#         labels.append([alpha_, label_])

#     for i in range(len(ds_test)):
#         x_, label_ = ds_test[i]
#         label_ += 964
#         alpha_ = alphabets[characters[label_].split('/')[0]]

#         x.append(x_)
#         labels.append([alpha_, label_])

#     x = 1.0 - torch.stack(x).float() / 255.0
#     labels = torch.tensor(labels)

#     print('save data')
#     torch.save(x, OMNIGLOT_DIR+'/x.pt')

#     print('save labels')
#     torch.save(labels, OMNIGLOT_DIR+'/labels.pt')

#     print('save char map')
#     with open(OMNIGLOT_DIR+'/char_map.json', 'w') as f:
#         json.dump(characters, f, indent=2)

def _make_omniglot(affine_correct=True):
    # download data
    print('download data')
    ds = torchvision.datasets.Omniglot(
        OMNIGLOT_DIR, background=True, download=True, transform=transforms.PILToTensor())
    ds_test = torchvision.datasets.Omniglot(
        OMNIGLOT_DIR, background=False, download=True, transform=transforms.PILToTensor())

    # merge datasets
    alphabets = ds._alphabets + ds_test._alphabets
    alphabets = dict((a,i) for i, a in enumerate(alphabets))
    characters = ds._characters + ds_test._characters

    # collect into tensors
    print('load and collect data into memory')
    x, labels = [], []
    for i in range(len(ds)):
        x_, label_ = ds[i]
        alpha_ = alphabets[characters[label_].split('/')[0]]

        x.append(x_)
        labels.append([alpha_, label_])

    for i in range(len(ds_test)):
        x_, label_ = ds_test[i]
        label_ += 964
        alpha_ = alphabets[characters[label_].split('/')[0]]

        x.append(x_)
        labels.append([alpha_, label_])

    # invert images
    x = 1.0 - torch.stack(x).float() / 255.0
    labels = torch.tensor(labels)

    # affine correct?
    if affine_correct:
        print('affine correct and downsample images')
        x = torch.stack([standardize(x_[0]).unsqueeze(0) for x_ in x])

        print('saving')
        torch.save(x, OMNIGLOT_DIR+'/x-affine_corrected.pt')
    else:
        # downsample
        print('downsample images')
        x = torch.stack([torch.from_numpy(resize(x_[0].numpy(), (28,28), anti_aliasing=True)).unsqueeze(0) for x_ in x])

        # save
        print('saving')
        torch.save(x, OMNIGLOT_DIR+'/x.pt')

    torch.save(labels, OMNIGLOT_DIR+'/labels.pt')
    with open(OMNIGLOT_DIR+'/char_map.json', 'w') as f:
            json.dump(characters, f, indent=2)

def pad2square(x):
    height, width = x.shape[0], x.shape[1]
    imsize = max(height,width)
    height_pad = imsize - height
    width_pad = imsize - width
    pad = ((height_pad//2, height_pad - height_pad//2), (width_pad//2, width_pad - width_pad//2))
    x = np.pad(x, pad)
    return x

def standardize(x):
    """ follows the procedure documented in EMNIST
    1: we apply sig=1 gaussian
    2: extract ROI bounding box
    3: pad so it is square
    4: pad with 2 pixels on each side
    5: downsample to 28 x 28
    """
    eps = 255*1e-6
    assert(len(x.shape) == 2)

    # convert
    x = x.numpy()

    # blur
    x = gaussian(x, sigma=1, mode='reflect')

    # extract ROI
    ylow, yhigh = (x.sum(1) > eps).nonzero()[0].min(), (x.sum(1) > eps).nonzero()[0].max()
    xlow, xhigh = (x.sum(0) > eps).nonzero()[0].min(), (x.sum(0) > eps).nonzero()[0].max()
    x = x[ylow:yhigh, xlow:xhigh]

    # pad ROI
    x = pad2square(x)
    x = np.pad(x, 2)

    # resize
    x = resize(x,(28,28))

    # convert
    x = torch.from_numpy(x)

    return x

# def emnist(split='letters',device='cpu'):
#     # train/val
#     ds = torchvision.datasets.EMNIST(EMNIST_DIR, download=True, train=True, split=split)
#     U = ds.data.float().reshape(-1, 1, 28, 28).permute(0,1,3,2).to(device)
#     U /= 84.3914 # normalize by U.std()
#     Y = ds.targets.to(device)

#     U_train = U[:100000]
#     U_val = U[100000:]

#     Y_train = Y[:100000]
#     Y_val = Y[100000:]

#     # test
#     ds = torchvision.datasets.EMNIST(EMNIST_DIR, download=True, train=False, split=split)
#     U_test = ds.data.float().reshape(-1, 1, 28, 28).permute(0,1,3,2).to(device)
#     U_test /= 84.3914 # normalize by U.std()
#     Y_test = ds.targets.to(device)

#     # format for digits @ letters
#     label_map = ds.classes
#     if split == 'letters':
#         Y_train -= 1
#         Y_val -= 1
#         Y_test -= 1
#         label_map = label_map[1:]

#     return U_train, Y_train, U_val, Y_val, U_test, Y_test, label_map

# def emnist_byclass():
#     device = 'cpu'
#     split = 'byclass'
#     # train/val
#     ds = torchvision.datasets.EMNIST(EMNIST_DIR, download=True, train=True, split=split)
#     U = ds.data.float().reshape(-1, 1, 28, 28).permute(0,1,3,2).to(device)
#     Y = ds.targets.to(device)

#     # test
#     ds = torchvision.datasets.EMNIST(EMNIST_DIR, download=True, train=False, split=split)
#     U_test = ds.data.float().reshape(-1, 1, 28, 28).permute(0,1,3,2).to(device)
#     Y_test = ds.targets.to(device)

#     # merge
#     U = torch.cat([U,U_test])
#     Y = torch.cat([Y,Y_test])

#     # process
#     U /= 84.3914

#     return U,Y,ds.classes

# def fashion_mnist(device='cpu'):
#     # train/val
#     ds = torchvision.datasets.FashionMNIST(FASHION_MNIST_DIR, download=True, train=True)
#     U = ds.data.float().reshape(-1, 1, 28, 28).to(device)
#     std = U.float().std()

#     U /= std # normalize by U.std()
#     U_train = U[:50000]
#     U_val = U[50000:]

#     Y = ds.targets.to(device)
#     Y_train = Y[:50000]
#     Y_val = Y[50000:]

#     # test
#     ds = torchvision.datasets.FashionMNIST(FASHION_MNIST_DIR, download=True, train=False)
#     U_test = ds.data.float().reshape(-1, 1, 28, 28).to(device)
#     U_test /= std # normalize by U.std()
#     Y_test = ds.targets.to(device)

#     return U_train, Y_train, U_val, Y_val, U_test, Y_test

# def ring(n=64, sig=1.0):
#     """Return gaussian bumps on a ring"""
#     x = torch.arange(-n,n+1)
#     u = torch.zeros(len(x),len(x))
#     for i in x:
#         u[i] = (-1/2*(x.roll(i.item()+n+1)/sig)**2).exp()

#     return u

# def stl(split='train'):
#     # load
#     ds = torchvision.datasets.STL10(STL_DIR, split=split, download=False)
#     print('ds loaded')
#     U = torch.from_numpy(ds.data)
#     U = U.float() / 256.0
#     U = 0.2989*U[:,0,:,:] + 0.5870*U[:,1,:,:] + 0.1140*U[:,2,:,:]
#     U = U.unsqueeze(1)
#     # Y = torch.from_numpy(np.array(ds.labels))

#     return U
