# coding=utf-8 # Copyright 2018 The DisentanglementLib Authors.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library of losses for weakly-supervised disentanglement learning.

Implementation of weakly-supervised VAE based models from the paper
"Weakly-Supervised Disentanglement Without Compromises"
https://arxiv.org/pdf/2002.02886.pdf.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from disentanglement_lib.methods.shared import losses  # pylint: disable=unused-import
from disentanglement_lib.methods.shared import (
    optimizers,
)  # pylint: disable=unused-import
from disentanglement_lib.methods.unsupervised import vae
from six.moves import zip
import tensorflow.compat.v1 as tf

import sys

import gin.tf
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import TPUEstimatorSpec

from disentanglement_lib.hypergeom_dist.tf_fmvhg import pmf_noncentral_fmvhg
from disentanglement_lib.hypergeom_dist.tf_fmvhg import get_logits
from disentanglement_lib.hypergeom_dist.heaviside import heaviside

import mvhg


@gin.configurable("weak_loss", blacklist=["z1", "z2", "labels"])
def make_weak_loss(z1, z2, labels, loss_fn=gin.REQUIRED):
    """Wrapper that creates weakly-supervised losses."""

    return loss_fn(z1, z2, labels)


@gin.configurable("group_vae")
class GroupVAEBase(vae.BaseVAE):
    """Beta-VAE with averaging from https://arxiv.org/abs/1809.02383."""

    def __init__(self, beta=gin.REQUIRED):
        """Creates a beta-VAE model with additional averaging for weak supervision.

        Based on https://arxiv.org/abs/1809.02383.

        Args:
          beta: Hyperparameter for KL divergence.
        """
        self.beta = beta

    def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled, kl_loss_hg=None):
        del z_mean, z_logvar, z_sampled
        return self.beta * kl_loss

    def model_fn(self, features, labels, mode, params):
        """TPUEstimator compatible model function."""
        is_training = mode == tf.estimator.ModeKeys.TRAIN
        data_shape = features.get_shape().as_list()[1:]
        data_shape[0] = int(data_shape[0] / 2)
        features_1 = features[:, : data_shape[0], :, :]
        features_2 = features[:, data_shape[0] :, :, :]
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            z_mean, z_logvar = self.gaussian_encoder(
                features_1, is_training=is_training
            )
            z_mean_2, z_logvar_2 = self.gaussian_encoder(
                features_2, is_training=is_training
            )
        labels = tf.squeeze(tf.one_hot(labels, z_mean.get_shape().as_list()[1]))
        kl_per_point = 0.5 * (
            compute_kl(z_mean, z_mean_2, z_logvar, z_logvar_2)
            + compute_kl(z_mean_2, z_mean, z_logvar_2, z_logvar)
        )

        new_mean = 0.5 * z_mean + 0.5 * z_mean_2
        var_1 = tf.exp(z_logvar)
        var_2 = tf.exp(z_logvar_2)
        new_log_var = tf.math.log(0.5 * var_1 + 0.5 * var_2)

        (
            mean_sample_1,
            log_var_sample_1,
            mean_sample_2,
            log_var_sample_2,
            kl_div_hg,
            agg_logs,
        ) = self.aggregate(
            z_mean,
            z_logvar,
            z_mean_2,
            z_logvar_2,
            new_mean,
            new_log_var,
            labels,
            kl_per_point,
        )

        z_sampled_1 = self.sample_from_latent_distribution(
            mean_sample_1, log_var_sample_1
        )
        z_sampled_2 = self.sample_from_latent_distribution(
            mean_sample_2, log_var_sample_2
        )
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            reconstructions_1 = self.decode(z_sampled_1, data_shape, is_training)
            reconstructions_2 = self.decode(z_sampled_2, data_shape, is_training)
        per_sample_loss_1 = losses.make_reconstruction_loss(
            features_1, reconstructions_1
        )
        per_sample_loss_2 = losses.make_reconstruction_loss(
            features_2, reconstructions_2
        )
        reconstruction_loss_1 = tf.reduce_mean(per_sample_loss_1)
        reconstruction_loss_2 = tf.reduce_mean(per_sample_loss_2)
        reconstruction_loss = 0.5 * reconstruction_loss_1 + 0.5 * reconstruction_loss_2
        kl_loss_1 = vae.compute_gaussian_kl(mean_sample_1, log_var_sample_1)
        kl_loss_2 = vae.compute_gaussian_kl(mean_sample_2, log_var_sample_2)
        kl_loss = 0.5 * kl_loss_1 + 0.5 * kl_loss_2
        regularizer = self.regularizer(kl_loss, None, None, None, kl_div_hg)

        loss = tf.add(reconstruction_loss, regularizer, name="loss")

        elbo = tf.add(reconstruction_loss, kl_loss, name="elbo")
        if mode == tf.estimator.ModeKeys.TRAIN:
            optimizer = optimizers.make_vae_optimizer()
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            train_op = optimizer.minimize(
                loss=loss, global_step=tf.train.get_global_step()
            )
            train_op = tf.group([train_op, update_ops])
            tf.summary.scalar("reconstruction_loss", reconstruction_loss)
            tf.summary.scalar("elbo", -elbo)
            for metric, value in agg_logs.items():
                tf.summary.scalar(metric, value)
            logging_hook = tf.train.LoggingTensorHook(
                {
                    "loss": loss,
                    "kl_loss": kl_loss,
                    "reconstruction_loss": reconstruction_loss,
                    "elbo": -elbo,
                },
                every_n_iter=100,
            )
            return TPUEstimatorSpec(
                mode=mode, loss=loss, train_op=train_op, training_hooks=[logging_hook]
            )
        elif mode == tf.estimator.ModeKeys.EVAL:
            for metric, value in agg_logs.items():
                tf.summary.scalar(f"eval_{metric}", value)
            return TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                eval_metrics=(
                    make_metric_fn(
                        "reconstruction_loss",
                        "elbo",
                        "regularizer",
                        "kl_loss",
                        *list(agg_logs.keys()),
                    ),
                    [reconstruction_loss, -elbo, regularizer, kl_loss]
                    + list(agg_logs.values()),
                ),
            )
        else:
            raise NotImplementedError("Eval mode not supported.")


@gin.configurable("group_vae_labels")
class GroupVAELabels(GroupVAEBase):
    """Class implementing the group-VAE with labels on which factor is shared."""

    def aggregate(
        self,
        z_mean1,
        z_logvar1,
        z_mean2,
        z_logvar2,
        new_mean,
        new_log_var,
        labels,
        kl_per_point,
    ):
        mean_sample_1, log_var_sample_1 = aggregate_labels(
            z_mean1, z_logvar1, new_mean, new_log_var, labels, kl_per_point
        )
        mean_sample_2, log_var_sample_2 = aggregate_labels(
            z_mean2, z_logvar2, new_mean, new_log_var, labels, kl_per_point
        )
        return (
            mean_sample_1,
            log_var_sample_1,
            mean_sample_2,
            log_var_sample_2,
            None,
            {"mean_k_hat": 1},
        )

    def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled, kl_loss_hg=None):
        del z_mean, z_logvar, z_sampled
        return self.beta * kl_loss


@gin.configurable("group_vae_argmax")
class GroupVAEArgmax(GroupVAEBase):
    """Class implementing the group-VAE without any label."""

    def aggregate(
        self,
        z_mean1,
        z_logvar1,
        z_mean2,
        z_logvar2,
        new_mean,
        new_log_var,
        labels,
        kl_per_point,
    ):
        mean_sample_1, log_var_sample_1, k_hat_1 = aggregate_argmax(
            z_mean1, z_logvar1, new_mean, new_log_var, labels, kl_per_point
        )
        mean_sample_2, log_var_sample_2, k_hat_2 = aggregate_argmax(
            z_mean2, z_logvar2, new_mean, new_log_var, labels, kl_per_point
        )
        k_hat = 0.5 * (tf.reduce_mean(k_hat_1) + tf.reduce_mean(k_hat_2))
        return (
            mean_sample_1,
            log_var_sample_1,
            mean_sample_2,
            log_var_sample_2,
            None,
            {"mean_k_hat": k_hat},
        )

    def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled, kl_loss_hg=None):
        del z_mean, z_logvar, z_sampled
        return self.beta * kl_loss


@gin.configurable("group_vae_hg")
class GroupVAEHyperGeom(GroupVAEBase):
    """Class implementing the group-VAE with hypergeometric shared factor selection"""

    def __init__(
        self,
        beta=gin.REQUIRED,
        num_latent=gin.REQUIRED,
        num_classes=gin.REQUIRED,
        gamma=gin.REQUIRED,
    ):
        super().__init__(beta)

        self.gamma = gamma
        self.num_latent = num_latent
        self.num_classes = num_classes
        self.hgeom_selector = HgeomSelector(self.num_latent, self.num_classes)

    def compute_hg_kl_div(self, n_latents, w):
        """
        Compute kl div to HG prior
        """
        # Assemble get_logits inputs
        n_samples = w.get_shape().as_list()[0]
        n = tf.expand_dims(tf.repeat(float(n_latents), repeats=n_samples), 1)
        heavy_x_rest = tf.cast(tf.range(int(n_latents + 1)), tf.float32)

        # Generate prior weights
        prior_w = 1.0 / float(n_latents) * tf.ones(w.get_shape()[0])
        #         prior_w = tf.ones(w.get_shape()[0])

        # Get logits
        logits_p_x_1 = get_logits(
            10.0,
            n,
            10.0,
            10.0,
            tf.reshape(w[:, 0] / w[:, 1], (-1, 1, 1)),
            heaviside(tf.reverse(heavy_x_rest, axis=(0,))) * heaviside(heavy_x_rest),
        )
        logits_prior_x_1 = get_logits(
            10.0,
            n,
            10.0,
            10.0,
            tf.reshape(prior_w, (-1, 1, 1)),
            heaviside(tf.reverse(heavy_x_rest, axis=(0,))) * heaviside(heavy_x_rest),
        )

        # Get all probabilities
        p_x_1 = tf.nn.softmax(logits_p_x_1)
        p_x_2 = 1 - p_x_1
        prior_x_1 = tf.nn.softmax(logits_prior_x_1)
        prior_x_2 = 1 - p_x_1

        # Calculate kl divergence
        kl_div = tf.reduce_sum(
            p_x_1 * p_x_2 * (tf.log(p_x_1 * p_x_2) - tf.log(prior_x_1 * prior_x_2))
        )

        return kl_div

    def aggregate(
        self,
        z_mean1,
        z_logvar1,
        z_mean2,
        z_logvar2,
        new_mean,
        new_logvar,
        labels,
        kl_per_point,
    ):
        # Get mask for shared and independent factors
        diag_masks, hg_ws, k_mask = self.hgeom_selector(new_mean, kl_per_point)
        diag_mask_avg, diag_mask_ind = diag_masks
        hg_weights, hg_weights_n = hg_ws

        # Get shared mean and logvar vectors
        z_mean_avg = tf.linalg.matmul(diag_mask_avg, tf.expand_dims(new_mean, -1))
        z_logvar_avg = tf.linalg.matmul(diag_mask_avg, tf.expand_dims(new_logvar, -1))

        # Mask independent factors
        z_mean1_ind = tf.linalg.matmul(diag_mask_ind, tf.expand_dims(z_mean1, -1))
        z_logvar1_ind = tf.linalg.matmul(diag_mask_ind, tf.expand_dims(z_logvar1, -1))
        z_mean2_ind = tf.linalg.matmul(diag_mask_ind, tf.expand_dims(z_mean2, -1))
        z_logvar2_ind = tf.linalg.matmul(diag_mask_ind, tf.expand_dims(z_logvar2, -1))

        # Combine shared and independent factors
        mean_sample1 = tf.add(z_mean_avg, z_mean1_ind)
        log_var_sample1 = tf.add(z_logvar_avg, z_logvar1_ind)
        mean_sample2 = tf.add(z_mean_avg, z_mean2_ind)
        log_var_sample2 = tf.add(z_logvar_avg, z_logvar2_ind)

        # Reshape factors to original shape
        mean_sample1 = tf.reshape(mean_sample1, [-1, self.num_latent])
        mean_sample2 = tf.reshape(mean_sample2, [-1, self.num_latent])
        log_var_sample1 = tf.reshape(log_var_sample1, [-1, self.num_latent])
        log_var_sample2 = tf.reshape(log_var_sample2, [-1, self.num_latent])

        n_samples = float(z_mean1.get_shape().as_list()[0])
        log_hg_weights = tf.math.log(hg_weights_n + 1e-6)
        w_prior = 0.5 * tf.ones((n_samples, 2))
        log_p_w = tf.math.log(w_prior)

        #         kl_div_w = tf.reduce_mean(
        #             tf.reduce_sum(hg_weights_n*(log_hg_weights-log_p_w), axis=1),
        #             name="kl_loss_w")

        kl_div_w = self.compute_hg_kl_div(z_mean1.get_shape().as_list()[1], hg_weights)

        # Compute average deviation from number of shared factors
        k_hat = tf.reduce_mean(tf.reduce_sum(k_mask, axis=-1))
        w_hat_n = tf.reduce_mean(hg_weights_n, axis=0)
        if hg_weights is not None:
            w_hat = tf.reduce_mean(hg_weights, axis=0)
        else:
            w_hat = (-1) * tf.ones((2))
        kl_per_point_avg = tf.reduce_mean(kl_per_point)
        logs = {
            "mean_k_hat": k_hat,
            "mean_w1_hat": w_hat[0],
            "mean_w2_hat": w_hat[1],
            "mean_w1_hat_n": w_hat_n[0],
            "mean_w2_hat_n": w_hat_n[1],
            "kl_div_w": kl_div_w,
            "kl_per_point": kl_per_point_avg,
        }

        return (
            mean_sample1,
            log_var_sample1,
            mean_sample2,
            log_var_sample2,
            kl_div_w,
            logs,
        )

    def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled, kl_loss_hg=None):
        del z_mean, z_logvar, z_sampled
        return self.beta * kl_loss + self.gamma * kl_loss_hg


@gin.configurable("mlvae")
class MLVae(vae.BaseVAE):
    """Beta-VAE with averaging from https://arxiv.org/abs/1705.08841."""

    def __init__(self, beta=gin.REQUIRED):
        """Creates a beta-VAE model with additional averaging for weak supervision.

        Based on ML-VAE https://arxiv.org/abs/1705.08841.

        Args:
          beta: Hyperparameter total correlation.
        """
        self.beta = beta

    def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled):
        del z_mean, z_logvar, z_sampled
        return self.beta * kl_loss

    def model_fn(self, features, labels, mode, params):
        """TPUEstimator compatible model function."""
        is_training = mode == tf.estimator.ModeKeys.TRAIN
        data_shape = features.get_shape().as_list()[1:]
        data_shape[0] = int(data_shape[0] / 2)
        features_1 = features[:, : data_shape[0], :, :]
        features_2 = features[:, data_shape[0] :, :, :]
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            z_mean, z_logvar = self.gaussian_encoder(
                features_1, is_training=is_training
            )
            z_mean_2, z_logvar_2 = self.gaussian_encoder(
                features_2, is_training=is_training
            )
        labels = tf.squeeze(tf.one_hot(labels, z_mean.get_shape().as_list()[1]))
        kl_per_point = compute_kl(z_mean, z_mean_2, z_logvar, z_logvar_2)

        var_1 = tf.exp(z_logvar)
        var_2 = tf.exp(z_logvar_2)
        new_var = 2 * var_1 * var_2 / (var_1 + var_2)
        new_mean = (z_mean / var_1 + z_mean_2 / var_2) * new_var * 0.5

        new_log_var = tf.math.log(new_var)

        (
            mean_sample_1,
            log_var_sample_1,
            mean_sample_2,
            log_var_sample_2,
        ) = self.aggregate(
            z_mean,
            z_logvar,
            z_mean_2,
            z_logvar_2,
            new_mean,
            new_log_var,
            labels,
            kl_per_point,
        )

        z_sampled_1 = self.sample_from_latent_distribution(
            mean_sample_1, log_var_sample_1
        )
        z_sampled_2 = self.sample_from_latent_distribution(
            mean_sample_2, log_var_sample_2
        )
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            reconstructions_1 = self.decode(z_sampled_1, data_shape, is_training)
            reconstructions_2 = self.decode(z_sampled_2, data_shape, is_training)
        per_sample_loss_1 = losses.make_reconstruction_loss(
            features_1, reconstructions_1
        )
        per_sample_loss_2 = losses.make_reconstruction_loss(
            features_2, reconstructions_2
        )
        reconstruction_loss_1 = tf.reduce_mean(per_sample_loss_1)
        reconstruction_loss_2 = tf.reduce_mean(per_sample_loss_2)
        reconstruction_loss = 0.5 * reconstruction_loss_1 + 0.5 * reconstruction_loss_2
        kl_loss_1 = vae.compute_gaussian_kl(mean_sample_1, log_var_sample_1)
        kl_loss_2 = vae.compute_gaussian_kl(mean_sample_2, log_var_sample_2)
        kl_loss = 0.5 * kl_loss_1 + 0.5 * kl_loss_2
        regularizer = self.regularizer(kl_loss, None, None, None)

        loss = tf.add(reconstruction_loss, regularizer, name="loss")
        elbo = tf.add(reconstruction_loss, kl_loss, name="elbo")
        if mode == tf.estimator.ModeKeys.TRAIN:
            optimizer = optimizers.make_vae_optimizer()
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            train_op = optimizer.minimize(
                loss=loss, global_step=tf.train.get_global_step()
            )
            train_op = tf.group([train_op, update_ops])
            tf.summary.scalar("reconstruction_loss", reconstruction_loss)
            tf.summary.scalar("elbo", -elbo)
            logging_hook = tf.train.LoggingTensorHook(
                {
                    "loss": loss,
                    "reconstruction_loss": reconstruction_loss,
                    "elbo": -elbo,
                },
                every_n_iter=100,
            )
            return TPUEstimatorSpec(
                mode=mode, loss=loss, train_op=train_op, training_hooks=[logging_hook]
            )
        elif mode == tf.estimator.ModeKeys.EVAL:
            return TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                eval_metrics=(
                    make_metric_fn(
                        "reconstruction_loss", "elbo", "regularizer", "kl_loss"
                    ),
                    [reconstruction_loss, -elbo, regularizer, kl_loss],
                ),
            )
        else:
            raise NotImplementedError("Eval mode not supported.")


@gin.configurable("mlvae_labels")
class MLVaeLabels(MLVae):
    """Class implementing the group-VAE with labels on which factor is shared."""

    def aggregate(
        self,
        z_mean1,
        z_logvar1,
        z_mean2,
        z_logvar2,
        new_mean,
        new_log_var,
        labels,
        kl_per_point,
    ):
        mean_sample_1, log_var_sample_1 = aggregate_labels(
            z_mean1, z_logvar1, new_mean, new_log_var, labels, kl_per_point
        )
        mean_sample_2, log_var_sample_2 = aggregate_labels(
            z_mean2, z_logvar2, new_mean, new_log_var, labels, kl_per_point
        )
        return mean_sample_1, log_var_sample_1, mean_sample_2, log_var_sample_2


@gin.configurable("mlvae_argmax")
class MLVaeArgmax(MLVae):
    """Class implementing the group-VAE without any label."""

    def aggregate(
        self,
        z_mean1,
        z_logvar1,
        z_mean2,
        z_logvar2,
        new_mean,
        new_log_var,
        labels,
        kl_per_point,
    ):
        mean_sample_1, log_var_sample_1 = aggregate_argmax(
            z_mean1, z_logvar1, new_mean, new_log_var, labels, kl_per_point
        )
        mean_sample_2, log_var_sample_2 = aggregate_argmax(
            z_mean2, z_logvar2, new_mean, new_log_var, labels, kl_per_point
        )
        return mean_sample_1, log_var_sample_1, mean_sample_2, log_var_sample_2


@gin.configurable("ml_vae_hg")
class MLVAEHyperGeom(MLVae):
    """Class implementing the ML-VAE with hypergeometric shared factor selection"""

    def __init__(
        self, beta=gin.REQUIRED, num_latent=gin.REQUIRED, num_classes=gin.REQUIRED
    ):
        super().__init__(beta)

        self.num_latent = num_latent
        self.num_classes = num_classes
        self.hgeom_selector = HgeomSelector(self.num_latent, self.num_classes)

    def aggregate(
        self,
        z_mean1,
        z_logvar1,
        z_mean2,
        z_logvar2,
        new_mean,
        new_logvar,
        labels,
        kl_per_point,
    ):
        diag_mask_avg, diag_mask_ind = self.hgeom_selector(new_mean, kl_per_point)
        z_mean_avg = tf.linalg.matmul(tf.expand_dims(new_mean, 1), diag_mask_avg)
        z_logvar_avg = tf.linalg.matmul(tf.expand_dims(new_logvar, 1), diag_mask_avg)
        z_mean1_ind = tf.linalg.matmul(tf.expand_dims(z_mean1, 1), diag_mask_ind)
        z_logvar1_ind = tf.linalg.matmul(tf.expand_dims(z_logvar1, 1), diag_mask_ind)
        z_mean2_ind = tf.linalg.matmul(tf.expand_dims(z_mean2, 1), diag_mask_ind)
        z_logvar2_ind = tf.linalg.matmul(tf.expand_dims(z_logvar2, 1), diag_mask_ind)
        mean_sample1 = tf.add(z_mean_avg, z_mean1_ind)
        log_var_sample1 = tf.add(z_logvar_avg, z_logvar1_ind)
        mean_sample2 = tf.add(z_mean_avg, z_mean2_ind)
        log_var_sample2 = tf.add(z_logvar_avg, z_logvar2_ind)

        mean_sample1 = tf.reshape(mean_sample1, [-1, self.num_latent])
        mean_sample2 = tf.reshape(mean_sample2, [-1, self.num_latent])
        log_var_sample1 = tf.reshape(log_var_sample1, [-1, self.num_latent])
        log_var_sample2 = tf.reshape(log_var_sample2, [-1, self.num_latent])
        return mean_sample1, log_var_sample1, mean_sample2, log_var_sample2


def aggregate_labels(z_mean, z_logvar, new_mean, new_log_var, labels, kl_per_point):
    """Use labels to aggregate.

    Labels contains a one-hot encoding with a single 1 of a factor shared. We
    enforce which dimension of the latent code learn which factor (dimension 1
    learns factor 1) and we enforce that each factor of variation is encoded in a
    single dimension.

    Args:
      z_mean: Mean of the encoder distribution for the original image.
      z_logvar: Logvar of the encoder distribution for the original image.
      new_mean: Average mean of the encoder distribution of the pair of images.
      new_log_var: Average logvar of the encoder distribution of the pair of
        images.
      labels: One-hot-encoding with the position of the dimension that should not
        be shared.
      kl_per_point: Distance between the two encoder distributions (unused).

    Returns:
      Mean and logvariance for the new observation.
    """
    del kl_per_point
    z_mean_averaged = tf.where(
        tf.math.equal(labels, tf.expand_dims(tf.reduce_max(labels, axis=1), 1)),
        z_mean,
        new_mean,
    )
    z_logvar_averaged = tf.where(
        tf.math.equal(labels, tf.expand_dims(tf.reduce_max(labels, axis=1), 1)),
        z_logvar,
        new_log_var,
    )
    return z_mean_averaged, z_logvar_averaged


def aggregate_argmax(z_mean, z_logvar, new_mean, new_log_var, labels, kl_per_point):
    """Argmax aggregation with adaptive k.

    The bottom k dimensions in terms of distance are not averaged. K is
    estimated adaptively by binning the distance into two bins of equal width.

    Args:
      z_mean: Mean of the encoder distribution for the original image.
      z_logvar: Logvar of the encoder distribution for the original image.
      new_mean: Average mean of the encoder distribution of the pair of images.
      new_log_var: Average logvar of the encoder distribution of the pair of
        images.
      labels: One-hot-encoding with the position of the dimension that should not
        be shared.
      kl_per_point: Distance between the two encoder distributions.

    Returns:
      Mean and logvariance for the new observation.
    """
    del labels
    mask = tf.equal(tf.map_fn(discretize_in_bins, kl_per_point, tf.int32), 1)
    k_hat = tf.reduce_sum(tf.cast(mask, tf.float32), axis=-1)
    z_mean_averaged = tf.where(mask, z_mean, new_mean)
    z_logvar_averaged = tf.where(mask, z_logvar, new_log_var)
    return z_mean_averaged, z_logvar_averaged, k_hat


def discretize_in_bins(x):
    """Discretize a vector in two bins."""
    return tf.histogram_fixed_width_bins(
        x, [tf.reduce_min(x), tf.reduce_max(x)], nbins=2
    )


def compute_kl(z_1, z_2, logvar_1, logvar_2):
    var_2 = tf.exp(logvar_2) + 1e-6
    kl_div = (
        tf.exp(logvar_1 - logvar_2)
        + tf.square(z_2 - z_1) / var_2
        - 1
        + logvar_2
        - logvar_1
    )
    kl_div = tf.nn.relu(kl_div)
    return kl_div


def make_metric_fn(*names):
    """Utility function to report tf.metrics in model functions."""

    def metric_fn(*args):
        return {name: tf.metrics.mean(vec) for name, vec in zip(names, args)}

    return metric_fn


def print_t(*args):
    ret = []
    for x in args:
        ret.append(
            tf.Print(x, [x, tf.math.reduce_any(tf.math.is_nan(x))], "Value of x: ")
        )
    if len(ret) == 1:
        return ret[0]
    return ret


class HgeomSelector(tf.keras.layers.Layer):
    def __init__(self, num_latents, num_classes):
        super(HgeomSelector, self).__init__()
        self.n_classes = num_classes
        self.num_latents = num_latents
        self.m = [float(num_latents), float(num_latents)]
        self.tau = 1.0

    def bl_matmul(self, A, B):
        return tf.einsum("mij,jk->mik", A, B)

    def neuralsort(self, s, tau=1.0, hard=True):
        """
        s: M x n x 1
        neuralsort(s): M x n x n
        Neuralsort algorithm from https://github.com/ermongroup/neuralsort
        """
        A_s = s - tf.transpose(s, perm=[0, 2, 1])
        A_s = tf.abs(A_s)

        n = tf.shape(s)[1]
        one = tf.ones((n, 1), dtype=tf.float32)

        B = self.bl_matmul(A_s, one @ tf.transpose(one))

        K = tf.range(n) + 1

        C = self.bl_matmul(
            s, tf.expand_dims(tf.cast(n + 1 - 2 * K, dtype=tf.float32), 0)
        )

        P = tf.transpose(C - B, perm=[0, 2, 1])

        P = tf.nn.softmax(P / tau, -1)

        if hard:
            P_hard = tf.cast(
                tf.equal(P, tf.reduce_max(P, axis=-1, keepdims=True)), P.dtype
            )
            P = tf.stop_gradient(P_hard - P) + P

        return P

    def call(self, z_mean, kl_ind):
        n_samples = z_mean.get_shape()[0]
        n_latents = float(z_mean.get_shape().as_list()[1])

        kl_ind = tf.math.log(kl_ind + 1)

        # logits estimation with tanh * 3: out put is on log-domain -> exp
        # needed
        logits_w = 3 * tf.layers.dense(
            kl_ind, self.n_classes, activation=tf.nn.tanh, name="logits_w"
        )
        # logits_w = tf.math.exp(logits_w)
        logits_w_n = tf.nn.softmax(logits_w, -1)

        n = tf.expand_dims(tf.repeat(n_latents, repeats=n_samples), 1)
        _, _, y_mask = mvhg.pmf_noncentral_fmvhg(
            self.m, n, logits_w, self.tau, self.n_classes
        )

        # Compute mask containing the d-k shared factors for each sample
        y_mask_sh = y_mask[1][:, :, 1:]
        y_mask_sh = tf.reverse(y_mask_sh, axis=(-1,))
        # y_mask_ind = tf.concat([y_mask[0][:,:,1:], tf.zeros((n_samples, 1, 1))],
        #                         axis=2)
        y_mask_ind = y_mask[0][:, :, 1:]

        kl_ind = tf.expand_dims(kl_ind, -1)
        sort_perm = self.neuralsort(kl_ind, tau=self.tau, hard=True)
        y_sel_ind = y_mask_ind @ sort_perm
        y_sel_sh = y_mask_sh @ sort_perm

        y_diag_ind = tf.clip_by_value(y_sel_ind, 0, 1)
        y_diag_sh = tf.clip_by_value(y_sel_sh, 0, 1)
        y_diag_sh = tf.squeeze(y_diag_sh)
        y_diag_ind = tf.squeeze(y_diag_ind)
        diag_ind = tf.linalg.diag(y_diag_ind, k=0, padding_value=0)
        diag_sh = tf.linalg.diag(y_diag_sh, k=0, padding_value=0)
        return [diag_sh, diag_ind], [logits_w, logits_w_n], y_mask_ind
