# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Functions specific to the construction of stochastic layers in LVAEs
"""
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

from .utils import make_stochastic_layer, make_deterministic_layer


def make_deterministic_upward_pass(encoder, features, params, eps=0):
    """ Iteratively generates upward pass deterministic 'd' parameters.

    Args:
        encoder: encoder network from input data to first stochastic layer
        features: feature data
        params: dict of parameters that define the model
        eps: minimum to apply to variances to avoid underflow

    Returns:
        upward_pass: deterministic layers of the inference model

    """
    # --- First stochastic layer is output of encoder network
    features = tf.reshape(features, (params["batch_size"], params["n_x"]))
    upward_pass = [encoder(features)]
    # --- Iteratively append new networks that take prior deterministic outputs
    # --- --- as input.
    for i in range(params['stochastic_depth'] - 1):
        new_d_layer = make_deterministic_layer(
            activation=params["activation"],
            hidden_size=params["hidden_size"][i + 1],
            output_size=params["latent_size"][i + 1],
            name='d_' + str(i + 1),
            eps=eps)
        upward_pass.append(new_d_layer(upward_pass[-1]['net']))

        tf.summary.histogram("d/mu/layer_" + str(i),
                             tf.reshape(upward_pass[-1]['mu'], [-1]))
        tf.summary.histogram("d/sig/layer_" + str(i),
                             tf.reshape(upward_pass[-1]['var'], [-1]))
    return upward_pass


def make_q_layer(p_params, d_params, name):
    """ Here we use downward pass 'p' parameters and upward pass 'd' parameters
    to create a new set of 'q' parameters by combining Gaussian parameters.

    Args:
        p_params: dict of parameters from ith p downward layer
        d_params: dict of parameters from ith q upward layer
        name: str to to assign to layers

    Returns:
        new_q_params: dict of tf probability MultivariateNormalDiag and new
            combined parameter values.

    """

    # --- Take reciprocal of p and q layers
    inv_d_var_sq = tf.reciprocal(d_params['var']**2)
    inv_p_var_sq = tf.reciprocal(p_params['var']**2)
    # --- Calculate combined variance
    new_var = tf.sqrt(tf.reciprocal(inv_d_var_sq + inv_p_var_sq))
    # --- Calculate combined mean
    new_mu = (d_params['mu'] * inv_d_var_sq) + (
        p_params['mu'] * inv_p_var_sq) / (inv_d_var_sq + inv_p_var_sq)

    return dict(
        prob=tfd.Normal(loc=new_mu, scale=new_var, name=name),
        mu=new_mu,
        var=new_var)


def make_stochastic_downward_pass(upward_pass, latent_prior, params, eps=0):
    """
    Iteratively generates downward pass 'p' parameters and upward pass 'q'
        parameters.

    Args:
        upward_pass: dict of parameters detailing dth determinisitic layer
        latent_prior: prior for the final stochastic layer
        params: dict of parameters that define the model
        eps: minimum to apply to variances to avoid underflow

    Returns:
        downward_pass: layers of the generative model
        q_layers: inference layers genrated from combination of generative model
            and 'd' deterministic layers
        q_z_samples: MC samples from q_layers


    """
    # --- Prior is the first layer of the generative model
    downward_pass = [latent_prior]
    # --- Final q layer is the MultivariateNormalDiag generated by final
    # --- --- deterministic parameters
    q_layers = [upward_pass[-1]]
    q_layers[-1]['prob'] = tfd.Normal(
        loc=q_layers[-1]['mu'],
        scale=q_layers[-1]['var'],
        name='q_z_' + str(params['stochastic_depth'] - 1))
    # --- We sample from the final q layer
    q_z_samples = [q_layers[-1]['prob'].sample(params["n_samples"])]

    # --- Traverse stochastic layers in 'reverse' / top-down fashion
    for i in reversed(range(params['stochastic_depth'] - 1)):
        # --- Make a new generative p layer
        new_p_layer = make_stochastic_layer(
            activation=params["activation"],
            hidden_size=params["hidden_size"][i + 1],
            output_size=params["latent_size"][i],
            name='p_z_' + str(i),
            eps=eps)

        # --- MC sample the output of this new layer
        p_output = new_p_layer(q_z_samples[-1])
        downward_pass.append(p_output)

        # --- Use the sampled p params and the deterministic 'd' outputs to
        # --- --- create an upward q layer.
        q_layers.append(
            make_q_layer(p_output, upward_pass[i], name='q_z_' + str(i)))

        # --- Sample from this new stochastic layer q
        q_z_samples.append(tf.squeeze(q_layers[-1]['prob'].sample(1)))

    #--- Now reverse to index by stochastic layer index
    q_layers.reverse()
    q_z_samples.reverse()
    downward_pass.reverse()

    return downward_pass, q_layers, q_z_samples


def make_deterministic_downward_pass(upward_pass, latent_prior, params, eps=0):
    """
    Iteratively generates downward pass 'p' parameters and upward pass 'q'
        parameters.

    Args:
        upward_pass: dict of parameters detailing dth determinisitic layer
        latent_prior: prior for the final stochastic layer
        params: dict of parameters that define the model
        eps: minimum to apply to variances to avoid underflow

    Returns:
        downward_pass: deterministic layers of the generative model
        q_layers: inference layers genrated from combination of generative model
            and 'd' deterministic layers
        q_z_samples: MC samples from q_layers


    """
    # --- Prior is the first layer of the generative model
    downward_pass = [latent_prior]
    # --- Final q layer is the MultivariateNormalDiag generated by final
    # --- --- deterministic parameters
    q_layers = [upward_pass[-1]]
    q_layers[-1]['prob'] = tfd.Normal(
        loc=q_layers[-1]['mu'],
        scale=q_layers[-1]['var'],
        name='q_z_' + str(params['stochastic_depth'] - 1))
    # --- We sample from the final q layer
    q_z_samples = [q_layers[-1]['prob'].sample(params["n_samples"])]

    # --- Traverse stochastic layers in 'reverse' / top-down fashion
    for i in reversed(range(params['stochastic_depth'] - 1)):
        if (i == (params['stochastic_depth'] - 2)):
            new_p_layer = make_deterministic_layer(
                activation=params["activation"],
                hidden_size=params["hidden_size"][i + 1],
                output_size=params["latent_size"][i],
                name='p_z_' + str(i),
                eps=eps)
            # --- MC sample the output of this new layer, if the first layer
            p_output = new_p_layer(q_z_samples[-1])
        else:
            new_p_layer = make_deterministic_layer(
                activation=params["activation"],
                hidden_size=params["hidden_size"][i + 1],
                output_size=params["latent_size"][i],
                name='p_z_' + str(i),
                eps=eps)
            p_output = new_p_layer(
                tf.concat([downward_pass[-1]["net"], q_z_samples[-1]], axis=2))

        p_output["prob"] = tfd.Normal(
            loc=p_output["mu"], scale=p_output["var"], name='p_z_' + str(i))

        downward_pass.append(p_output)

        # --- Use the sampled p params and the deterministic 'd' outputs to
        # --- --- create an upward q layer.
        q_layers.append(
            make_q_layer(p_output, upward_pass[i], name='q_z_' + str(i)))

        # --- Sample from this new stochastic layer q
        q_z_samples.append(tf.squeeze(q_layers[-1]['prob'].sample(1)))

    #--- Now reverse to index by stochastic layer index
    q_layers.reverse()
    q_z_samples.reverse()
    downward_pass.reverse()

    return downward_pass, q_layers, q_z_samples
