# =============================================================================
# Description
# =============================================================================
# Class to classify single molecular data using deep learning models
# =============================================================================


# =============================================================================
# Imports
# ============================================================================= 
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tslearn.utils import to_time_series_dataset
from tslearn.preprocessing import TimeSeriesScalerMinMax, TimeSeriesResampler
import tensorflow.keras as keras
import tensorflow as tf
import time
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve
# =============================================================================


#%%
class SMNN():
    
    
    #--------------------------------------------------------------------------        
    # Init
    #--------------------------------------------------------------------------  
    def __init__(self):         
        self.kb = 1.38064*(10**-23)# m^2 kg s^-2 K^-1 Boltzmann constant
    #--------------------------------------------------------------------------       
    
    
    #--------------------------------------------------------------------------        
    # Get refernce data from simulation data
    # data_pd: simulation data
    # molecule: # Choose from 'Titin', 'UtrNR3', 'DysNR3_bact', 'fewshot'
    #--------------------------------------------------------------------------  
    def get_ref_data(self, data_pd, molecule):
        
        # Build reference data 
        data_pd_ref = pd.DataFrame() 
        
        if (molecule == 'Titin'):
            no_sample_arr = np.array([16,4,1]) 
        elif (molecule == 'fewshot'):
            no_sample_arr = np.array([10,10,10]) 
        elif (molecule == 'UtrNR3') or (molecule == 'DysNR3_bact'):
            no_sample_arr = np.array([16,8,3]) 
                
        ii = 0
        for No_mol in data_pd['No_molecule'].unique():
            data_pd_cur_mol =  data_pd.loc[data_pd['No_molecule'] == No_mol]
            no_sample = no_sample_arr[ii]
            ii = ii + 1
            for ini_pos in data_pd_cur_mol['Initial_Pos'].unique():
                data_pd_cur_mol_cur_ini_pos = data_pd_cur_mol.loc[data_pd_cur_mol['Initial_Pos']==ini_pos]
                data_selected = data_pd_cur_mol_cur_ini_pos.sample( n = no_sample )
                data_pd_ref = pd.concat([data_pd_ref, data_selected])
                
        return data_pd_ref
    
    
    
    #--------------------------------------------------------------------------        
    # Resample and normalization of data
    # X_train, y_train: Data and label to process
    # sim_data: True if simulation data is used
    # resampling: True or False, perform resample or not
    # resample_size: length after resample
    # data_normalization: True or False, perform data minmax normalization or not. 
    #--------------------------------------------------------------------------  
    def resample_normalization_data(self, X_train, y_train, molecule, sim_data,
                                    resampling, resample_size, data_normalization):         
        
        # Get rid of extra zeros at the end
        if (sim_data == True):
            if (molecule != 'fewshot'):
              X_train_index = np.zeros((len(X_train),int(X_train.shape[1]-X_train.shape[1]/5)), dtype=bool)
              X_train_index = np.concatenate((X_train_index, X_train[:,-int(X_train.shape[1]/5+2):-1]==0), axis = 1)
              X_train[X_train_index] = 'nan'
        
        else:
            X_train[X_train==0] = 'nan'
            
        # Resample data
        if (resampling == True):
            X_train = to_time_series_dataset(X_train)
            X_train = TimeSeriesResampler(sz = resample_size).fit_transform(X_train)
        else:
            if len(X_train.shape) == 2:  # if univariate
                # add a dimension to make it multivariate with one dimension
                X_train = X_train.reshape((X_train.shape[0], X_train.shape[1], 1))

        # Normalize the data
        if(data_normalization == True):
            X_train = TimeSeriesScalerMinMax().fit_transform(X_train)

        return (X_train, y_train)
    #--------------------------------------------------------------------------     
    
    

    
    
    
    #--------------------------------------------------------------------------        
    # Augment reference curves to input curve
    # X_train, y_train: Data and label to process
    # data_pd_ref: reference curve dataframe
    # no_refer: the number of reference curves
    #--------------------------------------------------------------------------  
    def add_refer_data(self, X_train, y_train, data_pd_ref, no_refer, resample_size, 
                       data_normalization):
        
        X_train_ref = np.zeros((X_train.shape[0], X_train.shape[1], no_refer))
        
        
        for ii in range((len(y_train))):
            X_data_reference_df = pd.DataFrame()
            X_data_reference_df = data_pd_ref.sample(n = no_refer, replace = False)
      
            cur_refer = 0
            for jj in X_data_reference_df.index:
                X_data_reference_data = X_data_reference_df['Fwlc'][jj]*1e12
                X_data_reference_data = to_time_series_dataset(X_data_reference_data)
                X_data_reference_data = TimeSeriesResampler(sz = resample_size).fit_transform(X_data_reference_data)
                # Adding normalization to reference data
                if(data_normalization == True):
                    X_data_reference_data = TimeSeriesScalerMinMax().fit_transform(X_data_reference_data)
    
                X_data_reference_data_load = X_train[ii,:,0] - X_data_reference_data[0,:,0]
                X_train_ref[ii,:,cur_refer] = X_data_reference_data_load
                cur_refer = cur_refer + 1

        X_train = np.concatenate((X_train,X_train_ref), axis = 2)
        return X_train
    #--------------------------------------------------------------------------     

    

    
    #--------------------------------------------------------------------------        
    # Create MLP 
    # input_shape: the number of neurons in the input layer
    # nb_classes: the number of neurons in the output layer
    #--------------------------------------------------------------------------  
    def create_MLP(self, input_shape, nb_classes):
        # Create model
        input_layer = keras.layers.Input(input_shape)
        
        # flatten/reshape because when multivariate all should be on the same axis
        input_layer_flattened = keras.layers.Flatten()(input_layer)
        
        layer_1 = keras.layers.Dropout(0.1)(input_layer_flattened)
        layer_1 = keras.layers.Dense(500, activation='relu')(layer_1)
        
        layer_2 = keras.layers.Dropout(0.2)(layer_1)
        layer_2 = keras.layers.Dense(500, activation='relu')(layer_2)
        
        layer_3 = keras.layers.Dropout(0.2)(layer_2)
        layer_3 = keras.layers.Dense(500, activation='relu')(layer_3)
        
        output_layer = keras.layers.Dropout(0.3)(layer_3)
        output_layer = keras.layers.Dense(nb_classes, activation='softmax')(output_layer)
    
        return (input_layer, output_layer)
    #--------------------------------------------------------------------------  

    
    #--------------------------------------------------------------------------        
    # Create FCNN 
    # input_shape: the number of neurons in the input layer
    # nb_classes: the number of neurons in the output layer
    #--------------------------------------------------------------------------  
    def create_FCNN(self, input_shape, nb_classes, kernel_size_arr):
        # Create model
        input_layer = keras.layers.Input(input_shape)
        
        conv1 = keras.layers.Conv1D(filters=128, kernel_size=int(kernel_size_arr[0]), padding='same')(input_layer)
        conv1 = keras.layers.BatchNormalization()(conv1)
        conv1 = keras.layers.Activation(activation='relu')(conv1)
      
        conv2 = keras.layers.Conv1D(filters=256, kernel_size=int(kernel_size_arr[1]), padding='same')(conv1)
        conv2 = keras.layers.BatchNormalization()(conv2)
        conv2 = keras.layers.Activation('relu')(conv2)
      
        conv3 = keras.layers.Conv1D(128, kernel_size=int(kernel_size_arr[2]),padding='same')(conv2)
        conv3 = keras.layers.BatchNormalization()(conv3)
        conv3 = keras.layers.Activation('relu')(conv3)
          
        gap_layer = keras.layers.GlobalAveragePooling1D()(conv3)

        output_layer = keras.layers.Dense(nb_classes, activation='softmax')(gap_layer)
        return (input_layer, output_layer)
    #--------------------------------------------------------------------------  
    
    
    #--------------------------------------------------------------------------        
    # Create ResNet 
    # input_shape: the number of neurons in the input layer
    # nb_classes: the number of neurons in the output layer
    #--------------------------------------------------------------------------  
    def create_ResNet(self, input_shape, nb_classes, kernel_size_arr, n_feature_maps):
        
        # Create resnet block first, can make it deeper by repeating this block
        def resnet_model(input_layer, kernel_size_arr, n_feature_maps = 64):
          conv_x = keras.layers.Conv1D(filters=n_feature_maps, kernel_size=int(kernel_size_arr[0]), padding='same')(input_layer)
          conv_x = keras.layers.BatchNormalization()(conv_x)
          conv_x = keras.layers.Activation('relu')(conv_x)

          conv_y = keras.layers.Conv1D(filters=n_feature_maps, kernel_size=int(kernel_size_arr[1]), padding='same')(conv_x)
          conv_y = keras.layers.BatchNormalization()(conv_y)
          conv_y = keras.layers.Activation('relu')(conv_y)

          conv_z = keras.layers.Conv1D(filters=n_feature_maps, kernel_size=int(kernel_size_arr[2]), padding='same')(conv_y)
          conv_z = keras.layers.BatchNormalization()(conv_z)

          # expand channels for the sum
          shortcut_y = keras.layers.Conv1D(filters=n_feature_maps, kernel_size=1, padding='same')(input_layer)
          shortcut_y = keras.layers.BatchNormalization()(shortcut_y)

          output_block_1 = keras.layers.add([shortcut_y, conv_z])
          output_block_1 = keras.layers.Activation('relu')(output_block_1)

          # BLOCK 2
          conv_x = keras.layers.Conv1D(filters=n_feature_maps * 2, kernel_size=int(kernel_size_arr[0]), padding='same')(output_block_1)
          conv_x = keras.layers.BatchNormalization()(conv_x)
          conv_x = keras.layers.Activation('relu')(conv_x)

          conv_y = keras.layers.Conv1D(filters=n_feature_maps * 2, kernel_size=int(kernel_size_arr[1]), padding='same')(conv_x)
          conv_y = keras.layers.BatchNormalization()(conv_y)
          conv_y = keras.layers.Activation('relu')(conv_y)

          conv_z = keras.layers.Conv1D(filters=n_feature_maps * 2, kernel_size=int(kernel_size_arr[2]), padding='same')(conv_y)
          conv_z = keras.layers.BatchNormalization()(conv_z)

          # expand channels for the sum
          shortcut_y = keras.layers.Conv1D(filters=n_feature_maps * 2, kernel_size=1, padding='same')(output_block_1)
          shortcut_y = keras.layers.BatchNormalization()(shortcut_y)

          output_block_2 = keras.layers.add([shortcut_y, conv_z])
          output_block_2 = keras.layers.Activation('relu')(output_block_2)

          # BLOCK 3
          conv_x = keras.layers.Conv1D(filters=n_feature_maps * 2, kernel_size=int(kernel_size_arr[0]), padding='same')(output_block_2)
          conv_x = keras.layers.BatchNormalization()(conv_x)
          conv_x = keras.layers.Activation('relu')(conv_x)

          conv_y = keras.layers.Conv1D(filters=n_feature_maps * 2, kernel_size=int(kernel_size_arr[1]), padding='same')(conv_x)
          conv_y = keras.layers.BatchNormalization()(conv_y)
          conv_y = keras.layers.Activation('relu')(conv_y)

          conv_z = keras.layers.Conv1D(filters=n_feature_maps * 2, kernel_size=int(kernel_size_arr[2]), padding='same')(conv_y)
          conv_z = keras.layers.BatchNormalization()(conv_z)

          # no need to expand channels because they are equal
          shortcut_y = keras.layers.BatchNormalization()(output_block_2)

          output_block_3 = keras.layers.add([shortcut_y, conv_z])
          output_block_3 = keras.layers.Activation('relu')(output_block_3)

          return output_block_3


        # Create model
        input_layer = keras.layers.Input(input_shape)
        
        # Makding model deeper
        output_block_3 = resnet_model(input_layer, kernel_size_arr, n_feature_maps = n_feature_maps)

        # FINAL
        gap_layer = keras.layers.GlobalAveragePooling1D()(output_block_3)

        output_layer = keras.layers.Dense(nb_classes, activation='softmax')(gap_layer)
        return (input_layer, output_layer)
    #--------------------------------------------------------------------------  
    
    #--------------------------------------------------------------------------        
    # Create Triplet 
    # input_shape: the number of neurons in the input layer
    #--------------------------------------------------------------------------  
    def create_Triplet(self, input_shape):
        # Create embed model: the model consists of three embeded models, while parameters are the same
        def create_embedding(input_shape):
            n_feature_maps = 128#128 #64# 64
            input_layer = keras.layers.Input(input_shape)
            
            # BLOCK 1
            conv_x = keras.layers.Conv1D(filters=n_feature_maps, kernel_size=8, padding='same')(input_layer)
            conv_x = keras.layers.BatchNormalization()(conv_x)
            conv_x = keras.layers.Activation('relu')(conv_x)
            
            conv_y = keras.layers.Conv1D(filters=n_feature_maps, kernel_size=5, padding='same')(conv_x)
            conv_y = keras.layers.BatchNormalization()(conv_y)
            conv_y = keras.layers.Activation('relu')(conv_y)
            
            conv_z = keras.layers.Conv1D(filters=n_feature_maps, kernel_size=3, padding='same')(conv_y)
            conv_z = keras.layers.BatchNormalization()(conv_z)
            
            # expand channels for the sum
            shortcut_y = keras.layers.Conv1D(filters=n_feature_maps, kernel_size=1, padding='same')(input_layer)
            shortcut_y = keras.layers.BatchNormalization()(shortcut_y)
            
            output_block_1 = keras.layers.add([shortcut_y, conv_z])
            output_block_1 = keras.layers.Activation('relu')(output_block_1)
            
            # BLOCK 2
            conv_x = keras.layers.Conv1D(filters=n_feature_maps * 2, kernel_size=8, padding='same')(output_block_1)
            conv_x = keras.layers.BatchNormalization()(conv_x)
            conv_x = keras.layers.Activation('relu')(conv_x)
            
            conv_y = keras.layers.Conv1D(filters=n_feature_maps * 2, kernel_size=5, padding='same')(conv_x)
            conv_y = keras.layers.BatchNormalization()(conv_y)
            conv_y = keras.layers.Activation('relu')(conv_y)
            
            conv_z = keras.layers.Conv1D(filters=n_feature_maps * 2, kernel_size=3, padding='same')(conv_y)
            conv_z = keras.layers.BatchNormalization()(conv_z)
            
            # expand channels for the sum
            shortcut_y = keras.layers.Conv1D(filters=n_feature_maps * 2, kernel_size=1, padding='same')(output_block_1)
            shortcut_y = keras.layers.BatchNormalization()(shortcut_y)
            
            output_block_2 = keras.layers.add([shortcut_y, conv_z])
            output_block_2 = keras.layers.Activation('relu')(output_block_2)
            
            # BLOCK 3
            conv_x = keras.layers.Conv1D(filters=n_feature_maps * 2, kernel_size=8, padding='same')(output_block_2)
            conv_x = keras.layers.BatchNormalization()(conv_x)
            conv_x = keras.layers.Activation('relu')(conv_x)
            
            conv_y = keras.layers.Conv1D(filters=n_feature_maps * 2, kernel_size=5, padding='same')(conv_x)
            conv_y = keras.layers.BatchNormalization()(conv_y)
            conv_y = keras.layers.Activation('relu')(conv_y)
            
            conv_z = keras.layers.Conv1D(filters=n_feature_maps * 2, kernel_size=3, padding='same')(conv_y)
            conv_z = keras.layers.BatchNormalization()(conv_z)
            
            # no need to expand channels because they are equal
            shortcut_y = keras.layers.BatchNormalization()(output_block_2)
            
            output_block_3 = keras.layers.add([shortcut_y, conv_z])
            
            output_layer = output_block_3
            
            model = keras.models.Model(inputs=input_layer, outputs=output_layer)

            return model


        # build the anchor, positive and negative input layer
        anchorInput = keras.Input(input_shape , name="anchor")
        positiveInput = keras.Input(input_shape, name="positive")
        negativeInput = keras.Input(input_shape, name="negative")

        # Create embedding model
        embeddingModel = create_embedding(input_shape)

        # embed the anchor, positive and negative images
        anchorEmbedding = embeddingModel(anchorInput)
        positiveEmbedding = embeddingModel(positiveInput)
        negativeEmbedding = embeddingModel(negativeInput)
        # build the siamese network and return it
        siamese_network = keras.Model(
            inputs=[anchorInput, positiveInput, negativeInput],
            outputs=[anchorEmbedding, positiveEmbedding, negativeEmbedding]
        )
        
        return (embeddingModel, siamese_network)
    #--------------------------------------------------------------------------        
    
    
        
    #--------------------------------------------------------------------------        
    # Train models
    #--------------------------------------------------------------------------  
    def train_models(self, X_train, y_train_oh, model, callbacks, batch_size, 
                     nb_epochs, output_directory, file_name, fig_save_path, 
                     validation_split = 0.2, verbose = 1, diagonistic = False):
        mini_batch_size = int(min(X_train.shape[0]/10, batch_size))

        start_time = time.time()

        hist = model.fit(X_train, y_train_oh, batch_size=mini_batch_size, epochs=nb_epochs,
          validation_split = validation_split, verbose=verbose, callbacks=callbacks)

        duration = time.time() - start_time

        model.save(output_directory + file_name + '_last.keras')

        # Plotting tranining statistics
        history_dict = hist.history

        acc = history_dict['accuracy']
        loss = history_dict['loss']

        val_acc = history_dict['val_accuracy']
        val_loss = history_dict['val_loss']

        epochs = range(1, len(acc) + 1)

        # Plot loss
        if (diagonistic == True):
            plt.figure()
            plt.plot(epochs, loss, 'bo', label='Training loss')
            plt.plot(epochs, val_loss, 'ro', label='Validation loss')
            plt.title('Training and validation loss')
            plt.xlabel('Epochs')
            plt.ylabel('Loss')
            plt.legend()
            plt.show()
            plt.savefig(fname = fig_save_path + file_name + '_loss' + '.png')
            plt.savefig(fname = fig_save_path + file_name + '_loss' + '.svg')
    
            # Plot accuracy
            plt.figure()
            plt.plot(epochs, acc, 'bo', label='Training acc')
            plt.plot(epochs, val_acc, 'ro', label='Validation acc')
            plt.title('Training and validation accuracy')
            plt.xlabel('Epochs')
            plt.ylabel('Accuracy')
            plt.legend(loc='lower right')
            plt.show()
            
            plt.savefig(fname = fig_save_path + file_name + '_acc' + '.png')
            plt.savefig(fname = fig_save_path + file_name + '_acc' + '.svg')
        return (model)
    #--------------------------------------------------------------------------  


    #--------------------------------------------------------------------------        
    # Train models: different ways of training triplet models
    #--------------------------------------------------------------------------  
    def train_triplet(self, X_train, X_train_pos_all, X_train_neg_all, y_train, embed_model, model, batch_size, 
                     nb_epochs, output_directory, file_name, fig_save_path, verbose = 1, diagonistic = False):
    
        # Define loss function and training function
        def loss(model, x, y, training, margin = 10.0):
          # training=training is needed only if there are layers with different
          # behavior during training versus inference (e.g. Dropout).
          all_outputs = model(x)
    
          anchor_output = all_outputs[0]
          positive_output = all_outputs[1]
          negative_output = all_outputs[2]
    
          d_pos = tf.reduce_sum(tf.square(anchor_output - positive_output), 1)
          d_neg = tf.reduce_sum(tf.square(anchor_output - negative_output), 1)
    
    
          loss = tf.nn.relu(margin + d_pos - d_neg)
          loss = tf.reduce_mean(loss)
    
          # loss = batch_hard_triplet_loss(y, all_outputs, margin, squared=False)
    
          return loss
    
        def grad(model, inputs, targets):
          with tf.GradientTape() as tape:
            loss_value = loss(model, inputs, targets, training=True)
          return loss_value, tape.gradient(loss_value, model.trainable_variables)
    
    
        # Making training faster
        epoch_loss_avg = tf.keras.metrics.Mean()
    
        @tf.function
        def train_step(x, y):
            loss_value, grads = grad(model, x, y)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            epoch_loss_avg.update_state(loss_value)
            return loss_value
        
        file_path = output_directory + file_name + '_last.keras'

        train_loss_results = []
        optimizer = tf.keras.optimizers.Adam()

        train_data = [X_train, X_train_pos_all, X_train_neg_all]

        train_label = y_train

        pre_loss = tf.zeros(1)
        for epoch in range(nb_epochs):

            # Training loop - using batches of 32
            for ii in range(int(X_train.shape[0]/batch_size)):
                x = [X_train[32*ii:32*(ii+1),:,:], X_train_pos_all[32*ii:32*(ii+1),:,:], X_train_neg_all[32*ii:32*(ii+1),:,:]]
                y = train_label[32*ii:32*(ii+1)]
                # Optimize the model
                loss_value = train_step(x, y)

            # Display metrics at the end of each epoch.
            pre_loss = loss_value
            loss = epoch_loss_avg.result()
            if (verbose == 1):
                print("Epoch {:03d}: Loss: {:.20f}".format(epoch,epoch_loss_avg.result()))
            train_loss_results.append(loss)
  
            if (tf.math.abs(loss-pre_loss) <= 1e-8):
              break # quit training if loss is zero, no further updates
  
            # Reset training metrics at the end of each epoch
            # epoch_loss_avg.reset_states()
            epoch_loss_avg.reset_state()

        # Save model
        embed_model.save(file_path)

        epochs = range(1, len(train_loss_results) + 1)

        # Plot loss
        if (diagonistic == True):
            plt.figure()
            plt.plot(epochs, train_loss_results, 'bo', label='Training loss')
            plt.title('Training loss')
            plt.xlabel('Epochs')
            plt.ylabel('Loss')
            plt.legend()
            plt.ylim([0,10])
            plt.show()
            
            plt.savefig(fname = fig_save_path + file_name + 'acc_loss' + '.png')
            plt.savefig(fname = fig_save_path + file_name + 'acc_loss' + '.svg')
        
        return (embed_model, model)
    #--------------------------------------------------------------------------  
 
    
 
    
 
    #--------------------------------------------------------------------------        
    # Test models
    # triplet_test, embed_model are used to test triplet model
    #--------------------------------------------------------------------------  
    def test_models(self, X_test, y_test, X_test_exp, y_test_exp, model,
                     file_name, fig_save_path, 
                     triplet_test = False, embed_model = 0, model_class = 0):
        
        # Test on sim data
        if (triplet_test == False):
            y_pred_multi = model.predict(X_test)
        else:
            y_pred_feature = embed_model.predict(X_test)
            y_pred_multi = model_class.predict(y_pred_feature)
      
        # convert the predicted from binary to integer
        y_pred = np.argmax(y_pred_multi , axis=1)
      
        # normalization the probability
        y_pred_multi_norm = y_pred_multi/y_pred_multi.sum(axis=1,keepdims=1)
      
        txt_name = fig_save_path + file_name + '_logs.txt'
        acc_FCNN = accuracy_score(y_test, y_pred, normalize=True)
        f1_score_FCNN = f1_score(y_test, y_pred, average = 'weighted')
        roc_auc_score_FCNN = roc_auc_score(y_test, y_pred_multi_norm, average = 'weighted', multi_class= 'ovr')
        print('The percentage accuracy for ' + file_name + ' with sim data is :' +  str(acc_FCNN),)
        print('The F1 score (weighted) for ' + file_name + ' with sim data is :' +  str(f1_score_FCNN),)
        print('The ROC AUC (weighted, ovr) for ' + file_name + ' with sim data is :' +  str(roc_auc_score_FCNN),)
        with open(txt_name, "a") as f:
            print('The percentage accuracy for ' + file_name + ' with sim data is :' +  str(acc_FCNN),
                  file = f)
            print('The F1 score (weighted) for ' + file_name + ' with sim data is :' +  str(f1_score_FCNN),
                  file = f)
            print('The ROC AUC (weighted, ovr) for ' + file_name + ' with sim data is :' +  str(roc_auc_score_FCNN),
                  file = f)
        


        # Test on exp data 
        if (triplet_test == False):
            y_pred_exp_multi = model.predict(X_test_exp)
        else:
            y_pred_exp_feature = embed_model.predict(X_test_exp)
            y_pred_exp_multi = model_class.predict(y_pred_exp_feature)
      
        # convert the predicted from binary to integer
        y_pred_exp = np.argmax(y_pred_exp_multi , axis=1)
    
        # normalization the probability
        y_pred_multi_norm_exp = y_pred_exp_multi/y_pred_exp_multi.sum(axis=1,keepdims=1)
     
        acc_FCNN = accuracy_score(y_test_exp, y_pred_exp, normalize=True)
        f1_score_FCNN = f1_score(y_test_exp, y_pred_exp, average = 'weighted')
        roc_auc_score_FCNN = roc_auc_score(y_test_exp, y_pred_multi_norm_exp, average = 'weighted', multi_class= 'ovr')

        print('The percentage accuracy for ' + file_name + ' with exp data is :' +  str(acc_FCNN),)
        print('The F1 score (weighted) for ' + file_name + ' with exp data is :' +  str(f1_score_FCNN),)
        print('The ROC AUC (weighted, ovr) for ' + file_name + ' with exp data is :' +  str(roc_auc_score_FCNN),)
        with open(txt_name, "a") as f:
            print('The percentage accuracy for ' + file_name + ' with exp data is :' +  str(acc_FCNN),
                  file = f)
            print('The F1 score (weighted) for ' + file_name + ' with exp data is :' +  str(f1_score_FCNN),
                  file = f)
            print('The ROC AUC (weighted, ovr) for ' + file_name + ' with exp data is :' +  str(roc_auc_score_FCNN),
                  file = f)
            
            

        # Confusion matrix and plot
        # plt.figure()
        cm = confusion_matrix(y_test_exp, y_pred_exp)
        cm_arr = [cm[0,0], cm[0,1], cm[0,2],cm[1,0],
                  cm[1,1],cm[1,2],cm[2,0], cm[2,1],cm[2,2]]
            
        with open(txt_name, "a") as f:
            print('Confusion_matrix = ' + str(cm_arr), 
                  file = f)

        
        return None
    #--------------------------------------------------------------------------  

    #--------------------------------------------------------------------------        
    # Test models with improved thresholds -- use FCN and ResNet only
    #--------------------------------------------------------------------------  
    def test_models_thresholds(self, X_test_exp, y_test_exp, model,
                     file_name, fig_save_path, diagonistic = False):
        
        # Modify labels in class 1 vs the rest
        y_test_exp_one_vs_r = [1 if y_test_exp[i] == 1 else 0 for i in range(len(y_test_exp))]
        y_test_exp_one_vs_r = np.array(y_test_exp_one_vs_r)
        
        # get roc curvces per classes, using one over rest
        y_pred_exp_multi = model.predict(X_test_exp)
        y_pred_exp_multi_norm = y_pred_exp_multi/y_pred_exp_multi.sum(axis=1,keepdims=1)
          
        fpr = {}
        tpr = {}
        thresh ={}
        n_class = 3
        for i in range(n_class):    
            fpr[i], tpr[i], thresh[i] = roc_curve(y_test_exp, y_pred_exp_multi_norm[:,i], pos_label=i)
            
        # plotting  
        if (diagonistic == True):
            plt.figure()
            plt.plot(fpr[0], tpr[0], linestyle='--',color='orange', label='Class 0 vs Rest')
            plt.plot(fpr[1], tpr[1], linestyle='--',color='green', label='Class 1 vs Rest')
            plt.plot(fpr[2], tpr[2], linestyle='--',color='blue', label='Class 2 vs Rest')
            plt.title('Multiclass ROC curve')
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive rate')
            plt.legend(loc='best')
            
        # get the optimal threshold
        class_no = 1
        # Youden's J statistic https://stackoverflow.com/questions/28719067/roc-curve-and-cut-off-point-python
        optimal_fcn = tpr[class_no] - fpr[class_no]

        
        optimal_idx = np.argmax(optimal_fcn)
        optimal_threshold = thresh[class_no][optimal_idx]
        if (diagonistic == True):
            print("Threshold value is:", optimal_threshold)
        
        
        # Metrics before changing threshold
        y_pred_exp = np.argmax(y_pred_exp_multi , axis=1)
        y_pred_exp_one_vs_r = [1 if y_pred_exp[i] == 1 else 0 for i in range(len(y_pred_exp))]
        y_pred_exp_one_vs_r = np.array(y_pred_exp_one_vs_r)

        acc_before = accuracy_score(y_test_exp_one_vs_r, y_pred_exp_one_vs_r, normalize=True)
        f1_score_before = f1_score(y_test_exp_one_vs_r, y_pred_exp_one_vs_r, average = 'weighted')

        print('The percentage accuracy for ' + file_name + ' with exp data before thresholding is :' +  str(acc_before),)
        print('The F1 score (weighted) for ' + file_name + ' with exp data before thresholding is :' +  str(f1_score_before),)

        # Confusion matrix and plot
        cm_before = confusion_matrix(y_test_exp_one_vs_r, y_pred_exp_one_vs_r)
        cm_before_arr = [cm_before[0,0],cm_before[0,1],cm_before[1,0],cm_before[1,1]]

        # Metrics after chaning threshold
        y_pred_thres = [1 if y_pred_exp_multi[i,1] > optimal_threshold else 0 for i in range(len(y_test_exp_one_vs_r))]
        y_pred_thres = np.array(y_pred_thres)
        acc_after = accuracy_score(y_test_exp_one_vs_r, y_pred_thres, normalize=True)
        f1_score_after = f1_score(y_test_exp_one_vs_r, y_pred_thres, pos_label = 0, average = 'weighted')

        print('The percentage accuracy for ' + file_name + ' with exp data after thresholding is :' +  str(acc_after),)
        print('The F1 score (weighted) for ' + file_name + ' with exp data after thresholding is :' +  str(f1_score_after),)

        
        # Confusion matrix and plot
        cm_after = confusion_matrix(y_test_exp_one_vs_r, y_pred_thres)
        cm_after_arr = [cm_after[0,0],cm_after[0,1],cm_after[1,0],cm_after[1,1]]

        # writing to text file
        txt_name = fig_save_path + file_name + '_threshold_logs.txt'
        with open(txt_name, "a") as f:
            print('The percentage accuracy for ' + file_name + ' with exp data before thresholding is :' +  str(acc_before),
                  file = f)
            print('The F1 score (weighted) for ' + file_name + ' with exp data before thresholding is :' +  str(f1_score_before),
                  file = f)
            print('Confusion_matrix before thresholding = ' + str(cm_before_arr), 
                  file = f)
    
    
            print('The percentage accuracy for ' + file_name + ' with exp data after thresholding is :' +  str(acc_after),
                  file = f)
            print('The F1 score (weighted) for ' + file_name + ' with exp data after thresholding is :' +  str(f1_score_after),
                  file = f)
            print('Confusion_matrix after thresholding = ' + str(cm_after_arr), 
                  file = f)
    
        return None
    #--------------------------------------------------------------------------        
    
  
    
      

    
    
    
    
    
#--------------------------------------------------------------------------   
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    