# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Functions to preprocess d-sprites data
"""
import numpy as np
import tensorflow as tf
import os
from six.moves import urllib

ROOT_PATH = "https://github.com/deepmind/dsprites-dataset/raw/master/"
FILE_TEMPLATE = "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"


def download(directory, filename):
    """Downloads a file."""
    filepath = os.path.join(directory, filename)
    if tf.gfile.Exists(filepath):
        return filepath
    if not tf.gfile.Exists(directory):
        tf.gfile.MakeDirs(directory)
    url = os.path.join(ROOT_PATH, filename)
    print("Downloading %s to %s" % (url, filepath))
    urllib.request.urlretrieve(url, filepath)
    return filepath


def load_dsprites_data(params):
    file_path = download(params["data_dir"], FILE_TEMPLATE)
    dataset_zip = np.load(file_path, encoding='latin1')
    imgs = dataset_zip['imgs']
    imgs = imgs.astype('float32')
    x_train = imgs[:, :, :, np.newaxis]
    y_train = dataset_zip["latents_classes"]
    y_train = y_train.astype('int32')
    return x_train, y_train, x_train, y_train


def data_generator_train(x, y, batch_size):
    """
    Generates an infinite sequence of data

    Args:
      x: training data
      y: training labels
      batch_size: batch size to yield

    Yields:
      tuples of x,y pairs each of size batch_size

    """

    num = x.shape[0]
    while True:
        # --- Randomly select batch_size elements from the training set
        idx = np.random.randint(0, num, batch_size)
        x_batch = x[idx]
        y_batch = y[idx]
        # --- Now yield
        yield (x_batch, y_batch)


def build_input_fns(params):
    """Builds an Iterator switching between train and heldout data."""
    x_train, y_train, x_test, y_test = load_dsprites_data(params)

    def gen_train():
        return data_generator_train(x_train, y_train, params["batch_size"])

    def train_input_fn():
        # Build an iterator over training batches.
        dataset = tf.data.Dataset.from_generator(
            gen_train, (tf.float32, tf.int32),
            (tf.TensorShape([params["batch_size"], 64, 64, 1
                             ]), tf.TensorShape([params["batch_size"], 6])))
        dataset = dataset.prefetch(1)
        return dataset.make_one_shot_iterator().get_next()

    eval_input_fn = tf.estimator.inputs.numpy_input_fn(
        x_test,
        y=y_test,
        batch_size=params["batch_size"],
        num_epochs=1,
        shuffle=True,
        queue_capacity=10000,
        num_threads=1)

    # Build an iterator over the heldout set.

    return train_input_fn, eval_input_fn, x_train.shape[0]
