# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Functions to preprocess all data
"""
import numpy as np

from .random import build_input_fns as build_input_fns_random
from .mnist import build_input_fns as build_input_fns_mnist
from .dsprites import build_input_fns as build_input_fns_dsprites
from .chairs import build_input_fns as build_input_fns_chairs
from .faces import build_input_fns as build_input_fns_faces
from .celeba import build_input_fns as build_input_fns_celeba


def load_dataset(dataset, params):
    # --- Good ol random data to test the pipeline
    if dataset == 'random':
        params["IMAGE_SHAPE"] = [64, 64, 3]
        train_input_fn, eval_input_fn, params[
            "n_datapoints"] = build_input_fns_random(params)
        params["good_dims"], params["n_x"] = range(64 * 64 * 3), 64 * 64 * 3
    elif dataset == 'mnist':
        # --- The MNIST database of handwritten digits, available from this page,
        # --- has a training set of 60,000 examples, and a test set of 10,000 examples.
        # --- It is a subset of a larger set available from NIST.
        # --- The digits have been size-normalized and centered in a 28x28x1 image.
        train_input_fn, eval_input_fn, good_dims = build_input_fns_mnist(
            params)
        params["IMAGE_SHAPE"] = [28, 28, 1]
        # --- We filter the image for directions of low variance
        params["good_dims"], params["n_x"] = good_dims, len(good_dims)
    elif dataset == 'dsprites':
        # --- dSprites is a dataset of 2D shapes procedurally generated from
        # --- 6 ground truth independent latent factors.
        # --- These factors are color, shape, scale, rotation, x and y
        # --- positions of a sprite. All possible combinations of these latents
        # --- are present exactly once, generating N = 737280 total images.
        train_input_fn, eval_input_fn, params[
            "n_datapoints"] = build_input_fns_dsprites(params)
        params["IMAGE_SHAPE"] = [64, 64, 1]
        params["good_dims"], params["n_x"] = range(64 * 64), 64 * 64
    elif dataset == 'celeba':
        # --- CelebFaces Attributes Dataset (CelebA) is a large-scale face
        # --- attributes dataset with more than 200K celebrity images, each with
        # --- 40 attribute annotations. The images in this dataset cover large
        # --- pose variations and background clutter. CelebA has large diversities,
        # --- large quantities, and rich annotations, including 10,177 number of identities,
        # --- 202,599 number of face images, and 5 landmark locations,
        # --- 40 binary attributes annotations per image.
        params["IMAGE_SHAPE"] = [64, 64, 3]
        train_input_fn, eval_input_fn, params[
            "n_datapoints"] = build_input_fns_celeba(params)
        params["good_dims"], params["n_x"] = range(64 * 64 * 3), 64 * 64 * 3
    elif dataset == 'chairs':
        # --- Chairs is a standard ML dataset composed of 64x64 greyscale
        # --- images of chairs. There are 1393 different chair designs and there
        # --- are ground truth factors for azimiuth, elevation and distance.
        # --- There are 62 renders per chair design.
        params["IMAGE_SHAPE"] = [64, 64, 1]
        train_input_fn, eval_input_fn, params[
            "n_datapoints"] = build_input_fns_chairs(params)
        params["good_dims"], params["n_x"] = range(64 * 64), 64 * 64
    elif dataset == 'faces':
        # --- 3D faces is a standard ML dataset composed of 50 people where for
        # --- each there are 21 steps over azimuth and 11 over each of elevation
        # --- and lighting. Images are 64x64 greyscale.
        params["IMAGE_SHAPE"] = [64, 64, 1]
        train_input_fn, eval_input_fn, params[
            "n_datapoints"] = build_input_fns_faces(params)
        params["good_dims"], params["n_x"] = range(64 * 64), 64 * 64
    return train_input_fn, eval_input_fn, params
