# coding=utf-8
import os
import sys
import gzip
import struct
import numpy as np
from PIL import Image as processor


URLS = [
    'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
    'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
    'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
    'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
]

TRAINING_SET_IMAGES_FILE = 'train-images-idx3-ubyte.gz'
TRAINING_SET_LABELS_FILE = 'train-labels-idx1-ubyte.gz'
TESTING_SET_IMAGES_FILE = 't10k-images-idx3-ubyte.gz'
TESTING_SET_LABELS_FILE = 't10k-labels-idx1-ubyte.gz'


# download all the dataset files to a specific directory
def download(dir):
    try:
        import wget
        for index in range(len(URLS)):
            print('\nDownloading Dataset File %d / %d...' %(index, len(URLS)))
            wget.download(URLS[index], out=dir)
    except ImportError:
        import urllib.request
        for index in range(len(URLS)):
            print('\nDownloading Dataset Files %d / %d...' % (index, len(URLS)))
            urllib.request.urlretrieve(URLS[index], os.path.join(dir, URLS[index].split('/')[-1]))
    print('\nAll Downloaded.')


# function for parsing idx3 files
# :param idx3_ubyte_file: path to the idx3 file
# :return: numpy array of data
def decode_idx3_ubyte(idx3_ubyte_file):
    # load file as binary data
    bin_data = idx3_ubyte_file

    # Parse file header information, which are magic number, number of images, height and width of each image in sequence
    offset = 0
    fmt_header = '>iiii' # the first 4 rows of the data are in the format of 'int-32', thus, four 'i's are used as the format here
    magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)
    print('\nmagic number:%d, number of images: %d, image size: %d*%d' % (magic_number, num_images, num_rows, num_cols))

    # Parse data
    image_size = num_rows * num_cols
    offset += struct.calcsize(fmt_header)  # obtain the position where the pointer of the cache is pointing at: after loading the first 4 rows, the pointer (offset) is pointing at 0016
    fmt_image = '>' + str(image_size) + 'B'  # the data type of the image pixels is 'unsigned char', which is the format 'B' here
    images = np.empty((num_images, num_rows, num_cols))
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print('loading %d' %(i + 1))
        images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((num_rows, num_cols))
        offset += struct.calcsize(fmt_image)

    return images


# function for parsing idx1 files
# :param idx1_ubyte_file: path to the idx1 file
# :return: numpy array of data
def decode_idx1_ubyte(idx1_ubyte_file):
    # load file as binary data
    bin_data = idx1_ubyte_file

    # Parse file header information, which are magic number and number of labels in sequence
    offset = 0
    fmt_header = '>ii' # the first 2 rows of the data are in the format of 'int-32'
    magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)
    print('\nmagic number:%d, number of images: %d' % (magic_number, num_images))

    # Parse data
    offset += struct.calcsize(fmt_header)
    fmt_image = '>B'
    labels = np.empty(num_images)
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print ('loading %d' %(i + 1))
        labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
        offset += struct.calcsize(fmt_image)
    return labels


# load the files correspondingly
def load(file_path, mode):
    file = gzip.open(file_path, 'rb').read()
    if mode == 'idx3':
        return decode_idx3_ubyte(file)
    elif mode == 'idx1':
        return decode_idx1_ubyte(file)


# check whether the directories already exist,
# if yes, return True; if not, create the directories and return False
def check_path(path):
    if os.path.exists(path):
        return True
    else:
        os.mkdir(path)
        return False


# save all the images to the corresponding directories
def save_images(dir, images, labels):
    total_images = len(images)
    for i in range(total_images):
        if sys.version_info.major >= 3:
            print("\rSaving image %d / %d" %(i + 1, total_images), end="", flush=True)
        label = str(int(labels[i]))
        save_dir = os.path.join(dir, label)
        check_path(save_dir)
        image = processor.fromarray(images[i])
        image = image.convert('RGB')
        image.save(os.path.join(save_dir, '%s_%d.png') %(label, i+1))
    print('...OK!')


# main body
def main(origin_dir, target_dir):
    if check_path(origin_dir) == False:
        download(origin_dir)

    check_path(target_dir)
    train_dir = os.path.join(target_dir, 'train')
    check_path(train_dir)
    test_dir = os.path.join(target_dir, 'test')
    check_path(test_dir)

    print("\nLoading Training Data:")
    save_images(train_dir, load(os.path.join(origin_dir, TRAINING_SET_IMAGES_FILE), 'idx3'),
                load(os.path.join(origin_dir, TRAINING_SET_LABELS_FILE), 'idx1'))

    print("\nLoading Testing Data:")
    save_images(test_dir, load(os.path.join(origin_dir, TESTING_SET_IMAGES_FILE), 'idx3'),
                load(os.path.join(origin_dir, TESTING_SET_LABELS_FILE), 'idx1'))

    print('\nFinished!')


########################################################################################################################


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Extract all the data from MNIST dataset into image folders')
    parser.add_argument('--origin_dir', type=str, default='./_DOWNLOADS/MNIST',
                        help='Directory storing / to store the original dataset files')
    parser.add_argument('--target_dir', type=str, default='./MNIST_Dataset',
                        help='Directory to store the output image data')
    args = parser.parse_args()

    main(args.origin_dir, args.target_dir)