import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import Dropout, Dense, BatchNormalization
from lib.tf import regularizers
from tensorflow.keras.constraints import NonNeg, MaxNorm
from lib.tf import initializers
from lib.tf import activations
from lib.tf import layers

def return_regularizer(reg_type,reg_size,d=None):
    reg_type = str(reg_type)
    if reg_type.lower() == "spec" :
        assert d is not None
        regularizer = regularizers.SpectralNormRegularizer(d,reg_lambda=reg_size,num_iter=2, p=1.0)
    elif reg_type.lower() == "l2" :
        regularizer = regularizers.L2(reg_size)
    elif reg_type.lower() == "l1" :
        regularizer = regularizers.L1(reg_size)
    elif reg_type.lower() == "intv" :
        regularizer = regularizers.IntervalRegularizer(reg_size,reg_size)        
    elif reg_type.lower() == "none" :
        regularizer = lambda W:0.0
    else :
        regularizer = lambda W:0.0
    return regularizer

class IntervalLayer(keras.layers.Layer):
    def __init__(self,units,activation="softsign",beta1=4.0,beta2=4.0,beta3=4.0,drop=0.0, 
                 exp_type="id", mean=0.0,std=0.5, reg_type="none", reg_size=0.0, seed=None, **kwargs):
        super(IntervalLayer,self).__init__(**kwargs)

        self._intv_act = activations.activation_generator(activation,beta1)
        self._act = activations.activation_generator(activation,beta1)
        self._units = units
        self._beta2 = float(beta2)
        self._beta3 = float(beta3)
        
        if seed is not None:
            self._c_initializer = initializers.RandomUniform(-1.0,1.0,seed=seed)
            self._w_initializer = initializers.TruncatedNormal(mean,std,seed=seed+1)
            self._b_initializer = initializers.Zeros()
            self._drop_layer = layers.Dropout(drop, seed= seed+2)
        else:
            self._c_initializer = initializers.RandomUniform(-1.0,1.0,seed=None)
            self._w_initializer = initializers.TruncatedNormal(mean,std,seed=None)
            self._b_initializer = initializers.Zeros()
            self._drop_layer = layers.Dropout(drop, seed= None)

        if exp_type=="exp":
            self._exp_transform = lambda x : tf.exp(x)
        elif exp_type=="exp_sym":
            self._exp_transform = lambda x : tf.exp(x)-tf.exp(-x)
        elif exp_type == "id":
            self._exp_transform =lambda x:x
        else:
            raise ValueError(f"Invalid exp_type : {exp_type}")    
        
        self._reg_type = reg_type
        self._reg_size = reg_size

    def build(self,input_shape):
        d = int(input_shape[-1])
        self._regularizer = regularizers.IntervalRegularizer(self._reg_size,self._reg_size)
        
        self._center_left = self.add_weight(name="cl",
                                            shape=[1,input_shape[1],self._units],
                                            initializer=self._c_initializer,
                                            trainable=True)
        self._center_right = self.add_weight(name="cr",
                                            shape=[1,input_shape[1],self._units],
                                            initializer=self._c_initializer,
                                            trainable=True)        
        self.kernel = self.add_weight(name="kernel",
                                      shape=[1,input_shape[1],self._units],
                                      initializer=self._w_initializer,
                                      trainable=True) 
        self._bias = self.add_weight(name="b",
                                     shape=[1,input_shape[1],self._units],
                                     initializer=self._b_initializer,
                                     trainable=True)          
        self._d = tf.Variable(input_shape[-1],trainable=False, dtype=self._bias.dtype,name="d")
    
    @tf.function
    def call(self,inputs,training=False): 
        inputs = tf.expand_dims(inputs,-1)
        output = (self._beta2*self.kernel)*(self._intv_act(inputs-self._center_left) + self._intv_act(self._center_right-inputs))
        output += self._beta3*self._bias
        output = self._act(output)
        output = self._drop_layer(output,training)
        return output
    
    @tf.function
    def calc_reg(self):
        return self._regularizer(tf.squeeze(self.kernel,0),tf.squeeze(self._center_left,0),tf.squeeze(self._center_right,0),tf.squeeze(self._bias,0))
    
class ParameterizedReduceLayer(keras.layers.Layer):
    
    def __init__(self,dim_in,dim_out=None,activation="linear",method="max",drop=0.0, reg_type="none", reg_size=0.0,seed=None,**kwargs):
        super(ParameterizedReduceLayer,self).__init__(**kwargs)
        self._dim_in = dim_in
        self._dim_out = dim_in if dim_out is None else dim_out
        self._act = activations.activation_generator(activation,None)#getattr(my_activations,activation)
        self._m_initializer = initializers.GlorotNormal(seed=seed) if seed is not None else initializers.GlorotNormal(seed=None)
        self._drop_layer = layers.Dropout(drop, seed=seed+1) if seed is not None else layersDropout(drop)
        self._reduce_op = layers.ReduceOpLayer(method,axis=1)
        
        self._reg_type = reg_type
        self._reg_size = reg_size
        
    def build(self,input_shape):
        d = int(input_shape[-1])
        self._regularizer = return_regularizer(self._reg_type,self._reg_size,d)        
        self.kernel =  self.add_weight(name="kernel",
                                       shape=[self._dim_in,self._dim_out],
                                       initializer=self._m_initializer,
                                       trainable=True)
    
    @tf.function
    def call(self,inputs,training=False):
        output = tf.matmul(inputs,self.kernel)       
        reduce_output = self._reduce_op(output) # nxdxh->nxh
        reduce_output = self._act(reduce_output)
        reduce_output = self._drop_layer(reduce_output,training)
        return reduce_output
        
    @tf.function
    def calc_reg(self):
        return self._regularizer(self.kernel)  

class ReduceOpLayer(keras.layers.Layer):

    def __init__(self,method,axis=1,**kawrgs):
        super(ReduceOpLayer,self).__init__(trainable=False)
        self._method = method
        if method == "mean":
            self._reduce_op = lambda x:tf.reduce_mean(x,axis)
        elif method == "sum":
            self._reduce_op = lambda x:tf.reduce_sum(x,axis)            
        elif method == "max":
            self._reduce_op = lambda x:tf.reduce_max(x,axis)
        elif method == "concat":
            # self._reduce_op = lambda x:tf.concat([tf.reduce_max(x,axis),tf.reduce_mean(x,axis)],axis=axis)
            self._reduce_op = lambda x:tf.reduce_max(x,axis) + tf.reduce_mean(x,axis)
        else:
            raise ValueError(f"Invalid method {method}")
            
    @tf.function    
    def call(self,inputs,training=False):
        return self._reduce_op(inputs)
        
class DenseDrop(keras.layers.Layer):

    def __init__(self,units,
                 activation=None,
                 use_bias=True,
                 kernel_initializer='GlorotNormal',
                 bias_initializer='Zeros',
                 drop_rate = 0.0,
                 reg_type="none",reg_size=0.0,
                 seed = None,
                 **kwargs
                ):
        super(DenseDrop,self).__init__(**kwargs)
        self._units = int(units)
        self._act = getattr(activations,"linear") if activation is None else activations.activation_generator(activation,None)#getattr(my_activations,activation)
        self._kernel_initializer = getattr(initializers,kernel_initializer)(seed=seed) if kernel_initializer in initializers.RANDOM_INIT else getattr(initializers,kernel_initializer)()
        self._bias_initializer = getattr(initializers,bias_initializer)(seed=seed) if bias_initializer in initializers.RANDOM_INIT else getattr(initializers,bias_initializer)()
        self._drop = keras.layers.Dropout(drop_rate,seed=seed)

        self._reg_type = reg_type
        self._reg_size = reg_size
        
    def build(self,input_shape):
        d = int(input_shape[-1])
        self._regularizer = return_regularizer(self._reg_type,self._reg_size,d)      
        
        self.kernel = self.add_weight(name="kernel",
                                      shape=[d,self._units],
                                      initializer = self._kernel_initializer,
                                      trainable=True)
        self.bias = self.add_weight(name="bias",
                                    shape=[self._units],
                                    initializer=self._bias_initializer,
                                    trainable=True)
    @tf.function
    def call(self,inputs,training=False):
        preactivation = tf.matmul(inputs,self.kernel) + self.bias
        activation = self._act(preactivation)
        activation = self._drop(activation,training)
        return activation

    @tf.function    
    def predict(self,inputs):
        preactivation = tf.matmul(inputs,self.kernel) + self.bias
        activation = self._act(preactivation)
        activation = self._drop(activation,False)
        return activation
    
    @tf.function
    def calc_reg(self):
        return self._regularizer(self.kernel)