# coding=utf-8 # Copyright 2018 The DisentanglementLib Authors.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library of losses for weakly-supervised disentanglement learning.

Implementation of weakly-supervised VAE based models from the paper
"Weakly-Supervised Disentanglement Without Compromises"
https://arxiv.org/pdf/2002.02886.pdf.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from disentanglement_lib.methods.shared import losses  # pylint: disable=unused-import
from disentanglement_lib.methods.shared import (
    optimizers,
)  # pylint: disable=unused-import
from disentanglement_lib.methods.unsupervised import vae
from six.moves import zip
import tensorflow.compat.v1 as tf

import sys

import gin.tf
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import TPUEstimatorSpec

import mvhg.tf_fmvhg as hg
import mvhg.tf_heaviside as heaviside


@gin.configurable("weak_loss", blacklist=["z1", "z2", "labels"])
def make_weak_loss(z1, z2, labels, loss_fn=gin.REQUIRED):
    """Wrapper that creates weakly-supervised losses."""

    return loss_fn(z1, z2, labels)


@gin.configurable("group_vae")
class GroupVAEBase(vae.BaseVAE):
    """Beta-VAE with averaging from https://arxiv.org/abs/1809.02383."""

    def __init__(self, beta=gin.REQUIRED, mc_samples=1):
        """Creates a beta-VAE model with additional averaging for weak supervision.

        Based on https://arxiv.org/abs/1809.02383.

        Args:
          beta: Hyperparameter for KL divergence.
        """
        self.beta = beta
        self.n_mc_samples = mc_samples

    def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled, kl_loss_hg=None):
        del z_mean, z_logvar, z_sampled
        return self.beta * kl_loss

    def model_fn(self, features, labels, mode, params):
        """TPUEstimator compatible model function."""
        is_training = mode == tf.estimator.ModeKeys.TRAIN
        data_shape = features.get_shape().as_list()[1:]
        data_shape[0] = int(data_shape[0] / 2)
        features_1 = features[:, : data_shape[0], :, :]
        features_2 = features[:, data_shape[0] :, :, :]
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            z_mean, z_logvar = self.gaussian_encoder(
                features_1, is_training=is_training
            )
            z_mean_2, z_logvar_2 = self.gaussian_encoder(
                features_2, is_training=is_training
            )
        labels = tf.squeeze(tf.one_hot(labels, z_mean.get_shape().as_list()[1]))
        kl_per_point = 0.5 * (
            compute_kl(z_mean, z_mean_2, z_logvar, z_logvar_2)
            + compute_kl(z_mean_2, z_mean, z_logvar_2, z_logvar)
        )

        new_mean = 0.5 * z_mean + 0.5 * z_mean_2
        var_1 = tf.exp(z_logvar)
        var_2 = tf.exp(z_logvar_2)
        new_log_var = tf.math.log(0.5 * var_1 + 0.5 * var_2)
        
        regularizer, kl_loss, reconstruction_loss = 0,0,0
        agg_logs = {}
        # Multiple MC samples to stabilize training
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            for i in range(self.n_mc_samples):
                (
                    mean_sample_1,
                    log_var_sample_1,
                    mean_sample_2,
                    log_var_sample_2,
                    kl_div_hg,
                    curr_agg_logs,
                ) = self.aggregate(
                    z_mean,
                    z_logvar,
                    z_mean_2,
                    z_logvar_2,
                    new_mean,
                    new_log_var,
                    labels,
                    kl_per_point,
                )

                for key, value in curr_agg_logs.items():
                    if key not in agg_logs:
                        agg_logs[key] = []
                    agg_logs[key].append(value)

                z_sampled_1 = self.sample_from_latent_distribution(
                    mean_sample_1, log_var_sample_1
                )
                z_sampled_2 = self.sample_from_latent_distribution(
                    mean_sample_2, log_var_sample_2
                )

                reconstructions_1 = self.decode(z_sampled_1, data_shape, is_training)
                reconstructions_2 = self.decode(z_sampled_2, data_shape, is_training)
                per_sample_loss_1 = losses.make_reconstruction_loss(
                    features_1, reconstructions_1
                )
                per_sample_loss_2 = losses.make_reconstruction_loss(
                    features_2, reconstructions_2
                )
                reconstruction_loss_1 = tf.reduce_mean(per_sample_loss_1)
                reconstruction_loss_2 = tf.reduce_mean(per_sample_loss_2)
                curr_reconstruction_loss = 0.5 * reconstruction_loss_1 + 0.5 * reconstruction_loss_2
                reconstruction_loss = reconstruction_loss + curr_reconstruction_loss
                kl_loss_1 = vae.compute_gaussian_kl(mean_sample_1, log_var_sample_1)
                kl_loss_2 = vae.compute_gaussian_kl(mean_sample_2, log_var_sample_2)
                curr_kl_loss = 0.5 * kl_loss_1 + 0.5 * kl_loss_2
                kl_loss = kl_loss + curr_kl_loss
                curr_regularizer = self.regularizer(kl_loss, None, None, None, kl_div_hg)
                regularizer = regularizer + curr_regularizer
                agg_logs[f'kl_loss_mc_{i}'] = curr_kl_loss
                agg_logs[f'regularizer_mc_{i}'] = curr_regularizer
                agg_logs[f'reconstruction_loss_mc_{i}'] = curr_reconstruction_loss
        
        for key, value in agg_logs.items():
            agg_logs[key] = tf.reduce_mean(value)
        
        regularizer = 1./self.n_mc_samples * regularizer
        kl_loss = 1./self.n_mc_samples * kl_loss
        reconstruction_loss = 1./self.n_mc_samples * reconstruction_loss
        
        
        loss = tf.add(reconstruction_loss, regularizer, name="loss")

        elbo = tf.add(reconstruction_loss, kl_loss, name="elbo")
        if mode == tf.estimator.ModeKeys.TRAIN:
            optimizer = optimizers.make_vae_optimizer()
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            train_op = optimizer.minimize(
                loss=loss, global_step=tf.train.get_global_step()
            )
            train_op = tf.group([train_op, update_ops])
            tf.summary.scalar("reconstruction_loss", reconstruction_loss)
            tf.summary.scalar("elbo", -elbo)
            for metric, value in agg_logs.items():
                tf.summary.scalar(metric, value)
            logging_hook = tf.train.LoggingTensorHook(
                {
                    "loss": loss,
                    "kl_loss": kl_loss,
                    "reconstruction_loss": reconstruction_loss,
                    "elbo": -elbo,
                },
                every_n_iter=100,
            )
            return TPUEstimatorSpec(
                mode=mode, loss=loss, train_op=train_op, training_hooks=[logging_hook]
            )
        elif mode == tf.estimator.ModeKeys.EVAL:
            for metric, value in agg_logs.items():
                tf.summary.scalar(f"eval_{metric}", value)
            return TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                eval_metrics=(
                    make_metric_fn(
                        "reconstruction_loss",
                        "elbo",
                        "regularizer",
                        "kl_loss",
                        *list(agg_logs.keys()),
                    ),
                    [reconstruction_loss, -elbo, regularizer, kl_loss]
                    + list(agg_logs.values()),
                ),
            )
        else:
            raise NotImplementedError("Eval mode not supported.")


@gin.configurable("group_vae_labels")
class GroupVAELabels(GroupVAEBase):
    """Class implementing the group-VAE with labels on which factor is shared."""

    def aggregate(
        self,
        z_mean1,
        z_logvar1,
        z_mean2,
        z_logvar2,
        new_mean,
        new_log_var,
        labels,
        kl_per_point,
    ):
        mean_sample_1, log_var_sample_1 = aggregate_labels(
            z_mean1, z_logvar1, new_mean, new_log_var, labels, kl_per_point
        )
        mean_sample_2, log_var_sample_2 = aggregate_labels(
            z_mean2, z_logvar2, new_mean, new_log_var, labels, kl_per_point
        )
        return (
            mean_sample_1,
            log_var_sample_1,
            mean_sample_2,
            log_var_sample_2,
            None,
            {"mean_k_hat": 1},
        )

    def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled, kl_loss_hg=None):
        del z_mean, z_logvar, z_sampled
        return self.beta * kl_loss


@gin.configurable("group_vae_argmax")
class GroupVAEArgmax(GroupVAEBase):
    """Class implementing the group-VAE without any label."""

    def aggregate(
        self,
        z_mean1,
        z_logvar1,
        z_mean2,
        z_logvar2,
        new_mean,
        new_log_var,
        labels,
        kl_per_point,
    ):
        mean_sample_1, log_var_sample_1, k_hat_1 = aggregate_argmax(
            z_mean1, z_logvar1, new_mean, new_log_var, labels, kl_per_point
        )
        mean_sample_2, log_var_sample_2, k_hat_2 = aggregate_argmax(
            z_mean2, z_logvar2, new_mean, new_log_var, labels, kl_per_point
        )
        k_hat = 0.5 * (tf.reduce_mean(k_hat_1) + tf.reduce_mean(k_hat_2))
        return (
            mean_sample_1,
            log_var_sample_1,
            mean_sample_2,
            log_var_sample_2,
            None,
            {"mean_k_hat": k_hat},
        )

    def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled, kl_loss_hg=None):
        del z_mean, z_logvar, z_sampled
        return self.beta * kl_loss


@gin.configurable("group_vae_hg")
class GroupVAEHyperGeom(GroupVAEBase):
    """Class implementing the group-VAE with hypergeometric shared factor selection"""

    def __init__(
        self,
        beta=gin.REQUIRED,
        num_latent=gin.REQUIRED,
        num_classes=gin.REQUIRED,
        gamma=gin.REQUIRED,
        mc_samples=gin.REQUIRED
    ):
        super().__init__(beta, mc_samples)

        self.gamma = gamma
        self.num_latent = num_latent
        self.num_classes = num_classes
        self.hgeom_selector = HgeomSelector(
            self.num_latent,
            self.num_classes,
        )

    def aggregate(
        self,
        z_mean1,
        z_logvar1,
        z_mean2,
        z_logvar2,
        new_mean,
        new_logvar,
        labels,
        kl_per_point,
    ):
        # Get mask for shared and independent factors
        diag_masks, hg_ws, k_mask, lps_hg, curr_tau = self.hgeom_selector(new_mean, kl_per_point)
        diag_mask_avg, diag_mask_ind = diag_masks
        log_hg_weights, hg_weights_n = hg_ws

        # Get shared mean and logvar vectors
        z_mean_avg = tf.linalg.matmul(diag_mask_avg, tf.expand_dims(new_mean, -1))
        z_logvar_avg = tf.linalg.matmul(diag_mask_avg, tf.expand_dims(new_logvar, -1))

        # Mask independent factors
        z_mean1_ind = tf.linalg.matmul(diag_mask_ind, tf.expand_dims(z_mean1, -1))
        z_logvar1_ind = tf.linalg.matmul(diag_mask_ind, tf.expand_dims(z_logvar1, -1))
        z_mean2_ind = tf.linalg.matmul(diag_mask_ind, tf.expand_dims(z_mean2, -1))
        z_logvar2_ind = tf.linalg.matmul(diag_mask_ind, tf.expand_dims(z_logvar2, -1))

        # Combine shared and independent factors
        mean_sample1 = tf.add(z_mean_avg, z_mean1_ind)
        log_var_sample1 = tf.add(z_logvar_avg, z_logvar1_ind)
        mean_sample2 = tf.add(z_mean_avg, z_mean2_ind)
        log_var_sample2 = tf.add(z_logvar_avg, z_logvar2_ind)

        # Reshape factors to original shape
        mean_sample1 = tf.reshape(mean_sample1, [-1, self.num_latent])
        mean_sample2 = tf.reshape(mean_sample2, [-1, self.num_latent])
        log_var_sample1 = tf.reshape(log_var_sample1, [-1, self.num_latent])
        log_var_sample2 = tf.reshape(log_var_sample2, [-1, self.num_latent])

        # log probability of prior and posterior hg distribution
        lp_hg_prior = lps_hg[0]
        lp_hg_post = lps_hg[1]
        lp_hg_n = -tf.reduce_mean(lp_hg_prior - lp_hg_post)

        # Compute average deviation from number of shared factors
        k_hat = tf.reduce_mean(tf.reduce_sum(k_mask, axis=-1))
        if log_hg_weights is not None:
            log_w_hat = tf.reduce_mean(log_hg_weights, axis=0)
        else:
            log_w_hat = tf.zeros((2))
        kl_per_point_avg = tf.reduce_mean(kl_per_point)
        logs = {
            "mean_k_hat": k_hat,
            "mean_log_w1_hat": log_w_hat[0],
            "mean_log_w2_hat": log_w_hat[1],
            "log_prob_hg": lp_hg_n,
            "kl_per_point": kl_per_point_avg,
            "mean_lp_hg_prior": tf.reduce_mean(lp_hg_prior),
            "mean_lp_hg_posterior": tf.reduce_mean(lp_hg_post),
            "curr_temperature":curr_tau,
        }

        return (
            mean_sample1,
            log_var_sample1,
            mean_sample2,
            log_var_sample2,
            lp_hg_n,
            logs,
        )

    def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled, kl_loss_hg=None):
        del z_mean, z_logvar, z_sampled
        return self.beta * kl_loss + self.gamma * kl_loss_hg


@gin.configurable("mlvae")
class MLVae(vae.BaseVAE):
    """Beta-VAE with averaging from https://arxiv.org/abs/1705.08841."""

    def __init__(self, beta=gin.REQUIRED):
        """Creates a beta-VAE model with additional averaging for weak supervision.

        Based on ML-VAE https://arxiv.org/abs/1705.08841.

        Args:
          beta: Hyperparameter total correlation.
        """
        self.beta = beta

    def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled):
        del z_mean, z_logvar, z_sampled
        return self.beta * kl_loss

    def model_fn(self, features, labels, mode, params):
        """TPUEstimator compatible model function."""
        is_training = mode == tf.estimator.ModeKeys.TRAIN
        data_shape = features.get_shape().as_list()[1:]
        data_shape[0] = int(data_shape[0] / 2)
        features_1 = features[:, : data_shape[0], :, :]
        features_2 = features[:, data_shape[0] :, :, :]
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            z_mean, z_logvar = self.gaussian_encoder(
                features_1, is_training=is_training
            )
            z_mean_2, z_logvar_2 = self.gaussian_encoder(
                features_2, is_training=is_training
            )
        labels = tf.squeeze(tf.one_hot(labels, z_mean.get_shape().as_list()[1]))
        kl_per_point = compute_kl(z_mean, z_mean_2, z_logvar, z_logvar_2)

        var_1 = tf.exp(z_logvar)
        var_2 = tf.exp(z_logvar_2)
        new_var = 2 * var_1 * var_2 / (var_1 + var_2)
        new_mean = (z_mean / var_1 + z_mean_2 / var_2) * new_var * 0.5

        new_log_var = tf.math.log(new_var)

        (
            mean_sample_1,
            log_var_sample_1,
            mean_sample_2,
            log_var_sample_2,
        ) = self.aggregate(
            z_mean,
            z_logvar,
            z_mean_2,
            z_logvar_2,
            new_mean,
            new_log_var,
            labels,
            kl_per_point,
        )

        z_sampled_1 = self.sample_from_latent_distribution(
            mean_sample_1, log_var_sample_1
        )
        z_sampled_2 = self.sample_from_latent_distribution(
            mean_sample_2, log_var_sample_2
        )
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            reconstructions_1 = self.decode(z_sampled_1, data_shape, is_training)
            reconstructions_2 = self.decode(z_sampled_2, data_shape, is_training)
        per_sample_loss_1 = losses.make_reconstruction_loss(
            features_1, reconstructions_1
        )
        per_sample_loss_2 = losses.make_reconstruction_loss(
            features_2, reconstructions_2
        )
        reconstruction_loss_1 = tf.reduce_mean(per_sample_loss_1)
        reconstruction_loss_2 = tf.reduce_mean(per_sample_loss_2)
        reconstruction_loss = 0.5 * reconstruction_loss_1 + 0.5 * reconstruction_loss_2
        kl_loss_1 = vae.compute_gaussian_kl(mean_sample_1, log_var_sample_1)
        kl_loss_2 = vae.compute_gaussian_kl(mean_sample_2, log_var_sample_2)
        kl_loss = 0.5 * kl_loss_1 + 0.5 * kl_loss_2
        regularizer = self.regularizer(kl_loss, None, None, None)

        loss = tf.add(reconstruction_loss, regularizer, name="loss")
        elbo = tf.add(reconstruction_loss, kl_loss, name="elbo")
        if mode == tf.estimator.ModeKeys.TRAIN:
            optimizer = optimizers.make_vae_optimizer()
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            train_op = optimizer.minimize(
                loss=loss, global_step=tf.train.get_global_step()
            )
            train_op = tf.group([train_op, update_ops])
            tf.summary.scalar("reconstruction_loss", reconstruction_loss)
            tf.summary.scalar("elbo", -elbo)
            logging_hook = tf.train.LoggingTensorHook(
                {
                    "loss": loss,
                    "reconstruction_loss": reconstruction_loss,
                    "elbo": -elbo,
                },
                every_n_iter=100,
            )
            return TPUEstimatorSpec(
                mode=mode, loss=loss, train_op=train_op, training_hooks=[logging_hook]
            )
        elif mode == tf.estimator.ModeKeys.EVAL:
            return TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                eval_metrics=(
                    make_metric_fn(
                        "reconstruction_loss", "elbo", "regularizer", "kl_loss"
                    ),
                    [reconstruction_loss, -elbo, regularizer, kl_loss],
                ),
            )
        else:
            raise NotImplementedError("Eval mode not supported.")


@gin.configurable("mlvae_labels")
class MLVaeLabels(MLVae):
    """Class implementing the group-VAE with labels on which factor is shared."""

    def aggregate(
        self,
        z_mean1,
        z_logvar1,
        z_mean2,
        z_logvar2,
        new_mean,
        new_log_var,
        labels,
        kl_per_point,
    ):
        mean_sample_1, log_var_sample_1 = aggregate_labels(
            z_mean1, z_logvar1, new_mean, new_log_var, labels, kl_per_point
        )
        mean_sample_2, log_var_sample_2 = aggregate_labels(
            z_mean2, z_logvar2, new_mean, new_log_var, labels, kl_per_point
        )
        return mean_sample_1, log_var_sample_1, mean_sample_2, log_var_sample_2


@gin.configurable("mlvae_argmax")
class MLVaeArgmax(MLVae):
    """Class implementing the group-VAE without any label."""

    def aggregate(
        self,
        z_mean1,
        z_logvar1,
        z_mean2,
        z_logvar2,
        new_mean,
        new_log_var,
        labels,
        kl_per_point,
    ):
        mean_sample_1, log_var_sample_1 = aggregate_argmax(
            z_mean1, z_logvar1, new_mean, new_log_var, labels, kl_per_point
        )
        mean_sample_2, log_var_sample_2 = aggregate_argmax(
            z_mean2, z_logvar2, new_mean, new_log_var, labels, kl_per_point
        )
        return mean_sample_1, log_var_sample_1, mean_sample_2, log_var_sample_2


@gin.configurable("ml_vae_hg")
class MLVAEHyperGeom(MLVae):
    """Class implementing the ML-VAE with hypergeometric shared factor selection"""

    def __init__(
        self,
        beta=gin.REQUIRED,
        num_latent=gin.REQUIRED,
        num_classes=gin.REQUIRED,
        gs_temperature=gin.REQUIRED,
    ):
        super().__init__(beta)

        self.num_latent = num_latent
        self.num_classes = num_classes
        self.gs_temperature = gs_temperature
        self.hgeom_selector = HgeomSelector(
            self.num_latent, self.num_classes, self.gs_temperature
        )

    def aggregate(
        self,
        z_mean1,
        z_logvar1,
        z_mean2,
        z_logvar2,
        new_mean,
        new_logvar,
        labels,
        kl_per_point,
    ):
        diag_mask_avg, diag_mask_ind = self.hgeom_selector(new_mean, kl_per_point)
        z_mean_avg = tf.linalg.matmul(tf.expand_dims(new_mean, 1), diag_mask_avg)
        z_logvar_avg = tf.linalg.matmul(tf.expand_dims(new_logvar, 1), diag_mask_avg)
        z_mean1_ind = tf.linalg.matmul(tf.expand_dims(z_mean1, 1), diag_mask_ind)
        z_logvar1_ind = tf.linalg.matmul(tf.expand_dims(z_logvar1, 1), diag_mask_ind)
        z_mean2_ind = tf.linalg.matmul(tf.expand_dims(z_mean2, 1), diag_mask_ind)
        z_logvar2_ind = tf.linalg.matmul(tf.expand_dims(z_logvar2, 1), diag_mask_ind)
        mean_sample1 = tf.add(z_mean_avg, z_mean1_ind)
        log_var_sample1 = tf.add(z_logvar_avg, z_logvar1_ind)
        mean_sample2 = tf.add(z_mean_avg, z_mean2_ind)
        log_var_sample2 = tf.add(z_logvar_avg, z_logvar2_ind)

        mean_sample1 = tf.reshape(mean_sample1, [-1, self.num_latent])
        mean_sample2 = tf.reshape(mean_sample2, [-1, self.num_latent])
        log_var_sample1 = tf.reshape(log_var_sample1, [-1, self.num_latent])
        log_var_sample2 = tf.reshape(log_var_sample2, [-1, self.num_latent])
        return mean_sample1, log_var_sample1, mean_sample2, log_var_sample2


def aggregate_labels(z_mean, z_logvar, new_mean, new_log_var, labels, kl_per_point):
    """Use labels to aggregate.

    Labels contains a one-hot encoding with a single 1 of a factor shared. We
    enforce which dimension of the latent code learn which factor (dimension 1
    learns factor 1) and we enforce that each factor of variation is encoded in a
    single dimension.

    Args:
      z_mean: Mean of the encoder distribution for the original image.
      z_logvar: Logvar of the encoder distribution for the original image.
      new_mean: Average mean of the encoder distribution of the pair of images.
      new_log_var: Average logvar of the encoder distribution of the pair of
        images.
      labels: One-hot-encoding with the position of the dimension that should not
        be shared.
      kl_per_point: Distance between the two encoder distributions (unused).

    Returns:
      Mean and logvariance for the new observation.
    """
    del kl_per_point
    z_mean_averaged = tf.where(
        tf.math.equal(labels, tf.expand_dims(tf.reduce_max(labels, axis=1), 1)),
        z_mean,
        new_mean,
    )
    z_logvar_averaged = tf.where(
        tf.math.equal(labels, tf.expand_dims(tf.reduce_max(labels, axis=1), 1)),
        z_logvar,
        new_log_var,
    )
    return z_mean_averaged, z_logvar_averaged


def aggregate_argmax(z_mean, z_logvar, new_mean, new_log_var, labels, kl_per_point):
    """Argmax aggregation with adaptive k.

    The bottom k dimensions in terms of distance are not averaged. K is
    estimated adaptively by binning the distance into two bins of equal width.

    Args:
      z_mean: Mean of the encoder distribution for the original image.
      z_logvar: Logvar of the encoder distribution for the original image.
      new_mean: Average mean of the encoder distribution of the pair of images.
      new_log_var: Average logvar of the encoder distribution of the pair of
        images.
      labels: One-hot-encoding with the position of the dimension that should not
        be shared.
      kl_per_point: Distance between the two encoder distributions.

    Returns:
      Mean and logvariance for the new observation.
    """
    del labels
    mask = tf.equal(tf.map_fn(discretize_in_bins, kl_per_point, tf.int32), 1)
    k_hat = tf.reduce_sum(tf.cast(mask, tf.float32), axis=-1)
    z_mean_averaged = tf.where(mask, z_mean, new_mean)
    z_logvar_averaged = tf.where(mask, z_logvar, new_log_var)
    return z_mean_averaged, z_logvar_averaged, k_hat


def discretize_in_bins(x):
    """Discretize a vector in two bins."""
    return tf.histogram_fixed_width_bins(
        x, [tf.reduce_min(x), tf.reduce_max(x)], nbins=2
    )


def compute_kl(z_1, z_2, logvar_1, logvar_2):
    var_2 = tf.exp(logvar_2) + 1e-6
    kl_div = (
        tf.exp(logvar_1 - logvar_2)
        + tf.square(z_2 - z_1) / var_2
        - 1
        + logvar_2
        - logvar_1
    )
    kl_div = tf.nn.relu(kl_div)
    return kl_div


def make_metric_fn(*names):
    """Utility function to report tf.metrics in model functions."""

    def metric_fn(*args):
        return {name: tf.metrics.mean(vec) for name, vec in zip(names, args)}

    return metric_fn


def print_t(*args):
    ret = []
    for x in args:
        ret.append(
            tf.Print(x, [x, tf.math.reduce_any(tf.math.is_nan(x))], "Value of x: ")
        )
    if len(ret) == 1:
        return ret[0]
    return ret


@gin.configurable("hgeom_selector")
class HgeomSelector(tf.keras.layers.Layer):
    def __init__(
        self,
        num_latents,
        num_classes,
        tau=gin.REQUIRED,
        tau_init_mult=gin.REQUIRED,
        use_log_kl=gin.REQUIRED,
        use_tanh_act=gin.REQUIRED,
        clip_logits_w=gin.REQUIRED,
        neuralsort_fc_input=gin.REQUIRED,
        reparameterize_w=gin.REQUIRED,
        reparameterize_s=gin.REQUIRED,
        temperature_annealing=gin.REQUIRED,
        annealing_steps=gin.REQUIRED,
        start_annealing=gin.REQUIRED,
        shared_w_bias=gin.REQUIRED,
        mult_tau_neuralsort=gin.REQUIRED,
        use_neuralsort=gin.REQUIRED
    ):
        super(HgeomSelector, self).__init__()
        self.n_classes = num_classes
        self.num_latents = num_latents
        self.m = [float(num_latents), float(num_latents)]

        self.temperature_annealing = temperature_annealing
        self.final_tau = tau
        self.init_tau = float(tau_init_mult)*tau
        self.get_tau = temperature_annealing
        self.start_annealing = start_annealing
        self.annealing_steps = annealing_steps
        self.use_log_kl = use_log_kl
        self.use_tanh_act = use_tanh_act
        self.clip_logits_w = clip_logits_w
        self.use_neuralsort = use_neuralsort
        self.mult_tau_neuralsort = mult_tau_neuralsort
        self.neuralsort_fc_input = neuralsort_fc_input
        self.reparameterize_w = reparameterize_w
        self.reparameterize_s = reparameterize_s
        self.shared_w_bias = shared_w_bias

    @gin.configurable
    def linear_annealing(step, init_v, final_v, start_step, num_steps):
        curr_step = tf.cast(tf.math.maximum(step-start_step,0), tf.float32) 
        rate = tf.cast((99.*final_v)/float(num_steps), tf.float32)
        temp_sub = tf.cast(curr_step*rate, tf.float32)
        tau = tf.cast(
            tf.math.maximum(init_v - temp_sub, final_v),
            tf.float32
        )
        return tau

    @gin.configurable
    def exp_annealing(step, init_v, final_v, start_step, num_steps):
        curr_step = tf.cast(tf.math.maximum(step-start_step,0), tf.float32) 
        rate = tf.cast((tf.math.log(final_v) - tf.math.log(init_v))/float(num_steps), tf.float32)
        tau = tf.math.maximum(init_v*tf.math.exp(rate*curr_step),final_v)
        return tau

    @gin.configurable
    def no_annealing(step, init_v, final_v, start_step, rate):
        return final_v

    def bl_matmul(self, A, B):
        return tf.einsum("mij,jk->mik", A, B)

    def sample_gumbel(self, shape, eps=1e-20):
        U = tf.random.uniform(shape, minval=0, maxval=1)
        return -tf.math.log(-tf.math.log(U + eps) + eps)

    def neuralsort(self, s, tau=1.0, hard=True, add_gumbel_noise=False):
        """
        s: M x n x 1
        neuralsort(s): M x n x n
        Neuralsort algorithm from https://github.com/ermongroup/neuralsort
        """

        if add_gumbel_noise:
            s = s + self.sample_gumbel(tf.shape(s))

        A_s = s - tf.transpose(s, perm=[0, 2, 1])
        A_s = tf.abs(A_s)

        n = tf.shape(s)[1]
        one = tf.ones((n, 1), dtype=tf.float32)

        B = self.bl_matmul(A_s, one @ tf.transpose(one))

        K = tf.range(n) + 1

        C = self.bl_matmul(
            s, tf.expand_dims(tf.cast(n + 1 - 2 * K, dtype=tf.float32), 0)
        )

        P = tf.transpose(C - B, perm=[0, 2, 1])

        P = tf.nn.softmax(P / tau, -1)

        if hard:
            P_hard = tf.cast(
                tf.equal(P, tf.reduce_max(P, axis=-1, keepdims=True)), P.dtype
            )
            P = tf.stop_gradient(P_hard - P) + P

        return P

    def call(self, z_mean, kl_ind):
        n_samples = z_mean.get_shape()[0]
        n_latents = float(z_mean.get_shape().as_list()[1])
        tau = self.get_tau(tf.train.get_global_step(), self.init_tau, self.final_tau, self.start_annealing, self.annealing_steps)

        if self.use_log_kl:
            inp_ns_funcs = tf.math.log(kl_ind + 1)
        else:
            inp_ns_funcs = kl_ind
        
        if self.neuralsort_fc_input:
            logits_neuralsort = tf.layers.dense(
                inp_ns_funcs, n_latents, name="logits_neuralsort"
            )
        else:
            logits_neuralsort = inp_ns_funcs
        
        if self.use_tanh_act:
            # logits estimation with tanh * 5: output is on log-domain
            logits_w = 5 * tf.layers.dense(
                logits_neuralsort, self.n_classes, activation=tf.nn.tanh, name="logits_w"
            )
        else:
            logits_w = tf.layers.dense(logits_neuralsort, self.n_classes, name="logits_w")
            if self.clip_logits_w:
                logits_w = tf.clip_by_value(logits_w, -self.clip_logits_w, self.clip_logits_w, name="logits_w_clipped")

        logits_w_n = tf.nn.softmax(logits_w, -1)

        n = tf.expand_dims(tf.repeat(n_latents, repeats=n_samples), 1)
        _, x_all, y_mask = hg.pmf_noncentral_fmvhg(
            self.m, n, logits_w, tau, self.n_classes, self.reparameterize_w
        )
        logits_prior_w_independent = tf.zeros((n_samples, 1), dtype=tf.float32)
        logits_prior_w_shared = tf.constant(self.shared_w_bias,shape=(n_samples, 1), dtype=tf.float32)
        
        logits_prior_w = tf.concat(
            (logits_prior_w_independent,logits_prior_w_shared), 
            axis=-1
        )
        _, lp_prior = hg.get_probability(
            x_all, self.m, n, logits_prior_w, self.n_classes
        )
        _, lp_post = hg.get_probability(x_all, self.m, n, logits_w, self.n_classes)

        # Compute mask containing the d-k shared factors for each sample
        y_mask_sh = y_mask[1][:, :, 1:]
        y_mask_sh = tf.reverse(y_mask_sh, axis=(-1,))
        y_mask_ind = y_mask[0][:, :, 1:]


        if self.use_neuralsort:
            logits_neuralsort = tf.expand_dims(logits_neuralsort, -1)
            sort_perm = self.neuralsort(
                logits_neuralsort,
                tau=self.mult_tau_neuralsort*tau,
                hard=True,
                add_gumbel_noise=self.reparameterize_s,
            )
        else:
            sort_perm = tf.eye(int(n_latents), batch_shape=[n_samples])
        y_sel_ind = y_mask_ind @ sort_perm
        y_sel_sh = y_mask_sh @ sort_perm

        y_diag_ind = tf.clip_by_value(y_sel_ind, 0, 1)
        y_diag_sh = tf.clip_by_value(y_sel_sh, 0, 1)
        y_diag_sh = tf.squeeze(y_diag_sh)
        y_diag_ind = tf.squeeze(y_diag_ind)
        diag_ind = tf.linalg.diag(y_diag_ind, k=0, padding_value=0)
        diag_sh = tf.linalg.diag(y_diag_sh, k=0, padding_value=0)
        return (
            [diag_sh, diag_ind],
            [logits_w, logits_w_n],
            y_mask_ind,
            [lp_prior, lp_post],
            tau
        )

