#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""

Codebook metrics and utils.


"""



import tensorflow.compat.v2 as tf


__all__ = [
    'cluster_majority_class',
]

import numpy as np
from absl import logging,flags

FLAGS = flags.FLAGS

def create_cluster_metrics(num_centroids, num_labels,label_type):
    
    cluster_metrics = {}
    
    for c in range(num_centroids):
        for l in range(num_labels):
            cluster_metrics.update({f'{label_type}_cluster_{c}_label_{l}':tf.keras.metrics.Mean()})
            
    return cluster_metrics

def eval_clusters(strategy, metrics, epoch, clusters, labels,label_type):
        """
        Utility function to compute and update cluster metrics.
        It is used for qualitative evaluation of clustering algorithm.
        
        """
        
        labels=labels.numpy()
        clusters=clusters.numpy()
        logging.info(f"------ Cluster evaluation on {label_type}. ------")
        for c in range(FLAGS.codebook_size):
            indices=np.where(clusters == c)[0]
            
            # labels of cluster c
            labels_c=labels[indices]
            # count number of labels per cluster
            unique, counts = np.unique(labels_c, return_counts=True)
            
            @tf.function
            def update_cluster_metrics_fn():
                for l, cc in zip(unique, counts):
                    metrics[f'{label_type}_cluster_{c}_label_{l}'].update_state(cc)

            strategy.run(update_cluster_metrics_fn)

            # report number of labels per cluster centroid
            
            logging.info("Cluster %d: labels: %s, counts: %s",c, ' '.join(str(l) for l in unique), ' '.join(str(c) for c in counts))

def get_cluster_class(clusters, labels, num_centroids):
    
    """ 
    
    Return the class id of the majority class in the closest centroid for each datapoint.
    
    """
    
    # get the majority class in each cluster
    cluster_classes= cluster_majority_class(clusters, labels,num_centroids)
            
    clusters=tf.expand_dims(clusters,axis=-1)

    return tf.gather_nd(cluster_classes,clusters)
    
  
    



def cluster_majority_class(clusters, labels, num_centroids):
    
    """ 
    
    Get the majority class of each cluster, given the closest cluster and the label of each datapoint.
    
    Inputs:
        
        clusters: tensor with shape [batch_size] that contains the id of closest centroid in the cluster for each datapoint.
        labels: tensor with shape [batch_size] that contains the class id for each datapoint.
    
    """
    
    labels=tf.cast(labels,tf.int32)
    clusters=tf.cast(clusters,tf.int32)
    
    majority_classes=[]
    for c in range(num_centroids):
                
        cluster_c_indices=tf.where(clusters==c)
        

        def find_majority_class():

            labels_c=tf.gather_nd(labels,cluster_c_indices )
                        
            label, idx, label_count = tf.unique_with_counts(labels_c)
            

                        
            return  tf.squeeze(tf.gather_nd(label,tf.expand_dims(tf.argmax(label_count),axis=0)))
   
                

            
                    
        majority_class= tf.cond(tf.size(cluster_c_indices)>0, lambda: find_majority_class(),  lambda:  -1)
        majority_classes.append(majority_class)
  
    return tf.stack(majority_classes,axis=0)
    
    

     
