# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Functions specific to the construction of encoder in VAEs
"""
import functools

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions


def make_conv_encoder(activation,
                      latent_size,
                      image_shape,
                      hidden_size=None,
                      x_dist='Gaussian',
                      eps=0):
    """Creates the encoder function.

  Args:
    activation: Activation function in hidden layers.
    latent_size: Dimensionality of the encoding.
    image_shape: The output image shape.
    hidden_size: size of dense layers
    x_dist: distribution of final layer of decoder - 'Gaussian', 'Bernoulli', or 'Logistic'
    eps: constant to add to var estimation to improve convergence

  Returns:
    encoder: A `callable` mapping a `Tensor` of images to a
      `tfd.Distribution` instance over encodings.
  """

    if hidden_size == None:
        hidden_size = latent_size * 4

    conv = functools.partial(tf.keras.layers.Conv2D, padding="SAME")

    dense = functools.partial(tf.keras.layers.Dense,
                              units=hidden_size,
                              kernel_initializer='glorot_normal',
                              bias_initializer='zeros')

    batch_norm = functools.partial(tf.keras.layers.BatchNormalization)

    activ_layer = functools.partial(tf.keras.layers.Activation,
                                    activation=activation)

    if x_dist == 'Bernoulli':
        encoder_net = tf.keras.Sequential([
            tf.keras.layers.Reshape(image_shape),
            conv(32, 4, 2),
            batch_norm(),
            activ_layer(),
            conv(32, 4, 2),
            batch_norm(),
            activ_layer(),
            conv(64, 4, 2),
            batch_norm(),
            activ_layer(),
            conv(64, 4, 2),
            batch_norm(),
            activ_layer(),
            conv(512, 4, 1, padding="VALID"),
            batch_norm(),
            activ_layer(),
            tf.keras.layers.Flatten()
        ],
                                          name='encoder')
    else:

        encoder_net = tf.keras.Sequential([
            tf.keras.layers.Reshape(image_shape),
            conv(64, 4, 2),
            batch_norm(),
            activ_layer(),
            conv(64, 4, 2),
            batch_norm(),
            activ_layer(),
            conv(128, 4, 2),
            batch_norm(),
            activ_layer(),
            conv(128, 4, 2),
            batch_norm(),
            activ_layer(),
            conv(512, 4, 1, padding="VALID"),
            batch_norm(),
            activ_layer(),
            tf.keras.layers.Flatten()
        ],
                                          name='encoder')

    def encoder(images):
        images = 2 * tf.cast(images, dtype=tf.float32) - 1
        net = encoder_net(images)
        mu = dense(units=latent_size,
                   activation=None,
                   kernel_initializer='glorot_uniform',
                   bias_initializer='zeros',
                   name='encoder_mu')(net)
        var = dense(units=latent_size,
                    activation='softplus',
                    kernel_initializer='glorot_uniform',
                    bias_initializer='ones',
                    name='encoder_var')(net) + eps
        return dict(net=net,
                    prob=tfd.Normal(loc=mu, scale=var, name='posterior'),
                    mu=mu,
                    var=var)

    return encoder


def make_mlp_encoder(activation, latent_size, hidden_size, eps=0):
    """Creates the encoder function.

  Args:
    activation: Activation function in hidden layers.
    latent_size: The dimensionality of the encoding.
    hidden_size: size of dense layers
    eps: constant to add to var estimation to improve convergence

  Returns:
    encoder: A `callable` mapping a `Tensor` of images to a
      `tfd.Distribution` instance over encodings.
  """

    dense = functools.partial(tf.keras.layers.Dense,
                              units=hidden_size,
                              kernel_initializer='glorot_normal',
                              bias_initializer='zeros')

    batch_norm = functools.partial(tf.keras.layers.BatchNormalization)

    activ_layer = functools.partial(tf.keras.layers.Activation,
                                    activation=activation)

    # --- Don't apply BN to input like in https://arxiv.org/pdf/1511.06434.pdf
    encoder_net = tf.keras.Sequential([
        tf.keras.layers.Flatten(),
        dense(),
        batch_norm(),
        activ_layer(),
        dense(),
        batch_norm(),
        activ_layer()
    ],
                                      name='encoder')

    def encoder(images):
        images = 2 * tf.cast(images, dtype=tf.float32) - 1
        net = encoder_net(images)
        mu = dense(units=latent_size,
                   activation=None,
                   kernel_initializer='glorot_uniform',
                   bias_initializer='zeros',
                   name='encoder_mu')(net)
        var = dense(units=latent_size,
                    activation='softplus',
                    kernel_initializer='glorot_uniform',
                    bias_initializer='ones',
                    name='encoder_var')(net) + eps
        return dict(net=net,
                    prob=tfd.Normal(loc=mu, scale=var, name='posterior'),
                    mu=mu,
                    var=var)

    return encoder
