import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import metrics
from lib.tf import layers as layers
from lib.tf import tf_utils
    
class DNN(keras.Model):

    def __init__(self,
                 num_units=[],
                 num_output=1,
                 activation= "relu",
                 output_act="linear",
                 dropout = 0.0,
                 reg_type="none",
                 reg_size= 0.0,
                 seed = None,
                 **kwargs):    
        super(DNN,self).__init__(**kwargs)
        
        self._config = {"num_units":num_units,"num_output":num_output,"activation":activation,"output_act":output_act,"dropout":dropout,
                        "reg_type":reg_type,"reg_size":reg_size,"seed":seed}

        self._fcn_layers = keras.Sequential()
        for no, unit in enumerate(num_units):
            seed_value = seed+no+1 if seed is not None else None
            self._fcn_layers.add(layers.DenseDrop(unit,activation=activation,drop_rate=dropout,seed=seed_value,reg_type = reg_type, reg_size= reg_size))

        self._link_layer = layers.Dense(num_output,activation=output_act)
        
    @tf.function
    def call(self, inputs, training = True):
        output = self._fcn_layers(inputs,training)
        output = self._link_layer(output)
        return output 

    @tf.function
    def predict(self,inputs):
        output = self._fcn_layers(inputs,False)
        output = self._link_layer(output)
        return output 

    def return_config(self):
        return {key:value for key,value in self._config.items()}
    
################################################################################################################################
class SINNModel(keras.Model):
        
    def __init__(self,
                 ## TL관련
                 intv_units,intv_act="softsign",beta1=4.0,beta2=1.0,beta3=1.0,intv_drop=0.0, std=0.5, intv_l1=0.0, intv_l2=0.0, intv_reg="none",intv_lambda=0.0,
                 # Reduce 관련
                 rdc_param=True,rdc_units=None,rdc_act="linear",rdc_method="max", rdc_reg="none", rdc_lambda=0.0,
                 # FCN관련
                 fcn_units=[], fcn_act="swish", fcn_drop=[], fcn_reg="none", fcn_lambda=0.0,
                 # Link관련 : fcn의 reg_type을 자동으로 따라감
                 output_act = None, num_output=1,seed=None, 
                 # 기타학습등
                 train_type="default", noise_type="none", noise_size=1e-3,
                 **kwargs):
        super(SINNModel,self).__init__(**kwargs)
        self._config = {
            # IntervalLayer관련
            "intv_units":intv_units,"intv_act":intv_act,"beta1":beta1,"beta2":beta2,"beta3":beta3,"intv_drop":intv_drop,"std":std,
            "intv_l1":intv_l1,"intv_l2":intv_l2,"intv_reg":intv_reg,"intv_lambda":intv_lambda,
            # Reduce 관련
            "rdc_param":rdc_param,"rdc_units":rdc_units,"rdc_act":rdc_act,"rdc_method":rdc_method, "rdc_reg":rdc_reg, "rdc_lambda":rdc_lambda,
            # FCN관련
            "fcn_units":fcn_units, "fcn_act":fcn_act, "fcn_drop":fcn_drop,"fcn_reg":fcn_reg,"fcn_lambda":fcn_lambda,
            # Link관련
            "output_act":output_act, "num_output":num_output,"seed":seed,
            # 기타 확습관련
            "train_type":train_type,  "noise_type":noise_type, "noise_size":noise_size
        }
        ##### Interval Layer   
        self._intv_layer = layers.IntervalLayer(
            units=intv_units, activation=intv_act, beta1=beta1, beta2=beta2, beta3=beta3, drop=intv_drop,
            exp_type="id", mean=0.0,std=std, 
            reg_type = intv_reg, reg_size= intv_lambda,
            seed=seed, name="SIL"
        )
        self._intv_l1 = tf.Variable(intv_l1,False,name="lambda_l1")
        self._intv_l2 = tf.Variable(intv_l2,False,name="lambda_l2")

        ##### Reduce Layer
        if rdc_param:
            self._rdc_layer = layers.ParameterizedReduceLayer(intv_units,intv_units if rdc_units is None else rdc_units,activation=rdc_act,drop=0.0,
                                                                 reg_type = rdc_reg, reg_size= rdc_lambda,
                                                                 method=rdc_method,seed=seed,name="RDC")
        else:
            self._rdc_layer = layers.ReduceLayer(method=rdc_method,activation=rdc_act,name="RDC")
        
        ##### FCN Layer
        assert len(fcn_units) == len(fcn_drop)
        self._fcn_layers = keras.Sequential(name="FCN")
        for no, (fcn_unit,fcn_drop) in enumerate(zip(fcn_units,fcn_drop)):
            seed_value = seed+no+1 if seed is not None else None
            self._fcn_layers.add(layers.DenseDrop(fcn_unit,activation=fcn_act,drop_rate = fcn_drop,seed = seed_value,reg_type = fcn_reg, reg_size= fcn_lambda))
        
        ##### Link function
        self._link_layer = layers.DenseDrop(num_output,activation=output_act,seed=seed,drop_rate=0.0,reg_type = fcn_reg, reg_size= fcn_lambda, name="LINK")
        
        ##### Other settings
        self._norm_metrics = []
        if train_type =="default":
            self._ts_func = self._default_ts            
        elif train_type =="random":
            self._ts_func = self._random_ts                
        elif train_type =="adv":
            self._ts_func = self._adv_ts
        elif train_type =="vat":
            self._ts_func = self._vat_ts
        else:
            raise ValueError
            
        if noise_type=="input":
            self.forward_to = self.forward_to_input
            self.forward_from = self.forward_from_input
            self._normalize_func = tf_utils.get_normalized_vector
        elif noise_type=="matrix":
            self.forward_to = self.forward_to_matrix
            self.forward_from = self.forward_from_matrix
            self._normalize_func = tf_utils.get_normalized_matrix
        elif noise_type=="vector":
            self.forward_to = self.forward_to_vector
            self.forward_from = self.forward_from_vector
            self._normalize_func = tf_utils.get_normalized_vector
        else:
            self.forward_to = self.forward_to_input
            self.forward_from = self.forward_from_input
            self._normalize_func = tf_utils.get_normalized_vector
            
        self._noise_size = noise_size
        self._W_init = {}
    
#     # Its completly okay to override fit function: create baseline vector in order for feature importance
#     def fit(self,X,y,**kwargs):
#         self.n, self.d = X.shape        
#         super(SINNModel,self).__init__(X,y,**kwargs)
    
    def build_norm_metrics(self):
        for var in self.trainable_variables:
            self._norm_metrics.append(tf.keras.metrics.Mean(name=var.name))
            
    @tf.function
    def call(self,inputs,training):
        tl_output = self._intv_layer(inputs,training)
        reduced_output = self._rdc_layer(tl_output,training)
        fcn_output = self._fcn_layers(reduced_output,training)
        final_output = self._link_layer(fcn_output)
        return final_output

    @tf.function
    def predict(self,inputs):
        tl_output = self._intv_layer(inputs,False)
        reduced_output = self._rdc_layer(tl_output,False)    
        fcn_output = self._fcn_layers(reduced_output,False)
        final_output = self._link_layer(fcn_output,False)
        return final_output
    
    @tf.function
    def calc_interval_loss(self):
        center_diff = self._intv_layer._center_left - self._intv_layer._center_right

#         l1_sum = tf.reduce_sum(tf.math.abs(self._intv_layer.kernel)+tf.math.abs(center_diff)+tf.math.abs(self._intv_layer._bias))
#         l2_sum = tf.reduce_sum(tf.math.square(self._intv_layer.kernel)+tf.math.square(center_diff)+tf.math.square(self._intv_layer._bias))
        l1_sum = tf.math.abs(center_diff)
        l2_sum = tf.math.square(center_diff)

        intv_loss = self._intv_l1*l1_sum + self._intv_l2*l2_sum

        return intv_loss
    
    @tf.function    
    def train_step(self,data):
        result = self._ts_func(data)
        return result
    
    @tf.function
    def _default_ts(self,data):
        X,y = data
        with tf.GradientTape() as tape:
            y_pred = self(X, training=True)
            loss = self.compiled_loss(y,y_pred) # compiled_loss로 안하니까 compile에 적용된 metrics가 적용이안댐
            loss += self.calc_interval_loss()
            loss += self._intv_layer.calc_reg()            
            loss += self._rdc_layer.calc_reg()
            for layer in self._fcn_layers.layers:
                loss += layer.calc_reg()
            loss += self._link_layer.calc_reg()            
        if self.trainable_weights:
            trainable_weights = self.trainable_weights
            gradients = tape.gradient(loss,trainable_weights)
            self.optimizer.apply_gradients(zip(gradients,trainable_weights))
            [metric.update_state(tf.math.abs(grad)) for metric,grad in zip(self._norm_metrics,gradients)] # Gradient Update하기
        else:
            tf.print("WARNING : The model has no trainable weigths")
        self.compiled_metrics.update_state(y, y_pred) # compiled_metrics로 안하니까 compile에 적용된 metrics가 적용이안댐
        
        result = {m.name: m.result() for m in self.metrics}
        return result
    
    @tf.function
    def _adv_ts(self,data): # adv_ts_matrix
        X,y = data
        with tf.GradientTape() as tape:
            input_to = self.forward_to(X,training=True) # dropout후에 add_noise임
            logit = self.forward_from(input_to,training=True)
            with tape.stop_recording():
                at_noise = self.gen_at_noise((input_to,y))
            logit_at = self.forward_from(input_to+at_noise,training=True)
            loss = self.compiled_loss(y,logit) + self.compiled_loss(y,logit_at)
            loss += self.calc_interval_loss()
            loss += self._rdc_layer.calc_reg()
            loss += self._intv_layer.calc_reg()
            for layer in self._fcn_layers.layers:
                loss += layer.calc_reg()
            loss += self._link_layer.calc_reg()
            
        gradients = tape.gradient(loss,self.trainable_variables)

        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        self.compiled_metrics.update_state(y, logit)
        result = {m.name: m.result() for m in self.metrics}        
        return result
    
    @tf.function
    def _vat_ts(self,data):
        X,y = data
        with tf.GradientTape() as tape:
            input_to = self.forward_to(X,training=True) # dropout후에 add_noise임
            logit = self.forward_from(input_to,training=True)
            with tape.stop_recording():
                vat_noise = self.gen_vat_noise((input_to,logit))
            logit_vat = self.forward_from(input_to+vat_noise,training=True)
            loss = self.compiled_loss(y,logit) + self.compiled_loss(y,logit_vat)
            loss += self.calc_interval_loss()
            loss += self._rdc_layer.calc_reg()
            loss += self._intv_layer.calc_reg()
            for layer in self._fcn_layers.layers:
                loss += layer.calc_reg()       
            loss += self._link_layer.calc_reg()
        gradients = tape.gradient(loss,self.trainable_variables)

        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        self.compiled_metrics.update_state(y, logit)
        result = {m.name: m.result() for m in self.metrics}        
        return result
    
    def gen_vat_noise(self,inputs,num_iter=3):
        X,logit = inputs
        d = self._noise_size*self._normalize_func(X)
        for _ in range(num_iter):
            with tf.GradientTape() as tape:
                tape.watch([X,d])
                X_d = X+d
                logit_vat = self.forward_from(X_d,True)
                kld_loss_val = tf_utils.kld_loss_func(logit,logit_vat)
            d = tape.gradient(kld_loss_val,X_d)
            d = self._noise_size*self._normalize_func(d)
        return d
    
    def gen_at_noise(self,inputs):
        X,y = inputs
        with tf.GradientTape() as tape:
            tape.watch(X)
            logit = self.forward_from(X,training=False)
            loss_value = self.loss(y,logit)
        grad = tape.gradient(loss_value,X)
#         adv_noise = self._noise_size*tf.math.sign(grad)
        adv_noise = self._noise_size*self._normalize_func(grad)
        return adv_noise

    def gen_random_noise(self,inputs):
        noise = tf.random.normal(inputs.shape,dtype=inputs.dtype)
        return self._normalize_func(noise)*self._noise_size
    
    # inputvector을 input으로함
    def forward_to_input(self,inputs,training=False):
        return inputs
    
    def forward_from_input(self,inputs,training=False):
        tl_output = self._intv_layer(inputs,training)
        rdc_output = self._rdc_layer(tl_output,training)
        fcn_output = self._fcn_layers(rdc_output,training)
        final_output = self._link_layer(fcn_output)
        return final_output        
    
    # Matrix을 input으로함
    def forward_to_matrix(self,inputs,training=False):
        tl_output = self._intv_layer(inputs,training)
        return tl_output
       
    def forward_from_matrix(self,inputs,training=False):
        rdc_output = self._rdc_layer(inputs,training)
        fcn_output = self._fcn_layers(rdc_output,training)
        final_output = self._link_layer(fcn_output)
        return final_output

    # Encoding을 input으로함
    def forward_to_vector(self,inputs,training=False):
        tl_output = self._intv_layer(inputs,training)
        rdc_output = self._rdc_layer(tl_output,training)
        return rdc_output
       
    def forward_from_vector(self,inputs,training=False):
        fcn_output = self._fcn_layers(inputs,training)
        final_output = self._link_layer(fcn_output)
        return final_output    
    
    def return_config(self):
        return {key:value for key,value in self._config.items()}
    
################################################################################################################################
class SINNModel2(keras.Model):
        
    def __init__(self,
                 ## TL관련
                 intv_units,intv_act="softsign",beta1=4.0,beta2=4.0,intv_drop=0.0, std=0.5, intv_l1=0.0, intv_l2=0.0, intv_reg="none",intv_lambda=0.0,
                 # Reduce 관련
                 rdc_param=True,rdc_units=None,rdc_act="linear",rdc_method="max", rdc_reg="none", rdc_lambda=0.0,
                 # FCN관련
                 fcn_units=[], fcn_act="swish", fcn_drop=[], fcn_reg="none", fcn_lambda=0.0,
                 # Link관련 : fcn의 reg_type을 자동으로 따라감
                 output_act = None, num_output=1,seed=None, 
                 # 기타학습등
                 train_type="default", noise_type="none", noise_size=1e-3,
                 **kwargs):
        super(SINNModel,self).__init__(**kwargs)
        self._config = {
            # IntervalLayer관련
            "intv_units":intv_units,"intv_act":intv_act,"beta1":beta1,"beta2":beta2,"intv_drop":intv_drop,"std":std,
            "intv_l1":intv_l1,"intv_l2":intv_l2,"intv_reg":intv_reg,"intv_lambda":intv_lambda,
            # Reduce 관련
            "rdc_param":rdc_param,"rdc_units":rdc_units,"rdc_act":rdc_act,"rdc_method":rdc_method, "rdc_reg":rdc_reg, "rdc_lambda":rdc_lambda,
            # FCN관련
            "fcn_units":fcn_units, "fcn_act":fcn_act, "fcn_drop":fcn_drop,"fcn_reg":fcn_reg,"fcn_lambda":fcn_lambda,
            # Link관련
            "output_act":output_act, "num_output":num_output,"seed":seed,
            # 기타 확습관련
            "train_type":train_type,  "noise_type":noise_type, "noise_size":noise_size
        }
        ##### SIL
        self._sil_layers = keras.Sequential(name="SIL")
        for no in range(1):
            self._sil_layers.add(layers.SIL(intv_units,intv_units,intv_act,beta1,beta2,intv_drop,0.0,intv_lambda=intv_l1,
                                               rdc_method=rdc_method,reg_type=rdc_reg,reg_size=rdc_lambda,seed=seed)
                                )
        
        ##### FCN
        assert len(fcn_units) == len(fcn_drop)
        
        self._fcn_layers = keras.Sequential(name="FCN")
        for no, (fcn_unit,fcn_drop) in enumerate(zip(fcn_units,fcn_drop)):
            seed_value = seed+no+1 if seed is not None else None
            self._fcn_layers.add(layers.DenseDrop(fcn_unit,activation=fcn_act,drop_rate = fcn_drop,seed = seed_value,reg_type = fcn_reg, reg_size= fcn_lambda))
        
        ##### Link
        self._link_layer = layers.DenseDrop(num_output,activation=output_act,seed=seed,drop_rate=0.0,reg_type = fcn_reg, reg_size= fcn_lambda, name="LINK")
        
        ##### Other settings
        self._norm_metrics = []
        if train_type =="default":
            self._ts_func = self._default_ts            
        elif train_type =="random":
            self._ts_func = self._random_ts                
        elif train_type =="adv":
            self._ts_func = self._adv_ts
        elif train_type =="vat":
            self._ts_func = self._vat_ts
        else:
            raise ValueError
            
        if noise_type=="input":
            self.forward_to = self.forward_to_input
            self.forward_from = self.forward_from_input
            self._normalize_func = tf_utils.get_normalized_vector
        elif noise_type=="matrix":
            self.forward_to = self.forward_to_matrix
            self.forward_from = self.forward_from_matrix
            self._normalize_func = tf_utils.get_normalized_matrix
        elif noise_type=="vector":
            self.forward_to = self.forward_to_vector
            self.forward_from = self.forward_from_vector
            self._normalize_func = tf_utils.get_normalized_vector
        else:
            self.forward_to = self.forward_to_input
            self.forward_from = self.forward_from_input
            self._normalize_func = tf_utils.get_normalized_vector
            
        self._noise_size = noise_size
        self._W_init = {}
            
    def build_norm_metrics(self):
        for var in self.trainable_variables:
            self._norm_metrics.append(tf.keras.metrics.Mean(name=var.name))
            
    @tf.function
    def call(self,inputs,training):
        sil_output = self._sil_layers(inputs,training)
        fcn_output = self._fcn_layers(sil_output,training)
        final_output = self._link_layer(fcn_output)
        return final_output

    @tf.function
    def predict(self,inputs):
        sil_output = self._sil_layers(inputs,False)
        fcn_output = self._fcn_layers(sil_output,False)
        final_output = self._link_layer(fcn_output,False)
        return final_output
    
    @tf.function    
    def train_step(self,data):
        result = self._ts_func(data)
        return result
    
    @tf.function
    def _default_ts(self,data):
        X,y = data
        with tf.GradientTape() as tape:
            y_pred = self(X, training=True)
            loss = self.compiled_loss(y,y_pred) # compiled_loss로 안하니까 compile에 적용된 metrics가 적용이안댐
            for layer in self._sil_layers.layers: loss += layer.calc_reg()
            for layer in self._fcn_layers.layers: loss += layer.calc_reg()
            loss += self._link_layer.calc_reg()            
        if self.trainable_weights:
            trainable_weights = self.trainable_weights
            gradients = tape.gradient(loss,trainable_weights)
            self.optimizer.apply_gradients(zip(gradients,trainable_weights))
            [metric.update_state(tf.math.abs(grad)) for metric,grad in zip(self._norm_metrics,gradients)] # Gradient Update하기
        else:
            tf.print("WARNING : The model has no trainable weigths")
        self.compiled_metrics.update_state(y, y_pred) # compiled_metrics로 안하니까 compile에 적용된 metrics가 적용이안댐
        
        result = {m.name: m.result() for m in self.metrics}            
        return result
    
    @tf.function
    def _adv_ts(self,data): # adv_ts_matrix
        X,y = data
        with tf.GradientTape() as tape:
            input_to = self.forward_to(X,training=True) # dropout후에 add_noise임
            logit = self.forward_from(input_to,training=True)
            with tape.stop_recording():
                at_noise = self.gen_at_noise((input_to,y))
            logit_at = self.forward_from(input_to+at_noise,training=True)
            loss = self.compiled_loss(y,logit) + self.compiled_loss(y,logit_at)
            for layer in self._sil_layers.layers: loss += layer.calc_reg()
            for layer in self._fcn_layers.layers: loss += layer.calc_reg()
            loss += self._link_layer.calc_reg()  
            
        gradients = tape.gradient(loss,self.trainable_variables)

        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        self.compiled_metrics.update_state(y, logit)
        result = {m.name: m.result() for m in self.metrics}        
        return result
    
    @tf.function
    def _vat_ts(self,data):
        X,y = data
        with tf.GradientTape() as tape:
            input_to = self.forward_to(X,training=True) # dropout후에 add_noise임
            logit = self.forward_from(input_to,training=True)
            with tape.stop_recording():
                vat_noise = self.gen_vat_noise((input_to,logit))
            logit_vat = self.forward_from(input_to+vat_noise,training=True)
            loss = self.compiled_loss(y,logit) + self.compiled_loss(y,logit_vat)
            for layer in self._sil_layers.layers: loss += layer.calc_reg()
            for layer in self._fcn_layers.layers: loss += layer.calc_reg()
            loss += self._link_layer.calc_reg()  
        gradients = tape.gradient(loss,self.trainable_variables)

        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        self.compiled_metrics.update_state(y, logit)
        result = {m.name: m.result() for m in self.metrics}        
        return result
    
    def gen_vat_noise(self,inputs,num_iter=3):
        X,logit = inputs
        d = self._noise_size*self._normalize_func(X)
        for _ in range(num_iter):
            with tf.GradientTape() as tape:
                tape.watch([X,d])
                X_d = X+d
                logit_vat = self.forward_from(X_d,True)
                kld_loss_val = tf_utils.kld_loss_func(logit,logit_vat)
            d = tape.gradient(kld_loss_val,X_d)
            d = self._noise_size*self._normalize_func(d)
        return d
    
    def gen_at_noise(self,inputs):
        X,y = inputs
        with tf.GradientTape() as tape:
            tape.watch(X)
            logit = self.forward_from(X,training=False)
            loss_value = self.loss(y,logit)
        grad = tape.gradient(loss_value,X)
#         adv_noise = self._noise_size*tf.math.sign(grad)
        adv_noise = self._noise_size*self._normalize_func(grad)
        return adv_noise

    def gen_random_noise(self,inputs):
        noise = tf.random.normal(inputs.shape,dtype=inputs.dtype)
        return self._normalize_func(noise)*self._noise_size
    
    # inputvector을 input으로함
    def forward_to_input(self,inputs,training=False):
        return inputs
    
    def forward_from_input(self,inputs,training=False):
        tl_output = self._intv_layer(inputs,training)
        rdc_output = self._rdc_layer(tl_output,training)
        fcn_output = self._fcn_layers(rdc_output,training)
        final_output = self._link_layer(fcn_output)
        return final_output        
    
    # Matrix을 input으로함
    def forward_to_matrix(self,inputs,training=False):
        tl_output = self._intv_layer(inputs,training)
        return tl_output
       
    def forward_from_matrix(self,inputs,training=False):
        rdc_output = self._rdc_layer(inputs,training)
        fcn_output = self._fcn_layers(rdc_output,training)
        final_output = self._link_layer(fcn_output)
        return final_output

    # Encoding을 input으로함
    def forward_to_vector(self,inputs,training=False):
        tl_output = self._intv_layer(inputs,training)
        rdc_output = self._rdc_layer(tl_output,training)
        return rdc_output
       
    def forward_from_vector(self,inputs,training=False):
        fcn_output = self._fcn_layers(inputs,training)
        final_output = self._link_layer(fcn_output)
        return final_output    
    
    def return_config(self):
        return {key:value for key,value in self._config.items()}
