import os
import numpy as np
import pickle

import keras.backend as K
from keras.utils import np_utils


def one_hot(x, n):
    """
    convert index representation to one-hot representation
    """
    x = np.array(x)
    assert x.ndim == 1
    return np.eye(n)[x]


def load_mnist(path='/home1/leishiye/dataset/mnist.npz', flatten=False):
    with np.load(path) as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']

    # Adapt the data as an input of a fully-connected (flatten to 1D)
    if flatten:
        x_train = x_train.reshape(60000, 784)
        x_test = x_test.reshape(10000, 784)

    # Normalize data
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')

    x_train = x_train / 255
    x_test = x_test / 255

    # Adapt the labels to the one-hot vector syntax required by the softmax
    from keras.utils import np_utils
    y_train = np_utils.to_categorical(y_train, 10)
    y_test = np_utils.to_categorical(y_test, 10)

    return (x_train, y_train), (x_test, y_test)


def _load_batch(file):
    with open(file, 'rb') as fo:
        d = pickle.load(fo, encoding='bytes')
        d_decoded = {}
        for k, v in d.items():
            d_decoded[k.decode('utf8')] = v
        d = d_decoded
        data = d['data']
        labels = d['labels']
        data = data.reshape(data.shape[0], 3, 32, 32)
    return data, labels


def load_cifar10(path='datasets/cifar-10-batches-py'):
    """Loads CIFAR10 dataset.
    # Returns
        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
    """

    num_train_samples = 50000

    x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8')
    y_train = np.empty((num_train_samples,), dtype='uint8')

    for i in range(1, 6):
        fpath = os.path.join(path, 'data_batch_' + str(i))
        (x_train[(i - 1) * 10000: i * 10000, :, :, :],
         y_train[(i - 1) * 10000: i * 10000]) = _load_batch(fpath)

    fpath = os.path.join(path, 'test_batch')
    x_test, y_test = _load_batch(fpath)

    y_train = np.reshape(y_train, (len(y_train), 1))
    y_test = np.reshape(y_test, (len(y_test), 1))

    if K.image_data_format() == 'channels_last':
        x_train = x_train.transpose(0, 2, 3, 1)
        x_test = x_test.transpose(0, 2, 3, 1)

    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    y_train = np_utils.to_categorical(y_train, 10)
    y_test = np_utils.to_categorical(y_test, 10)
    return (x_train, y_train), (x_test, y_test)



def _grayscale(a):
    print(a.reshape(a.shape[0], 3, 32, 32).mean(1).reshape(a.shape[0], -1))
    return a.reshape(a.shape[0], 3, 32, 32).mean(1).reshape(a.shape[0], -1)


def _load_batch_cifar100(path="/home1/leishiye/dataset/cifar-100-python", dtype='float64', containing_channel=True):
    """
    load a batch in the CIFAR-100 format
    """
    batch = np.load(path, encoding='bytes', allow_pickle=True)
    data = batch[b'data'] / 255.0
    labels = one_hot(batch[b'fine_labels'], n=100)
    if containing_channel:
        data = data.reshape(data.shape[0], 3, 32, 32).transpose((0, 2, 3, 1))
    return data.astype(dtype), labels.astype(dtype)


def load_cifar100(path="/home1/leishiye/dataset/cifar-100-python", dtype='float64', grayscale=False, containing_channel=True):
    train_path = os.path.join(path, "train")
    test_path = os.path.join(path, "test")
    x_train, t_train = _load_batch_cifar100(train_path, dtype=dtype, containing_channel=containing_channel)
    x_test, t_test = _load_batch_cifar100(test_path, dtype=dtype, containing_channel=containing_channel)

    if grayscale:
        x_train = _grayscale(x_train)
        x_test = _grayscale(x_test)
    return x_train, t_train, x_test, t_test