from NeuraLayout_Models_2      import *
import time
import numpy                 as np
import tensorflow            as tf
import tensorflow_addons     as tfa
import pickle


def get_persistence_diagram(DX,mel,dim,card):   # numpy function
    ids = Rips(DX, mel, dim, card) 
    ids = np.array(ids).reshape(2*card,2)
    
    if dim > 0:
        dgm = DX[ids[:,0],ids[:,1]].reshape(card,2)
    else:
        ids = np.reshape(ids, [2*card,2])[1::2,:]
        dgm = np.concatenate([np.zeros((card,1)), np.reshape(DX[ids[:,0],ids[:,1]],(card,1))],axis=1)
    
    return dgm

# obtain adjacency list from persistence diagram
def get_normalized_adj_mtx(DX,dgm,lam):          # numpy function
    # inputs:
    # -- DX: distance matrix, (N,N) tensor
    # output:
    # -- adj_list: (M,dim) tensor of all point indices
    
    A = DX.copy()
    start_filt = dgm[:,0]

    filt = np.min(start_filt[start_filt!=0])*lam + np.max(start_filt[start_filt!=0])*(1-lam)
    
    print('starting filtration value:',filt)

    A[DX<filt] = 1.
    A[DX>=filt] = 0.
    
    # normalized adjacency mtx, A = A + I
    A_N = A + np.eye(A.shape[0])
    
    # sqrt of degree mtx D**(-1/2)
    degree_norm = np.diag(1/np.sqrt(A_N.sum(axis=0)))
                          
    # degree normalized adjacency matrix f(A) = D**(-1/2) A D**(-1/2)
    DAD = degree_norm @ A_N @ degree_norm
    
    return DAD

def converge(losses,thresh):
    return np.abs(max(losses)-min(losses))/np.abs(min(losses)) < thresh


class ModelTrain:
    def __init__(self,model_name,arg_dict,X_init,file_dir,r=1.0):
        # -- model_name: string, name of the model class
        # -- arg_dict: dictionary storing model parameters
        #    -- 'Rips_linear_1L': 'mel', 'dim', 'card'
        #    -- 'Rips_linear': 'hidden_dims','mel', 'dim', 'card'
        #    -- 'RipsGCN': 'hidden_dims','mel', 'dim', 'card','lam'
        # -- X_init: initial point cloud, numpy array
        # -- file_dir: string, name of the directory training results
        assert isinstance(model_name,str)
        assert isinstance(arg_dict,dict)
        assert isinstance(X_init,np.ndarray)
        assert isinstance(file_dir,str)
        assert r > 0
        
        self.X_init = X_init
        self.model_name = model_name
        self.model = None
        self.file_dir = file_dir
        self.lam = None
        self.adj_mtx_N = None
        self.r = r
        self.card = arg_dict['card']
        if model_name == 'Rips_linear_1L':
            self.model = Rips_linear_1L(X_init, arg_dict['mel'], arg_dict['dim'], arg_dict['card'])
        elif model_name == 'Rips_linear':
            self.model = Rips_linear(X_init, arg_dict['hidden_dims'], arg_dict['mel'], arg_dict['dim'], arg_dict['card'], self.r)
            
        elif model_name == 'Rips_GCN':
            self.model = Rips_GCN(X_init, arg_dict['hidden_dims'], arg_dict['mel'], arg_dict['dim'], arg_dict['card'], self.r)
            self.lam = arg_dict['lam']
            initial_DX = np.sqrt(np.sum((np.expand_dims(self.X_init,1) - np.expand_dims(self.X_init,0))**2,2))   
            initial_dgm = get_persistence_diagram(initial_DX,mel=12.,dim=1,card=self.card)
            self.adj_mtx_N = tf.convert_to_tensor(get_normalized_adj_mtx(initial_DX,initial_dgm,self.lam))
        else:
            raise ValueError('Invalid Model Name!!!')
            
        print('type(self.model):',type(self.model))
        
    def prefit(self,mse_thresh,max_train_epoch):
        # inputs
        # -- mse_thresh: threshold of mse error
        # -- max_train_epoch: maximum number of epoch for prefitting run
        # outputs
        # -- losses: list of sum of mse losses
        # -- output: final output point cloud
        # -- runtime: runtime for entire prefitting process
       
        # ADAM
        lr = 0.1
       
        optimizer = tf.keras.optimizers.Adam(
            learning_rate=lr, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False,
            name='Adam'
        )
        

        losses = []
        runtime_list = []
        output_list = []
        dummy_input = tf.eye(self.X_init.shape[0])
        X = tf.Variable(initial_value=self.X_init)

        start_time = time.time()
        for epoch in range(max_train_epoch+1):
            print('epoch number:',epoch)

            with tf.GradientTape() as tape:

                if self.model_name == 'Rips_GCN':
                    X_out, dgm = self.model(dummy_input,tf.cast(self.adj_mtx_N,dtype=tf.float32))
                else:
                    X_out, dgm = self.model(dummy_input)

                loss = tf.keras.metrics.mean_squared_error(X, X_out)

            if loss.numpy().sum() < mse_thresh:
                print('Converged at epoch:', epoch)
                break 
            
            ##### scale network weights to match the scale of point cloud #####
            if epoch == 0: output_list.append(X_out.numpy())

                    
            gradients = tape.gradient(loss, self.model.trainable_variables)   # gradient dL / dX
            optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
            losses.append(loss.numpy().sum())
            print('loss:',loss.numpy().sum())
            
            runtime_list.append(time.time()-start_time)

       
        output_list.append(X_out.numpy())

        print('Training with CPU: {:3f} s'.format(runtime_list[-1]))

        # e.g. file_dir = 'model_results/GCN_2L_h864L8'
        with open(self.file_dir+'_prefit_result.pickle','wb') as f:
            pickle.dump([losses,output_list,runtime_list],f)

        self.model.save_weights(self.file_dir+'_init_weights')

        return losses,output_list,runtime_list


    def train_regular(self,max_train_epoch=1000,
                      conv_patience=100,conv_thresh=0.03,conv_val=-100):
        # inputs
        # -- max_train_epoch: maximum number of epoch for training run
        # -- conv_patience: patience of convergence criteria
        # -- conv_thresh: threshold of convergence criteria
        # -- conv_val: model converges at this value
        # outputs
        # -- losses: list of losses
        # -- runtime_list: list of runtime of every epoch
        # -- output_list: [initial output, final output]
        
        # ADAM
        lr = 0.01
        optimizer = tf.keras.optimizers.Adam(
            learning_rate=lr, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False,
            name='Adam'
        )

        losses, runtime_list, output_list = [], [], []
        dummy_input = tf.eye(self.X_init.shape[0])

        conv_losses = []

        start_time = time.time()
        for epoch in range(max_train_epoch+1):
            print('epoch number:',epoch)

            with tf.GradientTape() as tape:
                if self.model_name == 'Rips_GCN':
                    X_out, dgm = self.model(dummy_input,tf.cast(self.adj_mtx_N,dtype=tf.float32))
                else:
                    X_out, dgm = self.model(dummy_input)

                topo_loss = -tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
                dist_loss = tf.reduce_sum(tf.maximum(tf.abs(X_out)-self.r, 0)) 

                loss = topo_loss + dist_loss 
                

            gradients = tape.gradient(loss, self.model.trainable_variables)   # gradient dL / dX
            optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
            losses.append(loss.numpy())
            runtime_list.append(time.time()-start_time)
            
            # save initial output for debug purpose
            if epoch == 0:
                output_list.append(X_out.numpy())
                
            conv_losses.append(loss.numpy())
            
            if len(conv_losses) > conv_patience:
                conv_losses.pop(0)
                if converge(conv_losses,conv_thresh) or loss <= conv_val:
                    print('model converges at epoch:',epoch)
                    break


        end_time = time.time()
        print('Training with CPU: {:3f} s'.format(end_time-start_time))
        
        output_list.append(X_out.numpy())
        
        # store training results 
        with open(self.file_dir+'_train_result_regular.pickle','wb') as f:
            pickle.dump([losses,runtime_list,output_list],f)

        return losses,runtime_list,output_list

    def train_hybrid(self,max_train_epoch=1000,
                     conv_patience=100,conv_thresh=0.03,conv_val=-100,
                     switch_patience=15):
        # inputs
        # -- max_train_epoch: maximum number of epoch for training run
        # -- conv_patience: patience of convergence criteria
        # -- conv_thresh: threshold of convergence criteria
        # -- switch_patience: patience of switching criteria
        # outputs
        # -- losses: list of losses
        # -- runtime_list: list of runtime of every epoch
        # -- output_list: [initial_output,output_before_switching,final_output]
        # -- switch_point: epoch number when model switches
        
        # ADAM
        lr = 0.01
        optimizer = tf.keras.optimizers.Adam(
            learning_rate=lr, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False,
            name='Adam'
        )

        losses, runtime_list, output_list = [], [], []
        switch_point = None
        dummy_input = tf.eye(self.X_init.shape[0])

        switch_losses = list(np.ones(switch_patience)*5)
        moving_avg = np.array(switch_losses).mean()

        start_time = time.time()
        for epoch in range(max_train_epoch+1):

            print('epoch number:',epoch)

            with tf.GradientTape() as tape:
                if self.model_name == 'Rips_GCN':
                    X_out, dgm = self.model(dummy_input,tf.cast(self.adj_mtx_N,dtype=tf.float32))
                else:
                    X_out, dgm = self.model(dummy_input)

                topo_loss = -tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0])))
                dist_loss = tf.reduce_sum(tf.maximum(tf.abs(X_out)-self.r, 0)) 
                loss = topo_loss + dist_loss
                

            gradients = tape.gradient(loss, self.model.trainable_variables)   # gradient dL / dX
            optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
            losses.append(loss.numpy())
            runtime_list.append(time.time()-start_time)
            
            # save initial output for debug purpose
            if epoch == 0:
                output_list.append(X_out.numpy())
                
            switch_losses.append(loss.numpy())
            switch_losses.pop(0)
            avg = np.array(switch_losses).mean()
            if avg > moving_avg:
                n = epoch + 1
                switch_point = epoch
                print('switch model at epoch:',epoch)
                model = Rips_linear_1L(X_init=X_out.numpy(), mel=12., dim=1, card=self.card)
                output_list.append(X_out.numpy())
                break
            moving_avg = avg
        
        if switch_point is not None:
            conv_losses = list(np.zeros(conv_patience))
            if len(losses) >= conv_patience:
                conv_losses = losses[-conv_patience:]
            else:
                conv_losses[-len(losses):] = losses


            for epoch in range(max_train_epoch+1-n):
                print('epoch number:',epoch+n)

                with tf.GradientTape() as tape:

                    X_out, dgm = model(dummy_input)
                    topo_loss = -tf.math.reduce_sum(tf.square(.5*(dgm[:,1]-dgm[:,0]))) 
                    dist_loss = tf.reduce_sum(tf.maximum(tf.abs(X_out)-self.r, 0)) 
                    loss = topo_loss + dist_loss 

                gradients = tape.gradient(loss, model.trainable_variables)   # gradient dL / dX

                optimizer.apply_gradients(zip(gradients, model.trainable_variables))
                losses.append(loss.numpy())
                runtime_list.append(time.time()-start_time)
                
                conv_losses.append(loss.numpy())
                conv_losses.pop(0)
                if converge(conv_losses,conv_thresh) or loss <= conv_val:
                    print('model converges at epoch:',epoch)
                    break

        output_list.append(X_out.numpy())
        end_time = time.time()
        print('Training with CPU: {:3f} s'.format(end_time-start_time))
        
        # store training results 
        with open(self.file_dir+'_train_result_hybrid.pickle','wb') as f:
            pickle.dump([losses,runtime_list,output_list,switch_point],f)

        return losses,runtime_list,output_list,switch_point
    
    def load_model_weights(self,directory):
        self.model.load_weights(directory)