# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Functions to preprocess mnist data
"""
import numpy as np
import tensorflow as tf
import os
from six.moves import urllib

ROOT_PATH = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/"
FILE_TEMPLATE = "binarized_mnist_{split}.amat"


def download(directory, filename):
    """Downloads a file."""
    filepath = os.path.join(directory, filename)
    if tf.gfile.Exists(filepath):
        return filepath
    if not tf.gfile.Exists(directory):
        tf.gfile.MakeDirs(directory)
    url = os.path.join(ROOT_PATH, filename)
    print("Downloading %s to %s" % (url, filepath))
    urllib.request.urlretrieve(url, filepath)
    return filepath


def static_mnist_dataset(directory, split_name):
    """Returns binary static MNIST tf.data.Dataset."""
    amat_file = download(directory, FILE_TEMPLATE.format(split=split_name))
    dataset = tf.data.TextLineDataset(amat_file)
    str_to_arr = lambda string: np.array([c == b"1" for c in string.split()])

    def _parser(s):
        booltensor = tf.py_func(str_to_arr, [s], tf.bool)
        reshaped = tf.reshape(booltensor, [28, 28, 1])
        return tf.cast(reshaped, dtype=tf.float32), tf.constant(0, tf.int32)

    return dataset.map(_parser)


def build_input_fns(data_dir, batch_size):
    """Builds an Iterator switching between train and heldout data."""

    # Build an iterator over training batches.
    def train_input_fn():
        dataset = static_mnist_dataset(data_dir, "train")
        dataset = dataset.shuffle(50000).repeat().batch(batch_size)
        dataset = dataset.prefetch(1)
        return dataset.make_one_shot_iterator().get_next()

    # Build an iterator over the heldout set.
    def eval_input_fn():
        eval_dataset = static_mnist_dataset(data_dir, "valid")
        eval_dataset = eval_dataset.batch(batch_size)
        eval_dataset = eval_dataset.prefetch(1)
        return eval_dataset.make_one_shot_iterator().get_next()

    return train_input_fn, eval_input_fn
