# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Defines ELBO loss functions
"""
import tensorflow as tf
import tensorflow_probability as tfp
from metrics.mig import mutual_info_metric_shapes
import numpy as np
tfd = tfp.distributions


def calculate_decoder_likelihood(decoder, features, q_z_samples, x_plus=False):
    """ Calculates decoder likelihood

    Args:
        decoder: tfd object of decoder network
        features: feature data
        q_z_samples: list of samples from inference model's stochastic layers

    Returns:
        decoder_likelihood: tfd distribution of decoder likelihood
        log_p_x_z: log likelihood of input data x given stochastic variable z
        mean_log_p_x_z: mean of log_p_x_z, across samples

    """
    # --- If x_plus is activated all Zs are fed as inputs to the decoder
    # ... z_2, z_1, z_0) -> x
    if x_plus:
        input_samples = tf.concat(q_z_samples, axis=2)
    else:
        input_samples = q_z_samples[0]

    decoder_likelihood = decoder(input_samples)
    # ---`log_p_x_z` is just the negative log likelihood.
    log_p_x_z = -decoder_likelihood.log_prob(features,
                                             name='decoder_likelihood')
    mean_log_p_x_z = tf.reduce_mean(log_p_x_z)
    tf.summary.scalar("log_p_x_z", mean_log_p_x_z)

    return decoder_likelihood, log_p_x_z, mean_log_p_x_z


def calculate_kl_divs(p_layers,
                      q_layers,
                      q_z_samples,
                      analytic_kl=False,
                      beta=[1.0]):
    """ Calculates KL divergence across stochastic layers between inference 'q'
        layers and generative 'p' layers

    Args:
        p_layers: list of tfd objects from generative model
        q_layers: list of tfd objects from inference model
        q_z_samples: list of samples from inference model's stochastic layers
        analytic_kl: boolean determining if kl is approximated using sampling or calculated analytically
        beta: additional penalty to apply to KL term to encourage disentaglement

    Returns:
        kl_z: kl divergence for each stochastic layer
        mean_kl_z: mean of kl_z
    """
    kl_z_list = []
    if analytic_kl:
        for p, q in zip(p_layers, q_layers):
            kl_z_list.append(tfd.kl_divergence(q, p))
    else:
        for i, (p, q) in enumerate(zip(p_layers, q_layers)):
            samples = q_z_samples[i]
            kl = beta * (tf.reduce_sum(q['prob'].log_prob(samples), axis=2) -
                         tf.reduce_sum(p['prob'].log_prob(samples), axis=2))
            tf.summary.scalar("kl_z_og/layer_" + str(i), tf.reduce_mean(kl))
            kl_z_list.append(kl)

    # --- Sum list of kl divergences
    kl_z = tf.add_n(kl_z_list)

    # --- Now take mean over minibatch
    mean_kl_z = tf.reduce_mean(kl_z)
    tf.summary.scalar("kl_z_og", mean_kl_z)

    return kl_z, mean_kl_z


def calculate_tc_decomp_kl_divs(p_layers,
                                q_layers,
                                q_z_samples,
                                params,
                                beta,
                                lamb,
                                alpha=1,
                                gamma=1,
                                factors_batch=None):
    """ Calculates TC decomposed KL divergence across stochastic layers between inference 'q'
        layers and generative 'p' layers.


    Args:
        p_layers: list of tfd objects from generative model
        q_layers: list of tfd objects from inference model
        q_z_samples: list of samples from inference model's stochastic layers
        params: model run params
        beta: penalty to apply to apply to top layer TC term to encourage disentanglement

    Returns:
        kl_z: kl divergence for each stochastic layer
        mean_kl_z: mean of kl_z
    """

    kl_z_list, eval_migs, mutual_infos = [], [], []

    for i, (p, q) in enumerate(zip(p_layers, q_layers)):

        z_samples = q_z_samples[i]

        # --- log q(z_i|x)

        _log_qz = q["prob"].log_prob(z_samples)
        # --- log q(z|x)
        logqz_x = tf.reduce_sum(_log_qz, axis=2)

        # --- log p(z_i|x)
        _log_pz = p["prob"].log_prob(z_samples)
        # --- log p(z|x)
        logpz_x = tf.reduce_sum(_log_pz, axis=2)

        tf.summary.histogram("q/mu/layer_" + str(i), tf.reshape(q['mu'], [-1]))
        tf.summary.histogram("q/sig/layer_" + str(i),
                             tf.reshape(q['var'], [-1]))
        tf.summary.histogram("p/mu/layer_" + str(i), tf.reshape(p['mu'], [-1]))
        tf.summary.histogram("p/sig/layer_" + str(i),
                             tf.reshape(p['var'], [-1]))

        if (i < params['stochastic_depth'] - 1) and (alpha > 0):
            kl = logqz_x - logpz_x
        else:

            # --- minibatch weighted sampling
            logp_n = tf.log(params["batch_size"] *
                            float(params["n_datapoints"]))

            log_qz_prodmarginals = tf.reduce_mean(tf.reduce_sum(
                tf.reduce_logsumexp(_log_qz, axis=1) - logp_n, axis=1),
                                                  axis=0)

            # --- compute log q(z)
            log_qz = tf.reduce_mean(tf.reduce_logsumexp(logqz_x, axis=1) -
                                    logp_n,
                                    axis=0)
            # --- KL = index-code MI + β*total-correlation + dimension-wise-kl
            total_correlation = log_qz - log_qz_prodmarginals
            dimension_wise_kl = log_qz_prodmarginals - logpz_x
            index_code_mi = logqz_x - log_qz

            tf.summary.scalar("tc/layer_" + str(i), total_correlation)

            kl = alpha * index_code_mi + beta * total_correlation + gamma * dimension_wise_kl

        # --- Free bits encoding
        kl = tf.maximum(np.float32(lamb * params["latent_size"][i]), kl)

        kl_z_list.append(kl)

        tf.summary.scalar("kl_z/layer_" + str(i),
                          tf.reduce_mean(kl_z_list[-1]))

        if (params["dataset"] in ['dsprites', 'celeba', 'chairs', 'faces'
                                  ]) and factors_batch is not None:
            mig, me, ce = mutual_info_metric_shapes(_log_qz, factors_batch,
                                                    params["latent_size"][i],
                                                    params["n_datapoints"])

            mutual_infos = me - ce

            tf.summary.scalar("MIG/layer_" + str(i), mig)
            eval_migs.append(mig)

    # --- Sum list of kl divergences
    kl_z = tf.add_n(kl_z_list)

    # --- Now take mean over minibatch
    mean_kl_z = tf.reduce_mean(kl_z)
    tf.summary.scalar("kl_z", mean_kl_z)

    return kl_z, mean_kl_z, eval_migs, mutual_infos


def calculate_elbo(kl_z, log_p_x_z, n_samples, warm_up_beta=1.0):
    """ Calculates vanilla ELBO given β-VAE definition
        elbo = -(β * kl_z + log_p_x_z)
        for β = 1 we recover the original VAE elbo loss.

    Args:
        kl_z: kl divergence for each stochastic layer
        log_p_x_z: log likelihood of input data x given stochastic variables z
        n_samples: number of MC samples used in approximations
        beta: additional penalty to apply to KL term to encourage disentaglement

    Returns:
        elbo: elbo of model
        importance_weighted_elbo: importance weighted elbo
    """

    # --- Local ELBO
    elbo_local = -(warm_up_beta * kl_z + log_p_x_z)
    # --- Mean ELBO
    elbo = tf.reduce_mean(elbo_local)
    # --- Importance weigthed ELBO
    importance_weighted_elbo = tf.reduce_mean(
        tf.reduce_logsumexp(elbo_local, axis=0) -
        tf.log(tf.cast(n_samples, dtype=tf.float32)))

    tf.summary.scalar("elbo", elbo)
    tf.summary.scalar("elbo/importance_weighted", importance_weighted_elbo)

    return elbo, importance_weighted_elbo
