# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Functions to preprocess mnist data
"""
import numpy as np
import tensorflow as tf

import os
import sys
import shutil
import zipfile
from six.moves import urllib
import scipy.misc
import cv2
import pickle

URL = [
    'https://www.dropbox.com/s/6rb9s2s2aky4ad2/celebA.npy?dl=1',
    'https://www.dropbox.com/s/j3avffpwe8rkssx/celebA_factors.npy?dl=1'
]


def unzip(filepath):
    print("Extracting: " + filepath)
    dirpath = os.path.dirname(filepath)
    with zipfile.ZipFile(filepath) as zf:
        zf.extractall(dirpath)
    os.remove(filepath)


def load_celeb_a(dirpath):
    if os.path.exists(os.path.join(dirpath, 'celebA')):
        print('Found Celeb-A - skip')
        return
    filename = "img_align_celeba.zip"
    save_path = os.path.join(dirpath, filename)
    if os.path.exists(save_path):
        print('[*] {} already exists'.format(save_path))
    else:
        print('Need to download Celeb-A')
        return
    zip_dir = ''
    with zipfile.ZipFile(save_path) as zf:
        zip_dir = zf.namelist()[0]
        zf.extractall(dirpath)
    os.remove(save_path)
    os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, 'celebA'))


def prepare_data_dir(path='/data/celeba'):
    if not os.path.exists(path):
        os.mkdir(path)


def center_crop(x, crop_h, crop_w, resize_h=64, resize_w=64):
    if crop_w is None:
        crop_w = crop_h
    h, w = x.shape[:2]
    j = int(round((h - crop_h) / 2.))
    i = int(round((w - crop_w) / 2.))
    return scipy.misc.imresize(x[j:j + crop_h, i:i + crop_w],
                               [resize_h, resize_w])


def crop(image, input_height, input_width, resize_height=64, resize_width=64):
    cropped_image = center_crop(image, input_height, input_width,
                                resize_height, resize_width)
    return np.array(cropped_image) / 127.5 - 1.


def inverse_transform(images):
    return (images + 1.) / 2.


def load_and_crop_one_image(path_to_image):
    img_bgr = cv2.imread(path_to_image)
    # Reference: https://stackoverflow.com/a/15074748/
    img_rgb = img_bgr[..., ::-1]
    img_rgb.astype(np.float)
    return crop(img_rgb.astype(np.float),
                input_height=108,
                input_width=108,
                resize_height=64,
                resize_width=64)


def crop_file_dir(dirpath, files):
    return [
        load_and_crop_one_image(os.path.join(dirpath, file)) for file in files
    ]


def create_celebA_latent_factors(dirpath):
    with open(os.path.join(dirpath, 'list_attr_celeba.txt')) as f:
        next(f)
        ncols = len(f.readline().split(' '))

    factors = np.loadtxt(os.path.join(dirpath, 'list_attr_celeba.txt'),
                         skiprows=2,
                         usecols=range(1, (ncols + 1)))
    return factors


def order_filenames(files):
    file_number = np.array([f[0:6] for f in files])
    file_number = np.sort(file_number)
    return [str(f) + '.jpg' for f in file_number]


def download(directory, filename, url):
    """Downloads file"""
    filepath = os.path.join(directory, filename)
    if tf.gfile.Exists(filepath):
        return filepath
    if not tf.gfile.Exists(directory):
        tf.gfile.MakeDirs(directory)
    print("Downloading %s to %s" % (url, filepath))
    urllib.request.urlretrieve(url, filepath)
    return filepath


def download_celebA(dirpath):
    download(dirpath, 'celebA.npy', URL[0])
    download(dirpath, 'celebA_factors.npy', URL[1])
    print('Celeb A ready')


def load_celebA(dirpath):
    download_celebA(dirpath)
    x = np.load(os.path.join(dirpath, 'celebA.npy'))
    y = np.load(os.path.join(dirpath, 'celebA_factors.npy'))
    return x, y


def load_celebA_gs(dirpath, num_samples):
    download_celebA(dirpath)
    x = np.load(os.path.join(dirpath, 'celebA.npy'), mmap_mode='r')
    x = x[:num_samples]
    # --- BT.709 HDTV standard for colorless images
    x = 0.2126 * x[:, :, :, 0] + 0.7152 * x[:, :, :, 1] + 0.0722 * x[:, :, :, 0]
    return x


def data_generator_train(x, y, batch_size):
    """
    Generates an infinite sequence of data

    Args:
      x: training data
      y: training labels
      batch_size: batch size to yield

    Yields:
      tuples of x,y pairs each of size batch_size

    """

    num = x.shape[0]
    while True:
        # --- Randomly select batch_size elements from the training set
        idx = np.random.randint(0, num, batch_size)
        x_batch = x[idx]
        y_batch = y[idx]
        # --- Now yield
        yield (x_batch, y_batch)


def data_generator_eval(x, y, batch_size):
    """
    Generates an infinite sequence of test data

    Args:
      x: test data
      y: test labels
      batch_size: batch size to yield

    Yields:
      tuples of x,y pairs each of size batch_size

    """
    num = x.shape[0]
    idx = np.random.randint(0, num, batch_size)
    x_batch = x[idx]
    y_batch = y[idx]
    yield (x_batch, y_batch)


def build_input_fns(params):
    """Builds an Iterator switching between train and heldout data."""
    x_train, y_train = load_celebA(params["data_dir"])

    def gen_train():
        return data_generator_train(x_train, y_train, params["batch_size"])

    def gen_eval():
        return data_generator_eval(x_train, y_train, params["batch_size"])

    def train_input_fn():
        # Build an iterator over training batches.
        dataset = tf.data.Dataset.from_generator(
            gen_train, (tf.float32, tf.int32),
            (tf.TensorShape([params["batch_size"], 64, 64, 3
                             ]), tf.TensorShape([params["batch_size"], 40])))
        dataset = dataset.prefetch(1)
        return dataset.make_one_shot_iterator().get_next()

    def eval_input_fn():
        # Build an iterator over training batches.
        dataset = tf.data.Dataset.from_generator(
            gen_eval, (tf.float32, tf.int32),
            (tf.TensorShape([params["batch_size"], 64, 64, 3
                             ]), tf.TensorShape([params["batch_size"], 40])))
        dataset = dataset.prefetch(1)
        return dataset.make_one_shot_iterator().get_next()

    # Build an iterator over the heldout set.

    return train_input_fn, eval_input_fn, x_train.shape[0]


if __name__ == '__main__':
    dirpath = '/data/img_align_celeba'
    files = [
        f for f in os.listdir(dirpath)
        if os.path.isfile(os.path.join(dirpath, f))
    ]
    files = order_filenames(files)
    cropped = np.array(crop_file_dir(dirpath, files))
    cropped = cropped.astype('float32')
    np.save('/data/celebA.npy', cropped)
    with open('/data/celebA_filenames.pkl', 'wb') as f:
        pickle.dump(files, f)
    factors = create_celebA_latent_factors(dirpath)
    factors = factors.astype('float32')
    np.save('/data/celebA_factors.npy', factors)
