import tensorflow as tf

class OrthogonalRegularizer(tf.keras.regularizers.Regularizer):
    def __init__(self, alpha=0.01):
        self.alpha = alpha
    
    def __call__(self, w):
        return 0.5 * self.alpha * tf.reduce_sum(
            tf.square(tf.matmul(tf.transpose(w), w) - tf.eye(w.shape[1], dtype='float32'))
        )

    def get_config(self):
        return {'alpha': self.alpha}
    

class CN(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(CN, self).__init__(**kwargs)
        
    def build(self, input_shape):
        self.gamma = self.add_weight(
            shape=(1,),
            name='gamma'
        )
        
    def call(self, inputs):
        return self.gamma * inputs