# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Functions specific to the construction of decoders in VAEs
"""
import functools

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
import numpy as np
from .logistic_mix import MixtureLogits


def make_mlp_decoder(activation,
                     latent_sizes,
                     hidden_size,
                     output_shape,
                     l2_penalty=0.0,
                     x_plus=False):
    """ Creates the decoder function.

  Args:
    activation: Activation function in hidden layers.
    latent_size: Dimensionality of the encoding.
    output_shape: The output image shape.
    l2_penalty: l2 penalty to apply to decoder network.
    x_plus: boolean to activate seatbelt connections.

  Returns:
    decoder: A `callable` mapping a `Tensor` of encodings to a
      `tfd.Distribution` instance over images.
  """
    dense = functools.partial(
        tf.keras.layers.Dense,
        units=hidden_size,
        activity_regularizer=tf.keras.regularizers.l2(l2_penalty),
        kernel_initializer='glorot_normal',
        bias_initializer='zeros')

    batch_norm = functools.partial(tf.keras.layers.BatchNormalization)

    activ_layer = functools.partial(tf.keras.layers.Activation,
                                    activation=activation)

    # --- If x_plus is activated all Zs are fed as inputs to the decoder
    # ... z_2, z_1, z_0) -> x, hence we sum
    if x_plus:
        input_size = sum(latent_sizes)
    else:
        input_size = latent_sizes[0]

    # --- Don't apply BN to output like in https://arxiv.org/pdf/1511.06434.pdf
    decoder_net = tf.keras.Sequential([
        tf.keras.layers.Reshape((-1, input_size)),
        dense(input_shape=(input_size, )),
        batch_norm(),
        activ_layer(),
        dense(),
        batch_norm(),
        activ_layer(),
        dense(units=np.prod(output_shape), activation=None)
    ],
                                      name='decoder')

    def decoder(codes):
        logits = decoder_net(codes)
        # --- Binary output
        return tfd.Independent(tfd.Bernoulli(logits=logits, dtype=tf.float32),
                               reinterpreted_batch_ndims=1,
                               name="image")

    return decoder


def make_conv_decoder(activation,
                      latent_sizes,
                      image_shape,
                      l2_penalty,
                      x_dist='Gaussian',
                      n_x_mixture=None,
                      x_plus=False):
    """Creates the decoder function.

  Args:
    activation: Activation function in hidden layers.
    latent_size: Dimensionality of the encoding.
    image_shape: The output image shape.
    x_dist: distribution of final layer of decoder - 'Gaussian', 'Bernoulli', or 'Logistic'
    x_plus: boolean to activate seatbelt connections

  Returns:
    decoder: A `callable` mapping a `Tensor` of encodings to a
      `tfd.Distribution` instance over images.
  """
    deconv = functools.partial(tf.keras.layers.Conv2DTranspose, padding="SAME")

    dense = functools.partial(
        tf.keras.layers.Dense,
        activity_regularizer=tf.keras.regularizers.l2(l2_penalty))

    batch_norm = functools.partial(tf.keras.layers.BatchNormalization)

    activ_layer = functools.partial(tf.keras.layers.Activation,
                                    activation=activation)

    # --- If x_plus is activated all Zs are fed as inputs to the decoder
    # ... z_2, z_1, z_0) -> x, hence we sum
    if x_plus:
        input_size = sum(latent_sizes)
    else:
        input_size = latent_sizes[0]

    if x_dist == 'Bernoulli':
        decoder_net = tf.keras.Sequential([
            deconv(512, 1, 1, padding="VALID"),
            batch_norm(),
            activ_layer(),
            deconv(64, 4, 1, padding="VALID"),
            batch_norm(),
            activ_layer(),
            deconv(64, 4, 2),
            batch_norm(),
            activ_layer(),
            deconv(32, 4, 2),
            batch_norm(),
            activ_layer(),
            deconv(32, 4, 2),
            batch_norm(),
            activ_layer(),
            deconv(1, 4, 2)
        ],
                                          name='decoder')
    else:
        decoder_net = tf.keras.Sequential([
            deconv(512, 1, 1, padding="VALID"),
            batch_norm(),
            activ_layer(),
            deconv(128, 4, 1, padding="VALID"),
            batch_norm(),
            activ_layer(),
            deconv(128, 4, 2),
            batch_norm(),
            activ_layer(),
            deconv(64, 4, 2),
            batch_norm(),
            activ_layer(),
            deconv(64, 4, 2),
            batch_norm(),
            activ_layer(),
            deconv(3, 4, 2)
        ],
                                          name='decoder')

    def decoder(codes):
        original_shape = tf.shape(codes)
        output_shape = [np.prod(image_shape)]
        # --- Collapse the sample and batch dimension and convert to rank-4
        # --- tensor for use with a convolutional decoder network.
        codes = tf.reshape(codes, (-1, 1, 1, input_size))
        decoder_output = decoder_net(codes)
        if x_dist == 'Bernoulli':
            logits = tf.reshape(decoder_output,
                                shape=tf.concat(
                                    [original_shape[:-1], output_shape],
                                    axis=0))
            return tfd.Independent(tfd.Bernoulli(logits=logits,
                                                 dtype=tf.float32),
                                   reinterpreted_batch_ndims=1,
                                   name="image")
        if x_dist == 'Gaussian':
            image_mean = tf.reshape(decoder_output,
                                    shape=tf.concat(
                                        [original_shape[:-1], output_shape],
                                        axis=0))
            return tfd.Independent(tfd.Normal(loc=image_mean,
                                              scale=tf.constant(
                                                  np.sqrt(0.1,
                                                          dtype=np.float32))),
                                   reinterpreted_batch_ndims=1,
                                   name="image")
        if x_dist == 'Logistic':
            # first do Network in Network
            logistic_parameters = dense(10 * n_x_mixture,
                                        activation=None)(decoder_output)
            return MixtureLogits(logistic_parameters, n_x_mixture)

    return decoder
