import os
import struct
import argparse
import numpy as np


def convert(path_lab, path_img, save=False, delimiter=',', path_out=None):
    with open(path_lab, 'rb') as f, open(path_img, 'rb') as g:
        magic1, size1 = struct.unpack('>II', f.read(8))
        if magic1 != 2049:
            raise ValueError('Magic number mismatch for the label file: '
                             'expected 2049, got {}'.format(magic1))
        magic2, size2, rows, cols = struct.unpack('>IIII', g.read(16))
        if magic2 != 2051:
            raise ValueError('Magic number mismatch for the image file: '
                             'expected 2051, got {}'.format(magic2))
        if size1 != size2:
            raise ValueError('Number of instances do not match: '
                             'got {} for the label file and {} for the image file'.format(size1, size2))
        labels = np.fromstring(f.read(), dtype=np.uint8).reshape(size1, 1)
        images = np.fromstring(g.read(), dtype=np.uint8).reshape(size1, rows*cols)
    if save:
        data = np.concatenate((labels, images), axis = 1)
        if path_out is None:
            path_out = 'infimnist_data'
        np.savetxt(path_out, data, fmt = '%u', delimiter = delimiter)
    return labels, images

def main():
    parser = argparse.ArgumentParser(description='Convert infinite MNIST binary data to a readable format.')
    parser.add_argument('label_data', help='binary data for labels generated by the infimnist executable')
    parser.add_argument('image_data', help='binary data for images generated by the infimnist executable')
    parser.add_argument('-o', '--output', default='infimnist_data', metavar='', help='output file')
    parser.add_argument('-n', '--num_splits', default=8, metavar='', help='number of splits')
    args = vars(parser.parse_args())


    labels, images = convert(args['label_data'], args['image_data'], save = False)
    N = labels.shape[0]
    indices = np.array(list(range(N)))
    np.random.shuffle(indices)
    labels = labels[indices]
    images = images[indices]
    label_splits = np.array_split(labels, args['num_splits'])
    image_splits = np.array_split(images, args['num_splits'])
    for i in range(args['num_splits']):
        labels = label_splits[i].astype('int32')
        images = image_splits[i].astype('int32')
        N, D = images.shape

        f_labels = open(os.path.join(args['output'], 'mnist8m_{0}_labels.bin'.format(i)), mode='wb')
        np.array([0, 1, N], dtype='int32').tofile(f_labels)
        labels.reshape(-1).tofile(f_labels)

        f_images = open(os.path.join(args['output'], 'mnist8m_{0}_features.bin'.format(i)), mode='wb')
        np.array([1, D, N], dtype='int32').tofile(f_images)
        images.reshape(-1).tofile(f_images)

if __name__ == '__main__':
    main()
