# coding=utf-8
# Copyright 2023
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.



from keras import backend

from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import reduce_util
from tensorflow.python.keras.utils import control_flow_util
import tensorflow as tf
import tensorflow_probability as tfp

class DenseUAIB(tf.keras.layers.Dense):
  """
  Uncertainty -Aware  Information Bottleneck  dense layer.
  
   Args:
        
        uaib_dim: Dimension of latent features.
        codebook_size: The codebook size.
        uaib_tau: Lagrange multiplier (temperature) in the rate-distortion function of the regularization term.
        momentum: Momentum for the moving average. 
                            Moving average is used for  a batch-manner update of the covariance matrices of variational marginal posterior and the centroid probabilities.
        activation: Non-linearity to be applied to the input features before they are passed to the encoder's network.
        
    Output: 
    
        Output shape: [None, uaib_dim+1].
        
        Output 1: latent features of inputs.
        Output 2: Uncertainty score (confidence).
        
  References:
  
  [1] Alemi AA, Fischer I, Dillon JV, Murphy K. Deep variational information bottleneck. arXiv preprint arXiv:1612.00410. 2016 Dec 1.
  [2] Alemi AA, Fischer I, Dillon JV. Uncertainty in the variational information bottleneck. arXiv preprint arXiv:1807.00906. 2018 Jul 2.

  """
  def __init__(self,
                    uaib_dim, 
                    codebook_size, 
                    uaib_tau=1.0, 
                    momentum=0.99, 
                    activation='relu',
                    use_bias=False,
                    kernel_initializer='glorot_uniform',
                    bias_initializer='zeros',
                    kernel_regularizer=None,
                    bias_regularizer=None,
                    activity_regularizer=None,
                    kernel_constraint=None,
                    bias_constraint=None,
                    **kwargs):
                   
    super().__init__(units=uaib_dim+ (uaib_dim*(uaib_dim+1))//2,
                            activation=None,
                            use_bias=use_bias,
                            kernel_initializer=kernel_initializer,
                            bias_initializer=bias_initializer,
                            kernel_regularizer=kernel_regularizer,
                            bias_regularizer=bias_regularizer,
                            activity_regularizer=activity_regularizer,
                            kernel_constraint=kernel_constraint,
                            bias_constraint=bias_constraint,
                            **kwargs)
    
    # variational bottleneck args
    self.uaib_activation=tf.keras.activations.get(activation)
    self.uaib_dim=uaib_dim
    self.codebook_size=codebook_size
    self.uaib_tau=uaib_tau
    self.momentum=momentum

  def build(self, input_shape):
    input_shape = tf.TensorShape(input_shape)
    
    super().build(input_shape)

    # centroid (reference distributions) parameters
    self.centroid_means=[]
    self.centroid_covariance=[]
    self.centroid_covariance_nxt=[]

    
    for i in range(self.codebook_size):
        self.centroid_means.append(super().add_weight(
            name='centroid_mean_'+str(i),
            shape=[self.uaib_dim],
            initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1),
            trainable=True,
            dtype=self.dtype))
            
        # covariance matrices (maintain two copies for em)
        # the covariance matrix to be used in current epoch (so that all inference steps use the same value)
        self.centroid_covariance.append(super().add_weight(
            name='centroid_covariance_'+str(i),
            shape=[self.uaib_dim, self.uaib_dim],
            initializer="zeros", 
            trainable=False,
            synchronization=tf.VariableSynchronization.ON_READ,
            aggregation=tf.compat.v1.VariableAggregation.ONLY_FIRST_REPLICA,
            dtype=self.dtype))
        # the covariance matrix that is updated in batch-mode for use in the next epoch
        self.centroid_covariance_nxt.append(super().add_weight(
            name='centroid_covariance_'+str(i),
            shape=[self.uaib_dim, self.uaib_dim],
            initializer="zeros", 
            trainable=False,
            synchronization=tf.VariableSynchronization.ON_READ,
            aggregation=tf.compat.v1.VariableAggregation.ONLY_FIRST_REPLICA,
            dtype=self.dtype))
        # 
        precision_matrix_reset_op = self.centroid_covariance[i].assign(tf.eye(self.uaib_dim))
        self.add_update(precision_matrix_reset_op)


    # centroid propabilities (maintain two copies for em)
    # the prior centroid probabilities to be used in current epoch (so that all inference steps use the same value)
    self.cen_probs=tf.Variable(initial_value=tf.ones(shape=[self.codebook_size])*1/self.codebook_size,
                                                trainable=False,
                                                name="centroid_probabilities",
                                                synchronization=tf.VariableSynchronization.ON_READ,
                                                aggregation=tf.compat.v1.VariableAggregation.ONLY_FIRST_REPLICA,
                                                shape=[self.codebook_size], 
                                            )
    # the prior centroid probabilities updated in batch-mode by Bayes rule for use in the next epoch
    self.cen_probs_nxt=tf.Variable(initial_value=tf.ones(shape=[self.codebook_size])*1/self.codebook_size,
                                                    trainable=False,
                                                    name="centroid_probabilities_nxt",
                                                    synchronization=tf.VariableSynchronization.ON_READ,
                                                    aggregation=tf.compat.v1.VariableAggregation.ONLY_FIRST_REPLICA,
                                                    shape=[self.codebook_size], 
                                                )
                                                
    self.initialized = self.add_weight(
        name='init',
        dtype=tf.bool,
        shape=(),
        initializer="zeros",
        trainable=False,
    )


    self.built = True

  def call(self, inputs, training=None):
  
    if training is None:
        training = tf.keras.backend.learning_phase()

    # encoder params
    inputs=self.uaib_activation(inputs)
    
   
    params=super().call(inputs)
  
    mu, perturb_factor=tf.split(params, [self.uaib_dim, (self.uaib_dim*(self.uaib_dim+1))//2] , axis=-1)
    
    # make covariance matrix symmetric
    perturb_factor_low=tfp.math.fill_triangular(perturb_factor)
    perturb_factor_upper=tfp.math.fill_triangular(perturb_factor,upper=True)
  
    perturb_factor=0.5 * (perturb_factor_upper+perturb_factor_low)

    
  
    # make covariance matrix psd
    perturb_diag,perturb_factor,_=tf.linalg.svd( perturb_factor,full_matrices=True)
    perturb_diag= tf.math.softplus(perturb_diag-5.0)+1e-5

    
    covariance = tf.linalg.matmul(perturb_factor,tf.linalg.diag(perturb_diag*perturb_diag))
    covariance =tf.linalg.matmul(covariance,perturb_factor,transpose_b=True)
    

    encoder=tfp.distributions.MultivariateNormalFullCovariance(loc=mu, covariance_matrix=covariance)
    

    cluster_distances=[]
    
    # TODO: make loop efficient
    for i in range(self.codebook_size):

        centroid=tfp.distributions.MultivariateNormalTriL(loc=self.centroid_means[i], 
                                                                                       scale_tril=tf.linalg.LinearOperatorLowerTriangular(self.centroid_covariance[i]).to_dense()
                                                                                       )

        # distance of encoder from centroids
        cluster_distance_i=encoder.kl_divergence(centroid)                                                                                                
        cluster_distances.append(cluster_distance_i)

    cluster_distances=tf.stack(cluster_distances, axis=-1)  
    
    cluster_distances_loss=tf.identity(cluster_distances, name=self.name+"_cluster_distances")

    self.add_loss(cluster_distances_loss)
    
    #E-step: update conditional centroid probabilities
    log_cen_probs=tf.reshape(tf.math.log(self.cen_probs), shape=[1, self.codebook_size])
    cond_cen_probs=tf.nn.softmax(tf.stop_gradient(log_cen_probs-self.uaib_tau*cluster_distances), axis=-1)
    
    
    cond_cen_probs=tf.cond(self.initialized,  lambda: cond_cen_probs, lambda: tf.ones(shape=[tf.shape(inputs)[0],self.codebook_size])*1/self.codebook_size)
    

    uncertainty=tf.math.reduce_sum(cond_cen_probs*cluster_distances, axis=-1)
    
    # M-step: update prior centroid probabilities
    self.add_update(self._update_cen_probs(training, cond_cen_probs))

    # M-step: update covariance matrices
    self.add_update(self._update_centroids(training, cond_cen_probs, mu, encoder.covariance()))
    
    # sample latent features

    
    outputs_sample=encoder.sample()
  
        
    outputs=control_flow_util.smart_cond(training, lambda: outputs_sample,lambda: mu)

    # return latent features and uncertainty
    return tf.concat([outputs,tf.expand_dims(uncertainty, axis=-1)], axis=-1) 
    
 ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ###
  ###                                                   Utility functions                                                          ###
 ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ###
    
  def _assign_moving_average(self, variable, value, momentum):
      
        """Utility function for updating the moving average of a variable."""
        
        def calculate_update_delta():
            decay = tf.convert_to_tensor(1.0 - momentum, name="decay")
            if decay.dtype != variable.dtype.base_dtype:
                decay = tf.cast(decay, variable.dtype.base_dtype)
            update_delta = (variable - tf.cast(value, variable.dtype)) * decay
            return update_delta

        with backend.name_scope("AssignMovingAvg") as scope:
            if tf.compat.v1.executing_eagerly_outside_functions():
                return variable.assign_sub(calculate_update_delta(), name=scope)
            else:
                with tf.compat.v1.colocate_with(variable):
                    return tf.compat.v1.assign_sub(variable, calculate_update_delta(), name=scope)
    
   ###                          Utility functions   for upating centroid probabilities                                 ###
  def _update_cen_probs(self, training, cond_cen_probs):
      
    """
        Utility function for updating prior centroid probabilities.
        
        if in training phase, it updates the prior centroid probabilities (by Bayes rule on current batch).
        The nxt copy is updated so that the same value is used throughout the whole epoch for all datapoints.
    """
      
    if training is None:
        training = tf.keras.backend.learning_phase()

    def  _do_update_cen_probs():
        
        new_cen_probs=self._calculate_cen_probs(cond_cen_probs)
        
        return self._assign_moving_average(self.cen_probs_nxt, new_cen_probs, self.momentum)
        
    true_branch = lambda: _do_update_cen_probs()
    false_branch = lambda: self.cen_probs
    return control_flow_util.smart_cond(training, true_branch, false_branch)
    
  def reset_centroid_probs(self):
    """"
    
    Utility function for reseting the prior centroid probabilities (the nxt copy).
    
    This function is useful for reseting the model's prior centroid probabilities at the begining of a new epoch.
    Reset the nxt copy of centroid probabilities to uniform distribution.
    
    """
    cen_probs_reset_op = self.cen_probs_nxt.assign(tf.ones(shape=[self.codebook_size])*1/self.codebook_size)
    self.add_update(cen_probs_reset_op)
    return
      
  def set_centroid_probs(self):
      
    """
        Utility function for updating the prior centroid probabilities with their new values ( the nxt copy).
        This function is useful for updating centroid probabilities for use in the next epoch.
    """
    
    cen_probs_set_op = self.cen_probs.assign(self.cen_probs_nxt)    
    self.add_update(cen_probs_set_op)
    return
    
  def _calculate_cen_probs(self, cond_cen_probs):
      
    """
    
    Utility function for calculating prior centroid probalities.
    It calculates prior centroid probabilities from the conditional centroid probabilities of current batch, per Bayes rule.
    It computes Equation of p(h|x) (page 1731) in [1].
    
    References:
    
    [1] Banerjee A, Merugu S, Dhillon IS, Ghosh J, Lafferty J. Clustering with Bregman divergences. Journal of machine learning research. 2005 Oct 1;6(10).
    
    
    """
      
    batch_size = tf.cast(tf.shape(cond_cen_probs)[0], tf.float32)
    replica_ctx = distribution_strategy_context.get_replica_context()
    
    # gather conditional probabilities from all replicas
    if replica_ctx is not None:
        local_sum = tf.reduce_sum(cond_cen_probs, axis=0, keepdims=False)
    
        global_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, local_sum)

        global_batch_size = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM,batch_size)
       
        # Apply relu in case floating point rounding causes it to go negative.
        cen_probs= tf.nn.relu(global_sum/global_batch_size)
        return cen_probs
    else:
        # Apply relu in case floating point rounding causes it to go negative.
        cen_probs= tf.nn.relu(tf.reduce_sum(cond_cen_probs, axis=0, keepdims=False)/batch_size)
        return cen_probs
         
  ###                          Utility functions   for upating covariance matrices                                 ###
  def reset_centroids(self):
        """
        Utility function for reseting the covariance matrices (the nxt copy).
       This function is useful for reseting the model's covariance matrix at the begining of a new epoch.
        """
    
        for i in range(self.codebook_size):
            centroid_covariance_reset_op = self.centroid_covariance_nxt[i].assign(tf.zeros(shape=[self.uaib_dim, self.uaib_dim]))
            self.add_update(centroid_covariance_reset_op)
        return
      
  def set_centroids(self):
      
    """
        Utility function for updating the covariance matrices with their new values ( the nxt copy).
        This function is useful for updating the model's covariance matrix for use in the next epoch.
    """
    
    for i in range(self.codebook_size):
    
        centroid_covariance_set_op = self.centroid_covariance[i].assign(self.centroid_covariance_nxt[i])
        self.add_update(centroid_covariance_set_op)
    return
    
  def _update_centroids(self, training, cond_cen_probs,  mu,  covariance):
    
    """
        Utility function for updating the covariance matrices of the centroids.
        
        if in training phase, it updates the moving avrerage of the covariance matrix on current batch.
        The nxt copy is updated so that the same value is used throughout the whole epoch for all datapoints.
                
    """
      
    if training is None:
        training = tf.keras.backend.learning_phase()
        
    def  _do_update_centroids():
        
        updates=[]
        for i in range(self.codebook_size):
            new_centroid_covariance=self._calculate_centroid_covariance(i, cond_cen_probs, mu, covariance)
            updates.append(self._assign_moving_average(self.centroid_covariance_nxt[i], new_centroid_covariance, self.momentum))
        return updates
    
        
    true_branch = lambda: _do_update_centroids()
    false_branch = lambda: self.centroid_covariance
    
    return control_flow_util.smart_cond(training, true_branch, false_branch)
    
  def _calculate_centroid_covariance(self, i, cond_cen_probs, mu, covariance):
      
        """"
        Utility function for computing the clustering optimal covariance matrix of a centroid on current batch.
        Equation (9) in [1] is computed. The cholesky decomposition of the covariance matrix is returned.
        
        References:
        
        [1] Davis J, Dhillon I. Differential entropic clustering of multivariate gaussians. Advances in Neural Information Processing Systems. 2006;19.
        """
        
        replica_ctx = distribution_strategy_context.get_replica_context()
        batch_size = tf.cast(tf.shape(cond_cen_probs)[0], tf.float32)
        # gather means and variances of encoders' to be clustered across all replicas
        if replica_ctx is not None:
            global_mu = replica_ctx.all_gather(mu, axis=0)
            global_covariance = replica_ctx.all_gather(covariance, axis=0)
            global_cond_cen_probs = replica_ctx.all_gather(cond_cen_probs, axis=0)
            global_batch_size = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM,batch_size)

        else:
            global_mu=mu
            global_covariance=covariance
            global_cond_cen_probs = cond_cen_probs
            global_batch_size = batch_size
            
       
        
        # mu: [batch_size, uaib_dim], condprobs: [batch_size,num_centroids]
        # normalize across datapoints to get contribution for each one to the new mean
        global_cond_cen_probs=(global_cond_cen_probs[:, i]+5.0)/(tf.math.reduce_sum(global_cond_cen_probs[:, i], axis=0)+5*global_batch_size)
        
        diff=global_mu-self.centroid_means[i]
        
        mi_miT=tf.linalg.matmul( tf.expand_dims(diff,axis=-1), 
                                                tf.expand_dims(diff,axis=1), 
                                                transpose_a=False,
                                                transpose_b=False)

        new_covariance=global_covariance+mi_miT
        
        global_cond_cen_probs=tf.expand_dims(global_cond_cen_probs, axis=-1)
        global_cond_cen_probs=tf.expand_dims(global_cond_cen_probs, axis=-1)
              
        new_covariance=global_cond_cen_probs*new_covariance
        new_covariance=tf.reduce_sum(new_covariance, axis=0)
    
        new_covariance=tf.linalg.cholesky(new_covariance)
        return new_covariance
    
  # TODO:  check serialization
  def get_config(self):
    config = {
        'uaib_activation':tf.keras.activations.serialize(self.uaib_activation),
        'uaib_dim':self.uaib_dim,
        'codebook_size':self.codebook_size, 
        'uaib_tau':self.uaib_tau, 
        'momentum':self.momentum, 
    }
    new_config = super().get_config()
    new_config.update(config)
    return new_config


