# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Functions to preprocess mnist data
"""
import numpy as np
import tensorflow as tf


def binarize(x):
    """
    Binarizes input array x
    """

    return np.random.binomial(1, x)


def drop_dimensions(x_train, x_test, threshold=0.1):
    """
    Removes dimensions with low variance

    Args:
      x_train: training data
      x_test: test data
      threshold: variance threshold for removing dimensions

    Returns:
      x_train: filtered training data
      x_test: filtered test data
      good_dims: dimensions that were retained, by index

    """
    stds = np.std(x_train, axis=0)
    good_dims = np.where(stds > threshold)[0]
    x_train = x_train[:, good_dims]
    x_test = x_test[:, good_dims]
    return x_train, x_test, good_dims


def load_mnist_data(n_datapoints, threshold):
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train = x_train.reshape([x_train.shape[0], -1]) / 256.
    x_test = x_test.reshape([x_test.shape[0], -1]) / 256.
    x_train, y_train = x_train[0:n_datapoints], y_train[0:n_datapoints]
    y_train, y_test = tf.keras.utils.to_categorical(
        y_train, 10), tf.keras.utils.to_categorical(y_test, 10)
    x_train, x_test, good_dims = drop_dimensions(x_train, x_test, threshold)

    return x_train, y_train, x_test, y_test, good_dims


def data_generator_eval(x, y, batch_size):
    """
    Generates an infinite sequence of test data

    Args:
      x: test data
      y: test labels
      batch_size: batch size to yield

    Yields:
      tuples of x,y pairs each of size batch_size

    """
    num = x.shape[0]
    idx = np.random.randint(0, num, batch_size)
    x_batch = x[idx]
    y_batch = y[idx]
    yield (x_batch, y_batch)


def data_generator_train(x, y, batch_size):
    """
    Generates an infinite sequence of data

    Args:
      x: training data
      y: training labels
      batch_size: batch size to yield

    Yields:
      tuples of x,y pairs each of size batch_size

    """

    num = x.shape[0]
    while True:
        # --- Randomly select batch_size elements from the training set
        idx = np.random.randint(0, num, batch_size)
        x_batch = binarize(x[idx])
        y_batch = y[idx]
        # --- Now yield
        yield (x_batch, y_batch)


def build_input_fns(params):
    """Builds an Iterator switching between train and heldout data."""
    x_train, y_train, x_test, y_test, good_dims = load_mnist_data(
        params["n_datapoints"], params["threshold"])

    def gen_train():
        return data_generator_train(x_train, y_train, params["batch_size"])

    def train_input_fn():
        # Build an iterator over training batches.
        dataset = tf.data.Dataset.from_generator(
            gen_train, (tf.float64, tf.int64), (tf.TensorShape([
                params["batch_size"], len(good_dims)
            ]), tf.TensorShape([params["batch_size"], 10])))
        dataset = dataset.prefetch(1)
        return dataset.make_one_shot_iterator().get_next()

    eval_input_fn = tf.estimator.inputs.numpy_input_fn(
        x_test,
        y=y_test,
        batch_size=params["batch_size"],
        num_epochs=1,
        shuffle=True,
        queue_capacity=10000,
        num_threads=1)

    # Build an iterator over the heldout set.

    return train_input_fn, eval_input_fn, good_dims
