
import keras
import tensorflow as tf
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.keras import backend as K
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras import activations
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras import constraints


class nanDense(keras.layers.Dense):
  
    def __init__(self,
                 units,
                 use_c = False,     # A flag to use compensatory weight or not.
                 activation=None,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 **kwargs):
        super(nanDense, self).__init__(units,
              activity_regularizer=activity_regularizer, **kwargs)
        
        self.use_c = use_c
        self.use_bias = True
        self.units = int(units) if not isinstance(units, int) else units
        self.activation = activations.get(activation)
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.input_spec = InputSpec(min_ndim=2)
        self.supports_masking = True
      
    def build(self, input_shape):
        
        dtype = dtypes.as_dtype(self.dtype or K.floatx())
        if not (dtype.is_floating or dtype.is_complex):
          raise TypeError('Unable to build `nanDense` layer with non-floating point '
                          'dtype %s' % (dtype,))
    
        input_shape = tensor_shape.TensorShape(input_shape)
        last_dim = tensor_shape.dimension_value(input_shape[-1])
        if last_dim is None:
            raise ValueError('The last dimension of the inputs to `nanDense` '
                             'should be defined. Found `None`.')
        self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim})
        
        if self.use_c:      # an extra weight if use_c is True.
            self.kernel = self.add_weight(
                'kernel',
                shape=[last_dim+1, self.units],
                initializer=self.kernel_initializer,
                regularizer=self.kernel_regularizer,
                constraint=self.kernel_constraint,
                dtype=self.dtype,
                trainable=True)
        else:
            self.kernel = self.add_weight(
                'kernel',
                shape=[last_dim, self.units],
                initializer=self.kernel_initializer,
                regularizer=self.kernel_regularizer,
                constraint=self.kernel_constraint,
                dtype=self.dtype,
                trainable=True)
        
        self.bias = self.add_weight(
            'bias',
            shape=[self.units,],
            initializer=self.bias_initializer,
            regularizer=self.bias_regularizer,
            constraint=self.bias_constraint,
            dtype=self.dtype,
            trainable=True)
        
        self.epsilon = tf.fill(self.kernel.shape, K.epsilon()) # Epsilon Matrix
        
        self.built = True
        
        
    def call(self, inputs):
        
        dtype = self._compute_dtype_object
        
        if self.use_c:      # Computing and concatenating the compensatory weight to the weights
            c = tf.math.reduce_sum(tf.cast(tf.math.is_nan(inputs), dtype), axis=1)
            c = tf.math.divide(c, inputs.shape[1]) 
            inputs = tf.concat([inputs,tf.expand_dims(c, axis=1)], axis=1)
        
        kernel = math_ops.add(self.epsilon, self.kernel) # Adding epsilon to weights
        
        if self.dtype:
            if inputs.dtype.base_dtype != dtype.base_dtype:
              inputs = math_ops.cast(inputs, dtype=dtype)
        
        rank = inputs.shape.rank
        if rank == 2 or rank is None:
            if isinstance(inputs, sparse_tensor.SparseTensor):
                raise NotImplementedError
            else:
                outputs = []
                for i in range(self.kernel.shape[1]): # Computing Neutralizers and activations for each neuron
                    d = tf.math.divide(-self.bias[i]/self.kernel.shape[0], kernel[:,i])
                    temp_inputs = tf.where(tf.math.is_nan(inputs), d, inputs) # replacing nans in inputs with -b/w
                    outputs.append(gen_math_ops.mat_mul(temp_inputs, kernel[:,i:i+1]))
                outputs = tf.concat(outputs, axis=1)
                
        # Broadcast kernel to inputs.
        else:
            raise NotImplementedError
            
        outputs = nn_ops.bias_add(outputs, self.bias)
          
        if self.activation is not None:
            outputs = self.activation(outputs)
        
        return outputs 
      
