import keras
import keras.backend as K
import numpy as np
import tensorflow as tf

#https://github.com/handongfeng/MNIST-center-loss/blob/master/centerLoss_MNIST.py
class CenterLossLayer(keras.layers.Layer):

    def __init__(self, num_features, num_classes=2, alpha=0.5, **kwargs):
        super().__init__(**kwargs)
        self.alpha = alpha
        self.num_classes = num_classes
        self.num_features = num_features
        
    def build(self, input_shape):
        self.centers = self.add_weight(name='centers',
                                       shape=(self.num_classes, self.num_features),
                                       initializer='uniform',
                                       trainable=False)
        super().build(input_shape)

    def call(self, x, mask=None):

        # x[0] is Nx2, x[1] (labels) is Nx(num_classes) onehot, self.centers is 10x2
        delta_centers = K.dot(K.transpose(x[1]), (K.dot(x[1], self.centers) - x[0]))  # num_classes x num_feats
        center_counts = K.sum(K.transpose(x[1]), axis=1, keepdims=True) + 1  # num_classes x 1
        delta_centers /= center_counts
        new_centers = self.centers - self.alpha * delta_centers
        self.add_update((self.centers, new_centers), x)

        self.result = x[0] - K.dot(x[1], self.centers) #  (N x num_feats) - (N x num_classes) x (num_classes x num_feats)
        self.result = K.sum(K.sum(self.result ** 2, axis=1, keepdims=True)) #/ K.dot(x[1], center_counts)
        return self.result # Nx1

    def compute_output_shape(self, input_shape):
        return K.int_shape(self.result)

class ZC_Layer(keras.layers.Layer):

    def __init__(self, batch_size, kernel_regularizer=None, no_diag=True, **kwargs):
        self.regularizer=kernel_regularizer
        self.batch_size = batch_size
        self.no_diag = no_diag
        super(ZC_Layer, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.kernel = self.add_weight(name='kernel', 
                                      shape=(self.batch_size, self.batch_size),
                                      initializer=keras.initializers.RandomNormal(mean=1e-4, stddev=0.01, seed=43),#toDoinitialize gaussian kernel
                                      #initializer=keras.initializers.Zeros(),
                                      #initializer=keras.initializers.Constant(1e-4),
                                      regularizer=self.regularizer,
                                      trainable=True)
        super(ZC_Layer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        if self.no_diag==False:
            return K.dot(self.kernel,x)
        else:
            no_diag = tf.linalg.set_diag(self.kernel,np.asarray(np.zeros((self.kernel.shape[0],))))
            return K.dot(no_diag,x)

    def compute_output_shape(self, input_shape):
        return (self.batch_size,input_shape[1])    
