# =============================================================================
# Description
# =============================================================================

# Call script for detecting single molecule curves with deep learning models
# =============================================================================

# =============================================================================
# Imports
# ============================================================================= 
import numpy as np
import sys
from matplotlib import pyplot as plt
import pandas as pd
import os
from sklearn.model_selection import train_test_split
import sklearn as sk
import tensorflow.keras as keras
import tensorflow as tf
from sklearn.metrics import confusion_matrix

import sys
sys.path.append('../')

from APIs.utils import utils
from APIs.SMNN import SMNN
#%%
UtObj = utils()
SNObj = SMNN()

plot_fonts = UtObj.plot_fonts
plt.rcParams.update(plot_fonts)

color_set = UtObj.color_set
color_set_hist = UtObj.color_set_hist
marker_set = UtObj.marker_set
fmt_set = UtObj.fmt_set

#%%
# Seed all
random_seed = np.random.randint(0, 1000000) # random seeds
print('The current seed = ' + str(random_seed))
UtObj.seed_all(random_seed)

#%%
# Check the gpu status
print("Num Devices Available: ", tf.config.list_physical_devices())
#%%
################################################
#             Global parameters
################################################

# Choose which dataset to use
molecule = 'Titin' # Choose from 'Titin', 'UtrNR3', 'DysNR3_bact', 'fewshot' (note: fewshot is DDRs)
molecule_exp_use = 'Titin' # which exp data to use, need to change when molecule = 'all
df_save_path = '../Data/ML_Dataset/' # data save path


resampling = True  # resample or not
resample_size = 300 # The size of resampling (399 for fewshot)
data_normalization = True # Do normalization of data (use minmax normalization)
Reference_data = True # Add reference data as input
no_refer = 1 # Number of reference signal to add


Train_Model_MLP = False  # Train MLP or not
Test_Model_MLP = True # Test MLP or not
Train_Model_FCNN = False  # Train FCNN or not
Test_Model_FCNN = True # Test FCNN or not
Train_Model_RESNET = False # Train RESNET or not
Test_Model_RESNET = True # Test RESNET or not
Train_Model_transfer = False # Transfer traing with FCNN or not
Train_Model_Triplet = False # Train Triplet or not
Test_Model_Triplet = True # Train Triplet or not


test_size = 0.2 # Test set percentage  (0.3 for fewshot)

data_save_path = 'ML_models/' + molecule + '/'
os.makedirs(data_save_path, exist_ok=True)

fig_save_path = data_save_path + 'plots/'
os.makedirs(fig_save_path, exist_ok=True)

output_directory = 'ML_models/' + molecule + '/saved_model/' # model save path
os.makedirs(output_directory, exist_ok=True)
#%%
"""
Dataset:
    1. Read dataset
    2. Format into NN dataset
    3. Preprocess of Data: Trim, resample + normalization, add refer data
"""
################################################
#             Read Dataset
################################################
# Simulation data
data_pd = pd.read_pickle(df_save_path + 'ML_data_' + molecule + '.csv')
# Reference data
data_pd_ref = SNObj.get_ref_data(data_pd, molecule)
# data_pd_ref = pd.read_pickle(df_save_path + 'ML_data_refer_' + molecule + '.csv')
    
# Read experimental data
data_pd_exp = pd.read_pickle(df_save_path + 'ML_data_exp_' + molecule + '.csv')

#%%
################################################
#             Format Dataset
################################################
# Format sim data
X_data = []
y_data = []
for ii in range(len(data_pd)):
    x_data_load = data_pd['Fwlc'][ii]*1e12
    X_data.append(x_data_load) # change into pN
    if (data_pd['No_molecule'][ii] == '0'):
        y_data.append(0)
    if (data_pd['No_molecule'][ii] == '1'):
        y_data.append(1)
    if (data_pd['No_molecule'][ii] == '2'):
        y_data.append(2)
    if (data_pd['No_molecule'][ii] == '3'):
        y_data.append(2)
X_data = np.array(X_data)
y_data = np.array(y_data)
# Spliting into test data
[X_train, X_test, y_train, y_test] = train_test_split(X_data, y_data, test_size = test_size, random_state=random_seed)
# One hot encoding of labels
enc = sk.preprocessing.OneHotEncoder(categories='auto')
enc.fit(np.concatenate((y_train, y_test), axis=0).reshape(-1, 1))
y_train_oh = enc.transform(y_train.reshape(-1, 1)).toarray()
y_test_oh = enc.transform(y_test.reshape(-1, 1)).toarray()



# Format exp data
X_test_exp = []
y_test_exp = []
file_name_exp = []
sim_data_size = np.size(X_data,1)
# find exp_data_size
exp_data_size = 0
for ii in data_pd_exp.index:
    # Need to padding for experimental data
    Fwlc = data_pd_exp['Fwlc'][ii]*1e12
    exp_data_size = np.max([np.size(Fwlc),exp_data_size])


for ii in data_pd_exp.index:  
    Fwlc = data_pd_exp['Fwlc'][ii]*1e12
    padding_size = np.max([exp_data_size, sim_data_size]) - np.size(Fwlc)
    Fwlc = np.pad(Fwlc, (0, padding_size), 'constant', constant_values=(0,0))
    file_name = data_pd_exp['file_name'][ii]
    
    X_test_exp.append(Fwlc) # change into pN
    file_name_exp.append(file_name)
    
    
    if (data_pd_exp['No_molecule'][ii] == '0'):
        y_test_exp.append(0)
    if (data_pd_exp['No_molecule'][ii] == '1'):
        y_test_exp.append(1)
    if (data_pd_exp['No_molecule'][ii] == '2'):
        y_test_exp.append(2)
    if (data_pd_exp['No_molecule'][ii] == '3'):
        y_test_exp.append(2)

X_test_exp = np.array(X_test_exp)
y_test_exp = np.array(y_test_exp)
y_test_exp_oh = enc.transform(y_test_exp.reshape(-1, 1)).toarray()

#%%
################################################
#             Resample Data
################################################

# Simulation data
sim_data = True
[X_train, _,] = SNObj.resample_normalization_data( X_train, y_train, molecule, sim_data,
                                resampling, resample_size, data_normalization)
[X_test, _,] = SNObj.resample_normalization_data(X_test, y_test, molecule, sim_data,
                                resampling, resample_size, data_normalization)
# Exp data
sim_data = False
[X_test_exp, _] = SNObj.resample_normalization_data(X_test_exp, y_test_exp, molecule, sim_data,
                                resampling, resample_size, data_normalization)
#%%
################################################
#         Add reference data
################################################
if (Reference_data == True):
    X_train = SNObj.add_refer_data(X_train, y_train, data_pd_ref, no_refer, resample_size, 
                       data_normalization)
    X_test = SNObj.add_refer_data(X_test, y_test, data_pd_ref, no_refer, resample_size, 
                       data_normalization)
    X_test_exp = SNObj.add_refer_data(X_test_exp, y_test_exp, data_pd_ref, no_refer, resample_size, 
                       data_normalization)
#%%
"""
Networks:
    1. ** Build, train, test MLP **
    2. Build, train, test FCNN
    3. Build, train, test ResNet
    4. Build, train, test Triplet
"""

################################################
#        Build MLP
################################################
# Build model with keras
input_shape = X_train.shape[1:]
nb_classes = len(np.unique(np.concatenate((y_train, y_test), axis=0)))
file_name = 'MLP'

[input_layer, output_layer] = SNObj.create_MLP(input_shape, nb_classes)

model = keras.models.Model(inputs=input_layer, outputs=output_layer)

# Loss function and optimizer
model.compile(loss='categorical_crossentropy', optimizer=keras.optimizers.Adadelta(),
  metrics=['accuracy'])

# Callbacks
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='loss', factor=0.5, patience=200, min_lr=0.1)

file_path = output_directory + file_name + '_best.keras'

model_checkpoint = keras.callbacks.ModelCheckpoint(filepath=file_path, monitor='loss',
  save_best_only=True)

callbacks = [reduce_lr,model_checkpoint]
#%%
################################################
#        Train MLP
################################################
# Train the model
if (Train_Model_MLP == False) and  (not (os.path.isfile(file_path))):
    print('No pretrained model available, please train the model first!')
elif (Train_Model_MLP == False) and (os.path.isfile(file_path)):
    model = keras.models.load_model(file_path)
elif (Train_Model_MLP == True):
    if (molecule != 'fewshot'):
        batch_size = 32
        nb_epochs = 500
    else:
        batch_size = 8
        nb_epochs = 200
    verbose = 1 # progress status
    diagonistic = True
    
    model = SNObj.train_models(X_train, y_train_oh, model, callbacks, batch_size, 
                     nb_epochs, output_directory, file_name, fig_save_path, 
                     verbose = verbose, diagonistic = diagonistic)
#%%
################################################
#        Test MLP
################################################
if (Test_Model_MLP == True):
    model = keras.models.load_model(file_path)

    if (molecule != 'fewshot'):
        SNObj.test_models(X_test, y_test, X_test_exp, y_test_exp, model,
                          file_name, fig_save_path)
    else:
        SNObj.test_models(X_train, y_train, X_test, y_test, model,
                          file_name, fig_save_path)
#%%
"""
Networks:
    1. Build, train, test MLP
    2. ** Build, train, test FCNN **
    3. Build, train, test ResNet
    4. Build, train, test Triplet
"""
################################################
#       Fully Convolutional Neural Network (FCNN)
################################################
input_shape = X_train.shape[1:]
nb_classes = len(np.unique(np.concatenate((y_train, y_test), axis=0)))
file_name = 'FCNN'


kernel_size_arr = np.array([8,5,3])

# Create model
[input_layer, output_layer] = SNObj.create_FCNN(input_shape, nb_classes, kernel_size_arr)

model = keras.models.Model(inputs=input_layer, outputs=output_layer)

# Loss function and optimizer
model.compile(loss='categorical_crossentropy', optimizer = keras.optimizers.Adam(),
              metrics=['accuracy'])

# Callbacks
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='loss', factor=0.5, patience=50,
                                              min_lr=0.0001)

file_path = output_directory + file_name + '_best.keras'

model_checkpoint = keras.callbacks.ModelCheckpoint(filepath=file_path, monitor='loss',
                                                   save_best_only=True)

callbacks = [reduce_lr,model_checkpoint]
model.save(output_directory + file_name +'_initial' + '.keras')
#%%
################################################
#      Train FCNN
################################################
# Train the model
if (Train_Model_FCNN == False) and  (not (os.path.isfile(file_path))):
    print('No pretrained model available, please train the model first!')
elif (Train_Model_FCNN == False) and (os.path.isfile(file_path)):
    model = keras.models.load_model(file_path)
elif (Train_Model_FCNN == True):
    
    if (molecule != 'fewshot'):
        batch_size = 32 #16 #32
        nb_epochs = 300 #300
    else:
        # fewshot 
        # batch_size = 8
        # nb_epochs = 200
        
        # exp data only 
        batch_size = 32
        nb_epochs = 100
    
    verbose = 1 # progress status
    diagonistic = True
    model = SNObj.train_models(X_train, y_train_oh, model, callbacks, batch_size, 
                       nb_epochs, output_directory, file_name, fig_save_path, 
                       verbose = verbose, diagonistic = diagonistic)

#%%
################################################
#      Test FCNN
################################################
if (Test_Model_FCNN == True):
    model = keras.models.load_model(file_path)
    
    if (molecule != 'fewshot'):
        SNObj.test_models(X_test, y_test, X_test_exp, y_test_exp, model,
                          file_name, fig_save_path)
        
                
        # threshold study
        SNObj.test_models_thresholds(X_test_exp, y_test_exp, model,
                          file_name, fig_save_path, diagonistic = False)
    else:
        SNObj.test_models(X_train, y_train, X_test, y_test, model,
                          file_name, fig_save_path)

#%%
"""
Networks:
    1. Build, train, test MLP
    2. Build, train, test FCNN
    3. ** Build, train, test ResNet **
    4. Build, train, test Triplet
"""
################################################
#       Deep Residual Network (ResNet)
################################################
# Build model with keras
input_shape = X_train.shape[1:]
nb_classes = len(np.unique(np.concatenate((y_train, y_test), axis=0)))
file_name = 'ResNet'

# Create model
n_feature_maps = 64 

kernel_size_arr = np.array([8,5,3])

[input_layer, output_layer] = SNObj.create_ResNet(input_shape, nb_classes, kernel_size_arr, n_feature_maps)

model = keras.models.Model(inputs=input_layer, outputs=output_layer)

# Loss function and optimizer
model.compile(loss='categorical_crossentropy', optimizer = keras.optimizers.Adam(),
  metrics=['accuracy'])

# Callbacks
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='loss', factor=0.5, patience=50,
  min_lr=0.0001)

file_path = output_directory + file_name + '_best.keras'

model_checkpoint = keras.callbacks.ModelCheckpoint(filepath=file_path, monitor='loss',
  save_best_only=True)

callbacks = [reduce_lr,model_checkpoint]

model.save(output_directory + file_name +'_initial' + '.keras')
#%%
################################################
#       Train ResNet
################################################
if (Train_Model_RESNET == False) and  (not (os.path.isfile(file_path))):
    print('No pretrained model available, please train the model first!')
elif (Train_Model_RESNET == False) and (os.path.isfile(file_path)):
    model = keras.models.load_model(file_path)
elif (Train_Model_RESNET == True):
    if (molecule != 'fewshot'):
        batch_size = 64#64
        nb_epochs = 200#300
    else:
        batch_size = 8
        nb_epochs = 200
    verbose = 1 # progress status
    
    diagonistic = True
    model = SNObj.train_models(X_train, y_train_oh, model, callbacks, batch_size, 
                       nb_epochs, output_directory, file_name, fig_save_path, 
                       verbose = verbose, diagonistic = diagonistic)

#%%
################################################
#       Test ResNet
################################################
if (Test_Model_RESNET  == True):
    model = keras.models.load_model(file_path)

    if (molecule != 'fewshot'):
        SNObj.test_models(X_test, y_test, X_test_exp, y_test_exp, model,
                          file_name, fig_save_path)
        
                        
        # threshold study
        SNObj.test_models_thresholds(X_test_exp, y_test_exp, model,
                          file_name, fig_save_path, diagonistic = False)
    else:
        SNObj.test_models(X_train, y_train, X_test, y_test, model,
                          file_name, fig_save_path)
#%%
################################################
#       Transfer learning on exp data
################################################
# Build training dataset from exp data
# Take 70% as test dataset, and do not touch them. 
[X_trans_train_all, X_trans_test, y_trans_train_all, y_trans_test] = train_test_split(X_test_exp, y_test_exp, 
                                                                              test_size = 0.7, 
                                                                              random_state=random_seed,
                                                                              stratify=y_test_exp) 
if (Train_Model_transfer == True):
    for file_token in ['FCNN', 'ResNet']: # 'FCNN', 
        for test_size_transfer in np.array([5/6, 2/3, 1/3]):
        # for test_size_transfer in np.array([5/6]):
            [X_trans_train, X_trans_valid, y_trans_train, y_trans_valid] = train_test_split(X_trans_train_all, y_trans_train_all, 
                                                                                          test_size = test_size_transfer, 
                                                                                          random_state=random_seed,
                                                                                          stratify=y_trans_train_all)
            
            y_trans_train_oh = enc.transform(y_trans_train.reshape(-1, 1)).toarray()
            y_trans_test_oh = enc.transform(y_trans_test.reshape(-1, 1)).toarray()
                
            for transfer_or_initial in np.array([True, False]): #False
                if (transfer_or_initial == True):
                    # Load trained model
                    file_name = file_token
                    file_path = output_directory + file_name + '_best.keras'
                    model = keras.models.load_model(file_path)
                    
                    # the current model name
                    file_name = file_token + '_transfer'
                    file_path = output_directory + file_name + '_best.keras'
                else:
                    # Load initial model
                    file_name = file_token
                    file_path = output_directory + file_name + '_initial.keras'
                    model = keras.models.load_model(file_path)
                    
                    file_name = file_token + '_initial'
                    file_path = output_directory + file_name + '_best.keras'
                    
                model_checkpoint = keras.callbacks.ModelCheckpoint(filepath=file_path, monitor='loss',
                  save_best_only=True)
                
                callbacks = [reduce_lr,model_checkpoint]
                
                
                # Train the model
        
                batch_size = 32 #16 #32
                nb_epochs = 100 #200 #50
                
                verbose = 1 # progress status
                
                diagonistic = True
                model = SNObj.train_models(X_trans_train, y_trans_train_oh, model, callbacks, batch_size, 
                               nb_epochs, output_directory, file_name, fig_save_path, 
                               validation_split = 0.001, verbose = verbose, diagonistic = diagonistic)
                
                
                ################################################
                #      Test transfer FCNN
                ################################################
                txt_name = fig_save_path + file_name + '_logs.txt'
                y_pred_exp_multi = model.predict(X_trans_test)
                y_pred_exp = np.argmax(y_pred_exp_multi , axis=1)
                
                cm = confusion_matrix(y_trans_test, y_pred_exp)
                cm_arr = [cm[0,0], cm[0,1], cm[0,2],cm[1,0],
                          cm[1,1],cm[1,2],cm[2,0], cm[2,1],cm[2,2]]
                    
                with open(txt_name, "a") as f:
                    print('Using transfer learning = ' + str(transfer_or_initial), file = f)
                    print('Molecule = ' + molecule_exp_use + ' of tranfer learning with model ' + file_token + 
                          ' and train size = ' + str(30*(1-test_size_transfer)), 
                          file = f)
                    print('Confusion_matrix = ' + str(cm_arr), 
                          file = f)
                    
                SNObj.test_models(X_test, y_test, X_trans_test, y_trans_test, model,
                                  file_name, fig_save_path)
            
    
#%%
"""
Networks:
    1. Build, train, test MLP
    2. Build, train, test FCNN
    3. Build, train, test ResNet
    4. ** Build, train, test Triplet **
"""

################################################
#       Triplet Network
################################################
#  Build dataset for Triplet network
X_train_pos_all = np.empty((0,X_train.shape[1],X_train.shape[2]), int)
X_train_neg_all = np.empty((0,X_train.shape[1],X_train.shape[2]), int)

y_train_pos_all = []
y_train_neg_all = []
for ii in range(len(y_train)):
    y_pos_chk = (y_train == y_train[ii])
    y_neg_chk =np.logical_not(y_pos_chk)
    
    X_train_pos = X_train[y_pos_chk]
    X_train_neg = X_train[y_neg_chk]
    
    pos_random_choice = np.random.choice(X_train_pos.shape[0], size = 1)
    neg_random_choice = np.random.choice(X_train_neg.shape[0], size = 1)
    
    X_train_pos_sel = X_train_pos[pos_random_choice,:,:]
    X_train_neg_sel = X_train_neg[neg_random_choice,:,:]
    
    X_train_pos_all = np.vstack((X_train_pos_all, X_train_pos_sel))
    X_train_neg_all = np.vstack((X_train_neg_all, X_train_neg_sel))
    
    y_train_pos_all.append(y_train[pos_random_choice])
    y_train_neg_all.append(y_train[neg_random_choice])

y_train_pos_all = np.array(y_train_pos_all)
y_train_neg_all = np.array(y_train_neg_all)


y_train_oh = enc.transform(y_train.reshape(-1, 1)).toarray()
y_train_pos_all_oh = enc.transform(y_train_pos_all.reshape(-1, 1)).toarray()
y_train_neg_all_oh = enc.transform(y_train_neg_all.reshape(-1, 1)).toarray()

# Build model with keras
input_shape = X_train.shape[1:]
nb_classes = len(np.unique(np.concatenate((y_train, y_test), axis=0)))

# file path to save model
file_name = 'Triplet'
file_path = output_directory + file_name + '_last.keras'

# create model
[embed_model, model] = SNObj.create_Triplet(input_shape)


#%%
################################################
#  Train Triplet network
################################################
# Keep results for plotting
if (Train_Model_Triplet == False) and  (not (os.path.isfile(file_path))):
    print('No pretrained model available, please train the model first!')
elif (Train_Model_Triplet == False) and (os.path.isfile(file_path)):
    embed_model = keras.models.load_model(file_path)
elif (Train_Model_Triplet == True):
    if (molecule != 'fewshot'):
        batch_size = 32
        nb_epochs = 150
    else:
        batch_size = 8
        nb_epochs = 100
    verbose = 1
    diagonistic = True
    
    [embed_model, model] = SNObj.train_triplet(X_train, X_train_pos_all, X_train_neg_all, y_train, embed_model, model, batch_size, 
                     nb_epochs, output_directory, file_name, fig_save_path, verbose = verbose, diagonistic = diagonistic)
#%%
################################################
#  Train Triplet network --- 3-layer netork classification
################################################
if (Test_Model_Triplet == True):
    # Build model with keras
    X_train_feature = embed_model(X_train)
    input_shape = X_train_feature.shape[1:]
    nb_classes = len(np.unique(np.concatenate((y_train, y_test), axis=0)))

    file_name = 'MLP_Triplet'

    [input_layer, output_layer] = SNObj.create_MLP(input_shape, nb_classes)

    model_class = keras.models.Model(inputs=input_layer, outputs=output_layer)

    # Loss function and optimizer
    model_class.compile(loss='categorical_crossentropy', optimizer=keras.optimizers.Adadelta(),
      metrics=['accuracy'])
    
    # Callbacks
    reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='loss', factor=0.5, patience=200, min_lr=0.1)
    
    file_path = output_directory+file_name+'_best.keras'
    
    model_checkpoint = keras.callbacks.ModelCheckpoint(filepath=file_path, monitor='loss',
      save_best_only=True)
    
    callbacks = [reduce_lr,model_checkpoint]

    batch_size = 16
    nb_epochs = 20
    verbose = 1 # progress status
    diagonistic = True
    
    model_class = SNObj.train_models(X_train_feature, y_train_oh, model_class, callbacks, batch_size, 
                     nb_epochs, output_directory, file_name, fig_save_path, 
                     verbose = verbose, diagonistic = diagonistic)  

#%%
################################################
#  Test Triplet network 
################################################
# Test on data
if (Test_Model_Triplet == True):

    if (molecule != 'fewshot'):
        SNObj.test_models(X_test, y_test, X_test_exp, y_test_exp, model,
                         file_name, fig_save_path, 
                         triplet_test = True, embed_model = embed_model, model_class = model_class)
    else:
        SNObj.test_models(X_train, y_train, X_test, y_test, model,
                         file_name, fig_save_path, 
                         triplet_test = True, embed_model = embed_model, model_class = model_class)
        

#%%




















