# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Functions specific to the construction of stochastic variables in VAEs
"""
import functools

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions


def make_deterministic_layer(activation, hidden_size, output_size, name,
                             eps=0):
    """ Creates an MLP network in the deterministic part of the inference model.

  Args:
    activation: Activation function in hidden layers.
    input_size: size of input to network
    hidden_size: size of hidden layers
    output_size: size of output

  Returns:
    dict of Multivariate Normal parameters generated by the MLP network.
  """
    dense = functools.partial(tf.keras.layers.Dense,
                              units=hidden_size,
                              kernel_initializer='glorot_normal',
                              bias_initializer='zeros')

    batch_norm = functools.partial(tf.keras.layers.BatchNormalization)

    activ_layer = functools.partial(tf.keras.layers.Activation,
                                    activation=activation)

    mlp = tf.keras.Sequential([
        dense(),
        batch_norm(),
        activ_layer(),
        dense(),
        batch_norm(),
        activ_layer()
    ],
                              name=name + '_network')

    def mapping(deterministic_variable):
        net = mlp(deterministic_variable)
        mu = dense(units=output_size,
                   activation=None,
                   kernel_initializer='glorot_uniform',
                   bias_initializer='zeros',
                   name=name + '_mu')(net)
        var = dense(units=output_size,
                    activation='softplus',
                    kernel_initializer='glorot_uniform',
                    bias_initializer='ones',
                    name=name + '_var')(net) + eps
        return dict(net=net, mu=mu, var=var)

    return mapping


def make_stochastic_layer(activation, hidden_size, output_size, name, eps=0):
    """ Creates an MLP network to generate a stochastic variable

  Args:
    activation: Activation function in hidden layers.
    input_size: size of input to network
    hidden_size: size of hidden layers
    output_size: size of output

  Returns:
    dict of Multivariate Normal parameters generated by the MLP network
  """

    dense = functools.partial(tf.keras.layers.Dense,
                              units=hidden_size,
                              kernel_initializer='glorot_normal',
                              bias_initializer='zeros')

    batch_norm = functools.partial(tf.keras.layers.BatchNormalization)

    activ_layer = functools.partial(tf.keras.layers.Activation,
                                    activation=activation)

    mlp = tf.keras.Sequential([
        dense(),
        batch_norm(),
        activ_layer(),
        dense(),
        batch_norm(),
        activ_layer()
    ],
                              name=name + '_network')

    def mapping(stochastic_variable):
        net = mlp(stochastic_variable)
        mu = dense(units=output_size,
                   activation=None,
                   name=name + '_mu',
                   kernel_initializer='glorot_uniform',
                   bias_initializer='zeros')(net)
        var = dense(units=output_size,
                    activation='softplus',
                    kernel_initializer='glorot_uniform',
                    bias_initializer='ones',
                    name=name + '_var')(net) + eps
        return dict(net=net,
                    prob=tfd.Normal(loc=mu, scale=var, name=name),
                    mu=mu,
                    var=var)

    return mapping
