import tensorflow as tf
import numpy as np

def different_weights_loss(weights, nb_classes, nb_features):
    def _custom_loss():        
    
        loss = tf.math.reduce_mean(
            tf.math.reduce_std(
                tf.reshape(weights, (-1, nb_classes, nb_features)),
                axis=2
            ),
            axis=0
        )
        
        loss = 2*tf.math.exp(-3*loss)
        loss = tf.math.reduce_sum(loss)
        return loss
    return _custom_loss

def loss_sim_dist(class_anchors, max_angle, max_dist):
    @tf.function
    def _loss(y_true, y_pred):
        # sometimes there are empty batches distributed to replicas
        batch_size = tf.shape(y_true)[0]
        if batch_size == 0:
            return tf.reshape(tf.constant(0, dtype=tf.float32), (-1,))
        
        # y true is not one hot encoded
        representation_to_learn = tf.gather(class_anchors, y_true)
        
        sim = tf.keras.losses.cosine_similarity(representation_to_learn, 
                                                y_pred, 
                                                axis=1)
        dist = tf.norm(y_pred - representation_to_learn, axis=1)
        
        relax_sim = tf.maximum(sim+max_angle, 0)
        relax_dist = tf.maximum(dist-max_dist, 0)
        
        # relax_sim = tf.math.exp(relax_sim)-1
        # relax_dist = tf.math.square(relax_dist)-1

        return relax_sim + relax_dist
    
    return _loss

def loss_dist(class_anchors, max_dist):
    @tf.function
    def _loss(y_true, y_pred):
        # sometimes there are empty batches distributed to replicas
        batch_size = tf.shape(y_true)[0]
        if batch_size == 0:
            return tf.reshape(tf.constant(0, dtype=tf.float32), (-1,))
        
        # y true is not one hot encoded
        representation_to_learn = tf.gather(class_anchors, y_true)
        
        dist = tf.norm(y_pred - representation_to_learn, axis=1)
        distance_to_all_anchors = tf.norm(tf.expand_dims(y_pred, 1) - class_anchors, axis=2)
        
        relax_dist = tf.maximum(dist-max_dist, 0)
        
        dist_difference = tf.expand_dims(dist, axis=1) - distance_to_all_anchors
        exp_difference = tf.exp(dist_difference)
        # some indices should be skipped in the sum operation so set them to 0 before sum
        indices_to_ignore = tf.stack([tf.range(batch_size, dtype=tf.int64), y_true], axis=1)
        dist_difference = tf.tensor_scatter_nd_update(exp_difference, indices_to_ignore, tf.zeros(batch_size))
        difference_sums = tf.reduce_sum(
                exp_difference,
                axis=1
        )
        
        tuplet_loss = tf.math.log(1 + difference_sums) 
        
        return relax_dist
        # return 0.1 * relax_dist + tuplet_loss
            
    return _loss

def loss_cac(class_anchors):
    @tf.function
    def _loss(y_true, y_pred):
        # sometimes there are empty batches distributed to replicas
        batch_size = tf.shape(y_true)[0]
        if batch_size == 0:
            return tf.reshape(tf.constant(0, dtype=tf.float32), (-1,))
        
        representation_to_learn = tf.gather(class_anchors, y_true)
        
        anchor_loss = tf.norm(y_pred - representation_to_learn, axis=1)
        
        distance_to_all_anchors = tf.norm(tf.expand_dims(y_pred, 1) - class_anchors, axis=2)
        
        dist_difference = tf.expand_dims(anchor_loss, axis=1) - distance_to_all_anchors
        exp_difference = tf.exp(dist_difference)
        # some indices should be skipped in the sum operation so set them to 0 before sum
        indices_to_ignore = tf.stack([tf.range(batch_size, dtype=tf.int64), y_true], axis=1)
        dist_difference = tf.tensor_scatter_nd_update(exp_difference, indices_to_ignore, tf.zeros(batch_size))
        difference_sums = tf.reduce_sum(
                exp_difference,
                axis=1
        )
                           
        tuplet_loss = tf.math.log(1 + difference_sums)     
                
        return 0.1 * anchor_loss + tuplet_loss
    return _loss

def penalize_wrong_classification(class_anchors, max_dist):
    @tf.function
    def _loss(y_true, y_pred):
        # sometimes there are empty batches distributed to replicas
        batch_size = tf.shape(y_true)[0]
        if batch_size == 0:
            return tf.reshape(tf.constant(0, dtype=tf.float32), (-1,))
        
        # y true is not one hot encoded
        representation_to_learn = tf.gather(class_anchors, y_true)
        
        dist = tf.norm(y_pred - representation_to_learn, axis=1)
        distance_to_all_anchors = tf.norm(tf.expand_dims(y_pred, 1) - class_anchors, axis=2)
        
        relax_dist = tf.maximum(dist-max_dist, 0)
        
        # penalize wrong classification
        relax_dist = tf.where(tf.argmin(distance_to_all_anchors, axis=1) == y_true, relax_dist, relax_dist**2)
        
        dist_difference = tf.expand_dims(dist, axis=1) - distance_to_all_anchors
        difference_sums = tf.reduce_sum(
                tf.where(tf.math.not_equal(dist_difference, 0), tf.math.exp(dist_difference), 0),
                axis=1
        )
                           
        tuplet_loss = tf.math.log(1 + difference_sums)   
        
        # relax_sim = tf.math.exp(relax_sim)-1
        # relax_dist = tf.math.square(relax_dist)-1

        return relax_dist + tuplet_loss
    
    return _loss