# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Functions specific to the construction of generative layers in VAEs
"""
import tensorflow as tf
import tensorflow_probability as tfp
from .utils import make_stochastic_layer
from .inference_layers import top_down_init

tfd = tfp.distributions


def make_p_networks(params):
    """
    Generates a list of callable to generate networks in generative model

    Args:
        params: set of run params

    Returns:
        networks from the generative model

    """
    p_nets = []
    for i in reversed(range(params['stochastic_depth'] - 1)):
        #--- Create the next p layer probability dist
        p_nets.append(
            make_stochastic_layer(activation=params["activation"],
                                  hidden_size=params["hidden_size"][i + 1],
                                  output_size=params["latent_size"][i],
                                  name='p_z_' + str(i),
                                  eps=params["lv_eps_z"]))
    p_nets.reverse()
    return p_nets


def gen_path_from_z(decoder, p_nets, samples, params, fixed_layers=2):
    """
    Generates generative model stochastic variables, given a set of fixed z.

    Args:
        decoder: callable to generate decoder model
        p_nets: list of callables to generate generative model network
        samples: deterministic layers of upward pass
        fixed_layers: which layers to keep fixed when generating

    Returns:
        images generated by generative model

    """
    z_samples = [samples]
    # if not params['x_plus']:
    for i in reversed(range(fixed_layers - 1, params['stochastic_depth'] - 1)):
        # new_samples = tf.squeeze(p_nets[i](z_samples[-1])['prob'].sample(1),axis=1)
        # z_samples.append(new_samples)
        z_samples = [p_nets[i](z_samples[-1])['mu']] + z_samples
    # z_samples.reverse()
    if params['x_plus']:
        decoder_likelihood = decoder(tf.concat(z_samples, axis=1))
    else:
        decoder_likelihood = decoder(z_samples[0])
    return dict(recon=tf.reduce_mean(decoder_likelihood.mean(), axis=0))


def gen_inf_sharing(q_net, p_layer, d_layer, q_samples, params):
    """
    Generates inference stochastic variables from deterministic layers,
    AND generative layer network to share information between the inference and
    the generative model.

    Args:
        q_net: callable to generate inference model network
        d_layers: deterministic layers of upward pass
        params: dict of parameters that define the model

    Returns:
        layers and samples from the inference model

    """

    new_q_layer = q_net

    d_net = tf.tile(tf.expand_dims(d_layer["net"], axis=0),
                    multiples=[params["n_samples"], 1, 1])

    combined_input = [d_net, p_layer["net"]]

    # --- If sharing plus is activated then the q_zsamples of the prior layer
    # --- are also fed in to create the new q output
    if (params["gen_inf_sharing"] == "shared+"):
        combined_input.append(q_samples)

    new_q_layer = new_q_layer(tf.concat(combined_input, axis=2))

    new_q_z_samples = tf.squeeze(new_q_layer["prob"].sample(1), axis=0)

    return new_q_layer, new_q_z_samples


def make_skip_gen_path(q_nets,
                       p_nets,
                       q_layers,
                       q_z_samples,
                       latent_prior,
                       params,
                       d_layers=None):
    """
    Creates generative model with skip connections

    ... z_2 -> z_1 -> z_0 -> x
    ... _|______|______|

    Args:
      q_nets: list of callables to generate inference model networks
      p_nets: list of callables to generate generative model network
      q_layers: list of dicts representing layers of generative model
      q_z_samples:  list of samples from inference layers
      latent_prior: prior for top most layer
      params: dict of parameters containing terms for the construction of the generative model

    Returns:
      - layers of the generative model
      - layers and samples from the inference model which may have been updated
        during the construction of the generative model

    """
    # --- Instantiate with encoder output
    # --- Instantiate with latent prior
    p_layers = [latent_prior]

    for i in reversed(range(params['stochastic_depth'] - 1)):
        #--- Create the next p layer probability dist
        new_p_layer = p_nets[i]
        if params["gen_inf_sharing"]:
            samples = q_z_samples[0]
        else:
            samples = q_z_samples[i + 1]
        if (i == (params['stochastic_depth'] - 2)):
            # --- MC sample the output of this new layer, if the first layer
            p_z = new_p_layer(samples)
        else:
            p_z = new_p_layer(tf.concat([q_z_samples[-1], samples], axis=2))
        # --- Use the network of the new p layer as input to update the
        # --- corresponding q_layer
        if params["gen_inf_sharing"] and d_layers is not None:
            q_l, q_z_s = gen_inf_sharing(q_nets[i], p_z, d_layers[i], samples,
                                         params)
            q_layers = [q_l] + q_layers
            q_z_samples = [q_z_s] + q_z_samples

        p_layers.append(p_z)
    p_layers.reverse()

    return p_layers, q_layers, q_z_samples


def make_chain_gen_path(q_nets,
                        p_nets,
                        q_layers,
                        q_z_samples,
                        latent_prior,
                        params,
                        d_layers=None):
    """
    Creates chained generative model

    ... z_2 -> z_1 -> z_0 -> x

    Args:
      q_nets: list of callables to generate inference model networks
      p_nets: list of callables to generate generative model network
      q_layers: list of dicts representing layers of generative model
      q_z_samples:  list of samples from inference layers
      latent_prior: prior for top most layer
      params: dict of parameters containing terms for the construction of the generative model
      d_layers: optional deterministic layers from an 'inverse' model

    Returns:
      - layers of the generative model
      - layers and samples from the inference model which may have been updated
        during the construction of the generative model

    """
    # --- Instantiate with latent prior

    p_layers = [latent_prior]

    for i in reversed(range(params['stochastic_depth'] - 1)):
        new_p_layer = p_nets[i]

        if params["gen_inf_sharing"]:
            samples = q_z_samples[0]
        else:
            samples = q_z_samples[i + 1]
        if params["gen_path"] != 'skip' or (i == (params['stochastic_depth'] -
                                                  2)):
            # --- MC sample the output of this new layer, if the first layer
            p_z = new_p_layer(samples)
        elif params["gen_path"] == 'skip':
            p_z = new_p_layer(tf.concat([q_z_samples[-1], samples], axis=2))
        # --- Use the network of the new p layer as input to update the
        # --- corresponding q_layer
        if params["gen_inf_sharing"] and d_layers is not None:
            q_l, q_z_s = gen_inf_sharing(q_nets[i], p_z, d_layers[i], samples,
                                         params)
            q_layers = [q_l] + q_layers
            q_z_samples = [q_z_s] + q_z_samples

        p_layers.append(p_z)

    p_layers.reverse()

    return p_layers, q_layers, q_z_samples


def make_generative_model(q_nets, p_nets, q_layers, q_z_samples, d_layers,
                          latent_prior, params):
    """
    Creates generative model

    Args:
      q_nets: list of callables to generate inference model networks
      p_nets: list of callables to generate generative model network
      q_layers: list of dicts representing layers of generative model
      q_z_samples:  list of samples from inference layers
      latent_prior: prior for top most layer
      params: dict of parameters containing terms for the construction of the generative model

    Returns:
      - layers of the generative model
      - layers and samples from the inference model which may have been updated
        during the construction of the generative model

    """
    # --- If we're sharing connections between the generative and inference model
    # --- then the q_layers and q_z_samples are instantiated using the top most
    # --- deterministic layer of the upwards pass
    if (d_layers is not None) and (params["gen_inf_sharing"]):
        q_layers, q_z_samples = top_down_init(d_layers, params)
    # --- If the generative path is 'chained'  ... z_2 -> z_1 -> z_0 -> x
    # --- there are no skip connections
    if params["gen_path"] == 'chain':
        p_layers, q_layers, q_z_samples = make_chain_gen_path(
            q_nets, p_nets, q_layers, q_z_samples, latent_prior, params,
            d_layers)
    # --- If the generative path is 'skipped'  ... z_2 -> z_1 -> z_0 -> x
    #                                          ... _|______|______|
    elif params["gen_path"] == 'skip':
        p_layers, q_layers, q_z_samples = make_skip_gen_path(
            q_nets, p_nets, q_layers, q_z_samples, latent_prior, params,
            d_layers)

    return p_layers, q_layers, q_z_samples
