# -*- coding: utf-8 -*-

import os
import numpy as np
import torch

### the function provided by CIFAR10 github for loading data
def load_mnist(path, kind='train'):
    import gzip

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)

    return images, labels

### class for data loading and preprocessing
class MNIST(object):
    def __init__(self, data_dir):
        super(MNIST, self).__init__()
        self.n_classes = 10

        self.X_train, self.Y_train = load_mnist(data_dir,'train')
        self.X_test, self.Y_test = load_mnist(data_dir,'t10k')


        self.X_train = self.X_train.reshape(-1,28,28,1) /255.0 ### normalize
        self.X_test = self.X_test.reshape(-1,28,28,1) /255.0 ### normalize

        
        self.X_train = np.transpose(self.X_train, [0,3,1,2]) ### transpose for feeding into NN, there will be 1 channel
        self.X_test = np.transpose(self.X_test, [0,3,1,2])
        
        ### random sample 2000 training data 
        idx = np.random.choice(len(self.Y_train),1500)
        self.X_train = self.X_train[idx]
        self.Y_train = self.Y_train[idx]

        self.n_test = len(self.Y_test)
        self.n_train = len(self.Y_train)


    def to_tensor(self, X, Y, cuda=True):
        X = torch.FloatTensor(X)
        Y = torch.LongTensor(Y)
        if cuda:
            X = X.cuda()
            Y = Y.cuda()
        return X, Y

    def sample_train_data(self, batch_size = 1000, cuda=True):
        idxs = np.random.choice(self.n_train, batch_size, False)
        X = self.X_train[idxs]
        Y = self.Y_train[idxs]
        return self.to_tensor(X, Y, cuda)


    def load_test_list(self, batch_size=5000, cuda=True):
        n_batch = self.n_test // batch_size
        batch_list = []
        for i in range(n_batch):
            X, Y = self.X_test[batch_size*i:batch_size*(i+1)], self.Y_test[batch_size*i:batch_size*(i+1)]
            X, Y = self.to_tensor(X, Y, cuda)
            batch_list.append((X, Y))
        return batch_list

