# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Functions specific to the construction of inference layers in VAEs
"""
import tensorflow as tf
import tensorflow_probability as tfp
from .utils import make_stochastic_layer
tfd = tfp.distributions


def make_q_networks(params):
    """
    Generates a list of callables to generate networks in inference model

    Args:
        params: set of run params

    Returns:
        networks for the inference model

    """
    q_nets, d_nets = [], None
    if ('deterministic' in params["inf_path"]) or params["gen_inf_sharing"]:
        d_nets = []
        for i in range(1, params['stochastic_depth']):
            d_nets.append(
                make_stochastic_layer(activation=params["activation"],
                                      hidden_size=params["hidden_size"][i],
                                      output_size=params["latent_size"][i],
                                      name='d_' + str(i),
                                      eps=params["lv_eps_z"]))
        # --- Ensure we have a correct name for the deterministic path if
        # --- we are not sharing weights
        if not params["gen_inf_sharing"]:
            assert (params["inf_path"] in [
                'deterministic_inverse', 'deterministic_up'
            ])
        # --- If the inference path is 'inverse' x -> z_0 <- z_1 <- z_2 etc...
        #                                        |____________|______|
    if (params["inf_path"] == 'deterministic_inverse'
        ) or params["gen_inf_sharing"]:
        for i in reversed(range(params['stochastic_depth'] - 1)):
            q_nets.append(
                make_stochastic_layer(params["activation"],
                                      params["hidden_size"][i + 1],
                                      params["latent_size"][i],
                                      name='q_z_' + str(i),
                                      eps=params["lv_eps_z"]))
        q_nets.reverse()
    # --- Else create a regular 'stochastic' inference path
    else:
        for i in range(1, params['stochastic_depth']):
            q_nets.append(
                make_stochastic_layer(params["activation"],
                                      params["hidden_size"][i],
                                      params["latent_size"][i],
                                      name='q_z_' + str(i),
                                      eps=params["lv_eps_z"]))

    return q_nets, d_nets


def make_deterministic_inf_path(d_nets, encoder, features, params):
    """ Iteratively generates upward pass deterministic 'd' parameters.

    Args:
        d_nets: list of callables to generate networks in deterministic upward path
        encoder: encoder network from input data to first stochastic layer
        features: feature data
        params: dict of parameters that define the model
        eps: minimum to apply to variances to avoid underflow

    Returns:
        d_layers: deterministic layers of the inference model
        q_z_samples: samples from the topmost layer

    """
    # --- First stochastic layer is output of encoder network
    features = tf.reshape(features, (params["batch_size"], params["n_x"]))
    d_layers = [encoder(features)]
    # --- Iteratively append new networks that take prior deterministic outputs
    # --- --- as input.
    for i in range(1, params['stochastic_depth']):
        new_d_layer = d_nets[i - 1]
        d_layers.append(new_d_layer(d_layers[-1]['net']))

    return d_layers


def top_down_init(d_layers, params):
    """
    Instantiates the top most stochastic variable using the top most variable of a
    a deterministic pass.

    Args:
        d_layers: deterministic layers in inference model in the form of a dict
                with keys prob, mu and var.
        params: dict of parameters that define the model

    Returns:
        q_layers: top most stochastic layer
        q_z_samples: samples from the topmost layer

    """
    q_layers = [d_layers[-1]]
    q_layers[-1]['prob'] = tfd.Normal(loc=q_layers[-1]['mu'],
                                      scale=q_layers[-1]['var'],
                                      name='q_z_' +
                                      str(params['stochastic_depth'] - 1))

    q_z_samples = [q_layers[-1]['prob'].sample(params["n_samples"])]

    return q_layers, q_z_samples


def bottom_up_init(d_layers, params):
    """
    Instantiates the bottom most stochastic variable using the bottom most variable of a
    a deterministic pass.

    Args:
        d_layers: deterministic layers in inference model in the form of a dict
                with keys prob, mu and var.
        params: dict of parameters that define the model

    Returns:
        q_layers: bottom most stochastic layer
        q_z_samples: samples from the bottom layer

    """
    q_layers = [d_layers[0]]
    q_layers[0]['prob'] = tfd.Normal(loc=q_layers[0]['mu'],
                                     scale=q_layers[0]['var'],
                                     name='q_z_0')

    q_z_samples = [q_layers[0]['prob'].sample(params["n_samples"])]

    return q_layers, q_z_samples


def inverse_inf(q_nets, d_layers, params):
    """
    Generates stochastic variables from deterministic layers to create an
    'inverse' inference model.

    Args:
        d_layers: deterministic layers in inference model in the form of a dict
                with keys prob, mu and var.
        params: dict of parameters that define the model

    Returns:
        layers and samples from the inference model

    """
    # --- Instantiate the top-most q layer
    q_layers, q_z_samples = top_down_init(d_layers, params)

    # --- Traverse the inference model top-down
    for i in reversed(range(0, params['stochastic_depth'] - 1)):
        new_q_layer = q_nets[i]

        # --- Tile the deterministic variables to match the dimensionality
        # --- of the q sample tensors
        input_net = tf.tile(tf.expand_dims(d_layers[i]["net"], axis=0),
                            multiples=[params["n_samples"], 1, 1])

        parent_samples = q_z_samples[-1]

        q_layers.append(
            new_q_layer(tf.concat([input_net, parent_samples], axis=2)))

        q_z_samples.append(tf.squeeze(q_layers[-1]["prob"].sample(1), axis=0))

    q_layers.reverse()
    q_z_samples.reverse()

    return q_layers, q_z_samples


def upwards_inf(q_nets, d_layers, params):
    """
    Generates stochastic variables from deterministic layers to create an
    'upwards' inference model.

    Args:
        q_nets: list of callables to generate inference model networks
        d_layers: deterministic layers in inference model in the form of a dict
                with keys prob, mu and var.
        params: dict of parameters that define the model

    Returns:
        layers and samples from the inference model

    """
    # --- Instantiate the bottom-most q layer
    q_layers, q_z_samples = bottom_up_init(d_layers, params)

    # --- Traverse the inference model bottom-up
    for i in range(1, params['stochastic_depth']):

        new_q_layer = q_nets[i - 1]

        # --- tile the deterministic variables to match the dimensionality
        # --- of the q sample tensors
        input_net = tf.tile(tf.expand_dims(d_layers[i]["net"], axis=0),
                            multiples=[params["n_samples"], 1, 1])

        q_z = new_q_layer(tf.concat([input_net, q_z_samples[-1]], axis=2))

        q_layers.append(q_z)
        # --- Reshape to compensate for some tf probability weirdness
        q_z_samp = tf.squeeze(q_z["prob"].sample(1), axis=0)
        q_z_samples.append(q_z_samp)

    return q_layers, q_z_samples


def make_stochastic_inf_path(q_nets, encoder, features, params):
    """
    Creates stochastic inference model that can be chained or skipped

    vanilla/chained model:
    x -> z_0 -> z_1 -> z_2 etc...

    skip model:
    x -> z_0 -> z_1 -> z_2 etc...
     ____________|______|

    Args:
      q_nets: list of callables to generate inference model networks
      encoder: encoder model -- takes in input data when called.
      features: feature data
      params: dict of parameters containing terms for the construction of the inference model

    Returns:
      layers and samples from the inference model

    """

    # --- Instantiate with encoder output
    features = tf.reshape(features, (params["batch_size"], params["n_x"]))
    q_layers = [encoder(features)]
    # --- We sample from the first q layer
    q_z_samples = [q_layers[-1]["prob"].sample(params["n_samples"])]

    # --- If skip bring tile the input to be of the same dimensionality as
    # --- each of the q_sample tensors
    if params["inf_path"] == 'skip':
        tiled_features = tf.tile(tf.expand_dims(features, axis=0),
                                 multiples=[params["n_samples"], 1, 1])

    for i in range(1, params['stochastic_depth']):
        # --- Sample from this new stochastic layer q
        new_q_layer = q_nets[i - 1]
        # --- If skip bring in the input to each stochastic layer
        if params["inf_path"] == 'skip':
            q_z = new_q_layer(
                tf.concat([tiled_features, q_z_samples[-1]], axis=2))
        else:
            q_z = new_q_layer(q_z_samples[-1])
        q_layers.append(q_z)
        # --- Reshape to compensate for some tf probability weirdness
        q_z_samp = tf.squeeze(q_z["prob"].sample(1), axis=0)
        q_z_samples.append(q_z_samp)

    return q_layers, q_z_samples


def make_inference_model(q_nets, d_nets, encoder, features, params):
    """
    Creates inference model

    Args:
      q_nets: list of callables to generate inference model networks for stochastic layers
      d_nets: list of callables to generate inference model networks for deterministic layers
      encoder: encoder model -- takes in input data when called.
      features: feature data
      params: dict of parameters containing terms for the construction of the inference model

    Returns:
      layers and samples from the inference model

    """
    # --- If the inference path is 'deterministic' x -> d_0 -> d_1 -> d_2 ... -> q_L
    # --- or we went to share pathways with generative model
    q_layers, q_z_samples, d_layers = None, None, None
    if d_nets is not None:
        d_layers = make_deterministic_inf_path(d_nets, encoder, features,
                                               params)
        # --- If the inference path is 'inverse' x -> z_0 <- z_1 <- z_2 etc...
        #                                        |____________|______|
        if params["inf_path"] == 'deterministic_inverse':
            q_layers, q_z_samples = inverse_inf(q_nets, d_layers, params)
        elif params["inf_path"] == 'deterministic_up':
            q_layers, q_z_samples = upwards_inf(q_nets, d_layers, params)
    # --- Else create a regular 'stochastic' inference path
    else:
        q_layers, q_z_samples = make_stochastic_inf_path(
            q_nets, encoder, features, params)

    return q_layers, q_z_samples, d_layers
