# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Functions to preprocess mnist data
"""
import numpy as np
import tensorflow as tf

import os
import sys
import shutil
import zipfile
from six.moves import urllib
import scipy.misc
import cv2
import pickle

URL = [
    'https://www.dropbox.com/s/r2wjhpb4w2tdz0j/3D_faces.npy?dl=1',
    'https://www.dropbox.com/s/eegryjkbg4u8tl6/3D_faces_factors.npy?dl=1'
]


def prepare_data_dir(path='/data/faces'):
    if not os.path.exists(path):
        os.mkdir(path)


def download(directory, filename, url):
    """Downloads file"""
    filepath = os.path.join(directory, filename)
    if tf.gfile.Exists(filepath):
        return filepath
    if not tf.gfile.Exists(directory):
        tf.gfile.MakeDirs(directory)
    print("Downloading %s to %s" % (url, filepath))
    urllib.request.urlretrieve(url, filepath)
    return filepath


def download_faces(dirpath):
    download(dirpath, '3D_faces.npy', URL[0])
    download(dirpath, '3D_faces_factors.npy', URL[1])
    print('faces ready')


def load_faces(dirpath='/data/faces'):
    download_faces(dirpath)
    x = np.load(os.path.join(dirpath, '3D_faces.npy'))
    y = np.load(os.path.join(dirpath, '3D_faces_factors.npy'))
    return x, y


def data_generator_train(x, y, batch_size):
    """
    Generates an infinite sequence of data

    Args:
      x: training data
      y: training labels
      batch_size: batch size to yield

    Yields:
      tuples of x,y pairs each of size batch_size

    """

    num = x.shape[0]
    while True:
        # --- Randomly select batch_size elements from the training set
        idx = np.random.randint(0, num, batch_size)
        x_batch = x[idx].reshape(batch_size, 64, 64, 1)
        y_batch = y[idx]
        # --- Now yield
        yield (x_batch, y_batch)


def data_generator_eval(x, y, batch_size):
    """
    Generates an infinite sequence of test data

    Args:
      x: test data
      y: test labels
      batch_size: batch size to yield

    Yields:
      tuples of x,y pairs each of size batch_size

    """
    n_data_points = x.shape[0]
    for i in range(int(n_data_points / batch_size)):
        idx = range(i * batch_size, (i + 1) * batch_size)
        x_batch = x[idx].reshape(batch_size, 64, 64, 1)
        y_batch = y[idx]
        yield (x_batch, y_batch)


def build_input_fns(params):
    """Builds an Iterator switching between train and heldout data."""
    x_train, y_train = load_faces(params["data_dir"])

    def gen_train():
        return data_generator_train(x_train, y_train, params["batch_size"])

    def gen_eval():
        return data_generator_eval(x_train, y_train, params["batch_size"])

    def train_input_fn():
        # Build an iterator over training batches.
        dataset = tf.data.Dataset.from_generator(
            gen_train, (tf.float32, tf.int32),
            (tf.TensorShape([params["batch_size"], 64, 64, 1
                             ]), tf.TensorShape([params["batch_size"], 4])))
        dataset = dataset.prefetch(1)
        return dataset.make_one_shot_iterator().get_next()

    def eval_input_fn():
        # Build an iterator over training batches.
        dataset = tf.data.Dataset.from_generator(
            gen_eval, (tf.float32, tf.int32),
            (tf.TensorShape([params["batch_size"], 64, 64, 1
                             ]), tf.TensorShape([params["batch_size"], 4])))
        dataset = dataset.prefetch(1)
        return dataset.make_one_shot_iterator().get_next()

    # Build an iterator over the heldout set.

    return train_input_fn, eval_input_fn, x_train.shape[0]


if __name__ == '__main__':
    faces = np.load('3D_faces_incl_latents.npy')
    factors = list()
    faces_reshaped = list()
    factor_max = faces.shape[0:4]
    for i in range(factor_max[0]):
        for j in range(factor_max[1]):
            for k in range(factor_max[2]):
                for m in range(factor_max[3]):
                    faces_reshaped.append(faces[i, j, k, m, :, :])
                    factors.append([
                        float(i) / float(factor_max[0] - 1),
                        float(j) / float(factor_max[1] - 1),
                        float(k) / float(factor_max[2] - 1),
                        float(m) / float(factor_max[3] - 1)
                    ])
    faces_reshaped = np.array(faces_reshaped, dtype='float32') / 255.0
    factors = np.array(factors, dtype='float32')
    np.save('3D_faces.npy', faces_reshaped)
    np.save('3D_faces_factors.npy', factors)
