from Rips_Model              import *
import random
import time
import numpy                 as np
import tensorflow            as tf
import tensorflow_addons     as tfa


############################
#      Original Model      #
############################
class Rips_linear_1L(tf.keras.Model):
    def __init__(self, X_init, mel, dim, card):
        super(Rips_linear_1L, self).__init__()
        self.linear = tf.keras.models.Sequential()
        self.linear.add(tf.keras.Input(shape=(X_init.shape[0],)))
        self.linear.add(tf.keras.layers.Dense(X_init.shape[1], use_bias= False))
        self.linear.set_weights([X_init])
        self.Rips_model = RipsModel2(mel, dim, card)
    
    def call(self,inp):
        output = self.linear(inp)
        dgm = self.Rips_model(output)
        
        return output,dgm
    

############################
#      DeepFDL             #
############################
class Rips_linear(tf.keras.Model):
    def __init__(self, X_init, hidden_dims, mel, dim, card, r):
        super(Rips_linear, self).__init__()
        self.initializer = tf.keras.initializers.GlorotUniform()
        self.linear = tf.keras.models.Sequential()
        self.linear.add(tf.keras.Input(shape=(X_init.shape[0],)))
        
        for h_dim in hidden_dims:
            # dense layer with activation
            self.linear.add(tf.keras.layers.Dense(h_dim,use_bias=False,activation='tanh'))
        
        self.linear.add(tf.keras.layers.Dense(X_init.shape[1], use_bias= False))
        self.Rips_model = RipsModel2(mel, dim, card)
    
    def call(self,inp):
        output = self.linear(inp)
        dgm = self.Rips_model(output)
        
        return output,dgm

    
############################
#      GCN                 #
############################    
class GCN(tf.keras.Model):
    def __init__(self, input_dim, output_dim, r):
        super(GCN, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
#         self.dense = tf.keras.layers.Dense(output_dim,use_bias=False,kernel_initializer=tf.keras.initializers.RandomUniform(minval=-r,maxval=r))
        self.dense = tf.keras.layers.Dense(output_dim,use_bias=False,kernel_initializer=tf.keras.initializers.GlorotUniform())
#         self.dense = tf.keras.layers.Dense(output_dim,use_bias=False,kernel_initializer=tf.keras.initializers.Orthogonal(gain=r))
        
    def call(self, inp, adj_mtx):
        x = adj_mtx @ inp
        x = self.dense(x)
    
        return x

class pointcloud_GCN_1L(tf.keras.Model):
    def __init__(self, X_init, hidden_dim_1, hidden_dim_2, r):
        super(pointcloud_GCN_1L, self).__init__()
        self.num_nodes = X_init.shape[0]
        self.output_dim = X_init.shape[1]
        self.hidden_dim_1 = hidden_dim_1
        self.hidden_dim_2 = hidden_dim_2
#         self.initializer = tf.keras.initializers.RandomUniform(minval=-r**(-1/3),maxval=r**(1/3))
#         self.initializer = tf.keras.initializers.RandomNormal(mean=0.0,stddev=np.sqrt(r))
        self.initializer = tf.keras.initializers.GlorotUniform()
#         self.initializer = tf.keras.initializers.Orthogonal(gain=r)
        
        
        # obtain random embedding Z from X, input (X,W0,b)
        self.dense1 = tf.keras.layers.Dense(hidden_dim_1, use_bias= False, kernel_initializer=self.initializer)
        
        # 1st GCN layer, input (A, Z, W1) 
        self.GCN1 = GCN(self.hidden_dim_1, self.hidden_dim_2, r)
        
        # FC layer, input (Z, G1, W)
        self.dense2 = tf.keras.layers.Dense(self.output_dim, use_bias= False,kernel_initializer=self.initializer)
           
    def call(self, inp, adj_mtx):
        # input --> embedding --> GCN1 --> GCN2 --> FC --> output
        
        dense_1 = self.dense1(inp)               # (N,N) --> (N,h1)
        gnn1 = self.GCN1(dense_1,adj_mtx)        # (N,h1) --> (N,h2)

        output = tf.concat((dense_1,gnn1),1)     # (N,h1+h2)
        output = self.dense2(output)             # (N,h1+h2) --> (N,out)
    
        return  output
    
class pointcloud_GCN_2L(tf.keras.Model):
    def __init__(self, X_init, hidden_dim_1, hidden_dim_2, hidden_dim_3, r):
        super(pointcloud_GCN_2L, self).__init__()
        self.num_nodes = X_init.shape[0]
        self.output_dim = X_init.shape[1]
        self.hidden_dim_1 = hidden_dim_1
        self.hidden_dim_2 = hidden_dim_2
        self.hidden_dim_3 = hidden_dim_3
        self.initializer = tf.keras.initializers.GlorotUniform()
        
        # obtain random embedding Z from X, input (X,W0,b)
        self.dense1 = tf.keras.layers.Dense(hidden_dim_1, use_bias= False, kernel_initializer=self.initializer)
        
        # 1st GCN layer, input (A, Z, W1) 
        self.GCN1 = GCN(self.hidden_dim_1, self.hidden_dim_2, r)
        
        # 2nd GCN layer, input (A, G1, W2)
        self.GCN2 = GCN(self.hidden_dim_2, self.hidden_dim_3, r)
        
        # FC layer, input (Z, G1, W)
        self.dense2 = tf.keras.layers.Dense(self.output_dim, use_bias= False,kernel_initializer=self.initializer)
           
    def call(self, inp, adj_mtx):
        # input --> embedding --> GCN1 --> GCN2 --> FC --> output
        
        dense_1 = self.dense1(inp)           # (N,N) --> (N,h1)
        gnn1 = self.GCN1(dense_1,adj_mtx)    # (N,h1) --> (N,h2)
        gnn2 = self.GCN2(gnn1,adj_mtx)       # (N,h2) --> (N,h3)

        output = tf.concat((dense_1,gnn1,gnn2),1) # (N,h1+h2)
        output = self.dense2(output)              # (N,h1+h2) --> (N,out)
        
        return  output

# integrate GCN with RipsModel
class Rips_GCN(tf.keras.Model):
    def __init__(self, X_init,hidden_dims,mel,dim,card,r):
        super(Rips_GCN, self).__init__()
        assert len(hidden_dims) == 2 or len(hidden_dims) == 3
        assert r > 0
        
        if len(hidden_dims) == 2:
            self.GCN_model = pointcloud_GCN_1L(X_init, *hidden_dims, r)
        else:
            self.GCN_model = pointcloud_GCN_2L(X_init, *hidden_dims, r)
            
        self.Rips_model = RipsModel2(mel, dim, card)
        
    def call(self,inp,adj_mtx):
        output = self.GCN_model(inp,adj_mtx)
        dgm = self.Rips_model(output)
        
        return output,dgm
