# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Functions to help calculate adversarial losses
"""
import tensorflow as tf


def calculate_adversarial_kl(q_layers, zt_mu, zt_var, params):
    """
    Calculates the adversarial KL divergence KL(za,zt), where z_a is the embedding
    of the adversarial example and z_t is the embedding of the target.

    :param q_layers: latent variables for adversarial example
    :param zt_mu: latent variable means (at each stochastic layer) for target
    :param zt_var: latent variable variances (at each stochastic layer) for target
    :param params: dict of tf.estimator params

    """

    def kld(mean1, var1, mean2, var2):
        """
        KL divergence between two Normal distributions
        """
        mean_term = ((0.5 * var1) + (mean1 - mean2)**2.0) / (0.5 * var2)
        return tf.reduce_sum(mean_term + tf.log(var2) - tf.log(var1) - 0.5,
                             axis=2)

    adv_kl = 0.0
    # --- If x_plus mode is activated adversarial attack must match ALL latent variables
    if params["x_plus"]:
        # --- Adversarial kl is sum(KL(zai,zti)) where i iterates of L, the number
        # --- of latent variables / stochastic layers
        for i in range(params["stochastic_depth"]):
            start_idx = int(sum(params["latent_size"][:i]))
            end_idx = int(start_idx + params["latent_size"][i])
            target_mu = zt_mu[:, :, start_idx:end_idx]
            target_var = zt_var[:, :, start_idx:end_idx]
            adv_kl += tf.reduce_mean(
                kld(q_layers[i]['mu'], q_layers[i]['var']**2, target_mu,
                    target_var))
    else:
        # --- Adversarial kl is KL(za0,zt0) where the 0 highlights that we
        # --- only need to match latent variables at the bottom-most stochastic
        # --- layer to succesfully attack the model
        end_idx = int(params["latent_size"][0])
        target_mu = zt_mu[:, :, 0:end_idx]
        target_var = zt_var[:, :, 0:end_idx]
        adv_kl += tf.reduce_mean(
            kld(q_layers[0]['mu'], q_layers[0]['var']**2, target_mu,
                target_var))
    return adv_kl
