# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Functions to preprocess chairs data
"""
import numpy as np
import tensorflow as tf

import os
import sys
import shutil
import zipfile
from six.moves import urllib
import scipy.misc
import cv2
import pickle
from glob import glob
import re

URL = [
    'https://www.dropbox.com/s/jv0q1kg66xduzmj/chairs.npy?dl=1',
    'https://www.dropbox.com/s/60epwk5butnnhf2/chairs_factors.npy?dl=1'
]


def prepare_data_dir(path='/data/chairs'):
    if not os.path.exists(path):
        os.mkdir(path)


def center_crop(x, crop_h, crop_w, resize_h=64, resize_w=64):
    if crop_w is None:
        crop_w = crop_h
    h, w = x.shape[:2]
    j = int(round((h - crop_h) / 2.))
    i = int(round((w - crop_w) / 2.))
    return scipy.misc.imresize(x[j:j + crop_h, i:i + crop_w],
                               [resize_h, resize_w])


def crop(image, input_height, input_width, resize_height=64, resize_width=64):
    cropped_image = center_crop(image, input_height, input_width,
                                resize_height, resize_width)
    return np.array(cropped_image) / 255.


def load_and_crop_one_image(path_to_image):
    img_grey = cv2.imread(path_to_image, 0)
    # Reference: https://stackoverflow.com/a/15074748/
    return crop(img_grey.astype(np.float),
                input_height=400,
                input_width=400,
                resize_height=64,
                resize_width=64)


def crop_file_dir(files):
    return [load_and_crop_one_image(file) for file in files]


def download(directory, filename, url):
    """Downloads file"""
    filepath = os.path.join(directory, filename)
    if tf.gfile.Exists(filepath):
        return filepath
    if not tf.gfile.Exists(directory):
        tf.gfile.MakeDirs(directory)
    print("Downloading %s to %s" % (url, filepath))
    urllib.request.urlretrieve(url, filepath)
    return filepath


def download_chairs(dirpath):
    download(dirpath, 'chairs.npy', URL[0])
    download(dirpath, 'chairs_factors.npy', URL[1])
    print('Chairs ready')


def load_chairs(dirpath='/data/chairs'):
    download_chairs(dirpath)
    x = np.load(os.path.join(dirpath, 'chairs.npy'))
    y = np.load(os.path.join(dirpath, 'chairs_factors.npy'))
    return x, y


def data_generator_train(x, y, batch_size):
    """
    Generates an infinite sequence of data

    Args:
      x: training data
      y: training labels
      batch_size: batch size to yield

    Yields:
      tuples of x,y pairs each of size batch_size

    """

    num = x.shape[0]
    while True:
        # --- Randomly select batch_size elements from the training set
        idx = np.random.randint(0, num, batch_size)
        x_batch = x[idx]
        y_batch = y[idx]
        # --- Now yield
        yield (x_batch, y_batch)


def data_generator_eval(x, y, batch_size):
    """
    Generates an infinite sequence of test data

    Args:
      x: test data
      y: test labels
      batch_size: batch size to yield

    Yields:
      tuples of x,y pairs each of size batch_size

    """
    n_data_points = x.shape[0]
    for i in range(int(n_data_points / batch_size)):
        idx = range(i * batch_size, (i + 1) * batch_size)
        x_batch = x[idx]
        y_batch = y[idx]
        yield (x_batch, y_batch)


def build_input_fns(params):
    """Builds an Iterator switching between train and heldout data."""
    x_train, y_train = load_chairs(params["data_dir"])

    def gen_train():
        return data_generator_train(x_train, y_train, params["batch_size"])

    def gen_eval():
        return data_generator_eval(x_train, y_train, params["batch_size"])

    def train_input_fn():
        # Build an iterator over training batches.
        dataset = tf.data.Dataset.from_generator(
            gen_train, (tf.float32, tf.int32),
            (tf.TensorShape([params["batch_size"], 64, 64, 1
                             ]), tf.TensorShape([params["batch_size"], 3])))
        dataset = dataset.prefetch(1)
        return dataset.make_one_shot_iterator().get_next()

    def eval_input_fn():
        # Build an iterator over training batches.
        dataset = tf.data.Dataset.from_generator(
            gen_eval, (tf.float32, tf.int32),
            (tf.TensorShape([params["batch_size"], 64, 64, 1
                             ]), tf.TensorShape([params["batch_size"], 3])))
        dataset = dataset.prefetch(1)
        return dataset.make_one_shot_iterator().get_next()

    # Build an iterator over the heldout set.

    return train_input_fn, eval_input_fn, x_train.shape[0]


if __name__ == '__main__':
    dirpath = 'rendered_chairs'
    files = [
        y for x in os.walk(dirpath) for y in glob(os.path.join(x[0], '*.png'))
    ]
    cropped = np.array(crop_file_dir(files))
    cropped = cropped.astype('float32')[:, :, :, np.new_axis]
    chair_names = np.array([
        re.search('rendered_chairs/(.+?)/renders', file).group(1)
        for file in files
    ])
    chair_indices = np.unique(chair_names, return_inverse=True)[1].tolist()
    chair_phi = np.array(
        [re.search('_p(.+?)_t', file).group(1) for file in files])
    chair_theta = np.array(
        [re.search('_t(.+?)_r', file).group(1) for file in files])
    chair_rho = np.array(
        [re.search('_r(.+?).p', file).group(1) for file in files])
    chair_factors = np.array([chair_phi, chair_theta, chair_rho],
                             dtype='float32').T
    chair_factors = chair_factors / chair_factors.max(0)
    np.save(dirpath + '/chairs.npy', cropped)
    np.save(dirpath + '/chairs_factors.npy', chair_factors)
    np.save(dirpath + '/chairs_names.npy', chair_names)
