import os.path
import sys
import h5py
import math
import gc
import numpy as np
#from numba import cuda
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Flatten, Dense, Input, Conv1D, MaxPooling1D, ReLU, Dropout, Concatenate, Multiply, BatchNormalization #, AveragePooling1D, Add, GlobalAveragePooling1D, GlobalMaxPooling1D
from tensorflow.keras.utils import plot_model   #, get_source_inputs
from tensorflow.keras.utils import get_file
from tensorflow.keras import backend as K
#from keras_applications.imagenet_utils import decode_predictions
#from keras_applications.imagenet_utils import preprocess_input
#from keras_applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model
import tensorflow as tf
import matplotlib.pyplot as plt
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path

bp_range = [0, 3329]
skpv_range = [0, 3328]
fqmul_range = [-1828, 1664]
tracelen = 600
NumFQMULclasses = fqmul_range[1] - fqmul_range[0] + 1;  # number of classes for fqmul(skpv, bp)
NumSKPVclasses = skpv_range[1] - skpv_range[0] + 1;     # number of classes for skpv
NumBPinput = bp_range[1] - bp_range[0] + 1;             # number of input for bp (ciphertext)
noClasses = NumSKPVclasses
noHypoKeys = NumSKPVclasses
sKeyNo = 0  # Note: sKeyNo is in range 0 to 3 and which subkeys are they are decided by code in m4 (NOT by code in PC)
work = 'train' #'train'  #'attack'
training_file_list = ['Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data000000to099999_600samples.h5',\
'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data100000to199999_600samples.h5']
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data200000to299999_600samples.h5']#,\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data300000to399999_600samples.h5',\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data400000to499999_600samples.h5']

trained_model_path = 'trained_models/4C4FC_2BP4FC4FC_J4FCSM_hy0101010101_skpv0_bp01_100kDatax2/trained4C4FC_2BP4FC4FC_J4FCSM_hy0101010101_skpv0_bp01_byEpoch/4C4FC_2BP4FC4FC_J4FCSM_hy0101010101_skpv0_bp01_ep1271.h5'
data_path = 'data.npz'
nruns_default = 20
maxtrc_default = 200
testPortion = 1
attack_byModel_epNo = 232

#print("len(sys.argv) =", len(sys.argv))
#print("sys.argv[0] =", sys.argv[0])
#print("sys.argv[1] =", sys.argv[1])
#print("sys.argv[2] =", sys.argv[2])
if len(sys.argv) > 1:
    work = sys.argv[1]
    attack_byModel_epNo = int(sys.argv[2])
print("work =", work)
print("attack_byModel_epNo =", attack_byModel_epNo)
#input()

# training parameters
train_batch_size = 500#100#150#200#250#500#640 #80 for mars45 #170 for mars56
period = 8 #8
maxEpochs = 1536#3072#2048#1536#1280#1024#512#256 #1536
attack_byModel_fileNo = int(attack_byModel_epNo/period)
N_TRACE = 200000

#model hyper-parameters
noConv1Dbranch = 1
noLayers = 6    # if newly train
noClassificationLayer = 1
GPU_clear = True    # False

# training data type
xType = 'wavebp01'  #'wave' #'wavebp0' #'wavebp1' #'wavebp01' #'wavebp01next0' #'wavebp01next01'
yType = 'skpv'    #'fqmul0' #'fqmul1' #'skpv' 
trainPortion = 0.8

# Database and logs for model and training progress (epochs)
attackModel = 'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1'
device = 'm4_CWLite'
attackModel_dev = attackModel + '_' + device
attackModel_dev_folder = '../' + attackModel_dev + '/'

MLmodelStruct = '4C4FC_2BP4FC4FC_J4FCSM'
#MLmodel_detail = '3C[512_128_64]_2BP4FC[1024_512_256_128]4FC[1024_512_256_128]_J4FC[1024_512_256_128]SM'
MLmodel_detail = '4C/512_256_128_64/_2BP4FC/1024_512_256_128/4FC/1024_512_256_128/_J4FC/1024_512_256_128/SM'

hyper_ver = 'hy0001010101_skpv0'    #hyper-parameter contains 5 groups: Conv1D, FC for Conv1D, BP0, BP1, FC for joined BPs
#dataFile_train = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'  #'20kDatax25'
dataFile_train_folder = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'    #'20kDatax25'
dataFile_attack = '100kDatax1_test'#'skvp0_0_700points100kDatax1attack' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'  #'20kDatax25'
model_input_type = '_in[[][]]_tf2' #'in[]_tf2' #'[[][]]_tf2'
#data_type = dataFile_train + model_input_type
data_type = '100kDataxN' + str(len(training_file_list)) + model_input_type
#database_folder_train = attackModel_dev_folder + attackModel + '_' + dataFile_train_folder + '_h5/'
database_folder_train = 'trained_models/active_models'
database_folder_attack = attackModel_dev_folder + attackModel + '_' + dataFile_attack + '_h5/'
logFilename = MLmodelStruct + '_' + hyper_ver
DLmodel_name = logFilename
#DLmodel_folder = attackModel_dev_folder + logFilename + '_' + data_type + '/'
DLmodel_folder = 'models/'
modelLogFolder = DLmodel_folder + 'log' + DLmodel_name + '/'
logTrainedModel_byFile_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byDataFile/'
#logTrainedModel_byEp_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byEpoch/'
logTrainedModel_byEp_folder = logTrainedModel_byFile_folder
attackLogFolder = DLmodel_folder + 'log' + DLmodel_name + '_attack/'
if os.path.isdir(DLmodel_folder) == False:
    os.mkdir(DLmodel_folder)
if os.path.isdir(modelLogFolder) == False:
    os.mkdir(modelLogFolder)
if os.path.isdir(logTrainedModel_byFile_folder) == False:
    os.mkdir(logTrainedModel_byFile_folder)
if os.path.isdir(logTrainedModel_byEp_folder) == False:
    os.mkdir(logTrainedModel_byEp_folder)
print('DLmodel_folder =', DLmodel_folder)
print('modelLogFolder =', modelLogFolder)
print('logTrainedModel_byFile_folder =', logTrainedModel_byFile_folder)
print('logTrainedModel_byEp_folder =', logTrainedModel_byEp_folder)


################################################################################################
####################################### MODELS STRUCTURE #######################################
################################################################################################
# Input BatchNormalization for each PoI size
#                           subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_inputBNorms =   [   1,      0,      0,      0,      0,      0]
###################### MULTI CONVOLUTIONAL-SIZE CONVOLUTION ######################
# Convolutional nodes
# matrix showing number of nodes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_NoConvNodes =   [   [   512,    256,    128,    64,     0,      0], # subModel0
                            [   0,      0,      0,      0,      0,      0], # subModel1
                            [   0,      0,      0,      0,      0,      0], # subModel2
                            [   0,      0,      0,      0,      0,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]] # subModel5
# Convolutional filter sizes
# matrix showing filter sizes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convKernelSizes = [ [   3,      3,      3,      3,      0,      0], # subModel0
                            [   0,      0,      0,      0,      0,      0], # subModel1
                            [   0,      0,      0,      0,      0,      0], # subModel2
                            [   0,      0,      0,      0,      0,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]] # subModel5

###############################################
# Pooling size in convolutional layers
# matrix showing MaxPooling sizes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convPoolSizes = [   [   2,      2,      2,      2,      0,      0], # subModel0
                            [   0,      0,      0,      0,      0,      0], # subModel1
                            [   0,      0,      0,      0,      0,      0], # subModel2
                            [   0,      0,      0,      0,      0,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]] # subModel5
# Pooling stride in convolutional layers
# matrix showing MaxPooling strike in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convPoolStrides = [ [   3,      3,      3,      3,      0,      0], # subModel0
                            [   0,      0,      0,      0,      0,      0], # subModel1
                            [   0,      0,      0,      0,      0,      0], # subModel2
                            [   0,      0,      0,      0,      0,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]] # subModel5
# BatchNormalization in convolutional layers
# matrix showing BatchNormalization condition in each convolutional layer in each PoI length
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_convBNorms = [  [   1,      1,      1,      1,      0,      0], # subModel0
                        [   0,      0,      0,      0,      0,      0], # subModel1
                        [   0,      0,      0,      0,      0,      0], # subModel2
                        [   0,      0,      0,      0,      0,      0], # subModel3
                        [   0,      0,      0,      0,      0,      0], # subModel4
                        [   0,      0,      0,      0,      0,      0]] # subModel5
# Dropout in convolutional layers
# matrix showing Dropout value in each convolutional layer in each PoI length
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_convDrops = [   [   0,      0,      0,      0,      0,      0], # subModel0
                        [   0,      0,      0,      0,      0,      0], # subModel1
                        [   0,      0,      0,      0,      0,      0], # subModel2
                        [   0,      0,      0,      0,      0,      0], # subModel3
                        [   0,      0,      0,      0,      0,      0], # subModel4
                        [   0,      0,      0,      0,      0,      0]] # subModel5

###################### MULTI CONVOLUTIONAL-SIZE FULLY-CONNECTED ######################
# Flatten Convolutional feature map before Fully connected
#                           subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_convFeatFlat = [    1,      0,      0,      0,      0,      0]
# Fully-connected for convolutional value before adding Plaintext
# matrix showing fully-connected condition before adding Plaintext
#                   layer0  layer1  layer2  layer3  layer4  layer5
subMods_FCs = [ [   1024,   512,    256,    128,    0,      0], # subModel0
                [   0,      0,      0,      0,      0,      0], # subModel1
                [   0,      0,      0,      0,      0,      0], # subModel2
                [   0,      0,      3,      0,      0,      0], # subModel3
                [   0,      0,      0,      0,      0,      0], # subModel4
                [   0,      0,      0,      0,      0,      0]] # subModel5
# BatchNormalization for fully-connected of convolutional value before adding Plaintext
# matrix showing BatchNormalization for fully-connected condition before adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_FC_BNorms = [   [   1,      1,      1,      1,      0,      0], # subModel0
                        [   0,      0,      0,      0,      0,      0], # subModel1
                        [   0,      0,      0,      0,      0,      0], # subModel2
                        [   0,      0,      3,      0,      0,      0], # subModel3
                        [   0,      0,      0,      0,      0,      0], # subModel4
                        [   0,      0,      0,      0,      0,      0]] # subModel5
# Dropout for fully-connected of convolutional value before adding Plaintext
# matrix showing Dropout for fully-connected condition before adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_FC_Drops = [    [   0.2,    0,      0.2,    0,      0,      0], # subModel0
                        [   0,      0,      0,      0,      0,      0], # subModel1
                        [   0,      0,      0,      0,      0,      0], # subModel2
                        [   0,      0,      3,      0,      0,      0], # subModel3
                        [   0,      0,      0,      0,      0,      0], # subModel4
                        [   0,      0,      0,      0,      0,      0]] # subModel5

###################### MULTI_CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENSION ######################
# Plaintext adding here
if xType == 'wave':
    noBPbranch = 0
    #                       sub0    sub1    sub2    sub3    sub4    sub5
    subMods_Pext =  [[  0,      0,      0,      0,      0,      0]]  # conv1D branch 0
elif xType == 'wavebp0':
    subMods_Pext =  [[  0,      0,      0,      0,      0,      0]]  # conv1D branch 0
elif xType == 'wavebp1':
    subMods_Pext =  [[  0,      0,      0,      0,      0,      0]]  # conv1D branch 0
elif xType == 'wavebp01':
    noBPbranch = 2
    subMods_Pext =  [[  1,      1,      0,      0,      0,      0]]  # conv1D branch 0
elif xType == 'wavebp01next0':
    subMods_Pext =  [[  0,      0,      0,      0,      0,      0]]  # conv1D branch 0
elif xType == 'wavebp01next01':
    noBPbranch = 4
    subMods_Pext =  [[  1,      1,      1,      1,      0,      0]]  # conv1D branch 0


###################### (MULTI CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENDED) FULLY-CONNECTED ######################
# Fully-connected for convolutional value after adding Plaintext
# matrix showing fully-connected condition after adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FCs = [[   [   1024,   1024,   512,    256,    128,    0], # subModel0
                        [   1024,   1024,   512,    256,    128,    0], # subModel1
                        [   1024,   1024,   512,    256,    128,    0], # subModel2
                        [   1024,   1024,   512,    256,    128,    0], # subModel3
                        [   0,      0,      0,      0,      0,      0], # subModel4
                        [   0,      0,      0,      0,      0,      0]]]    # subModel5
# BatchNormalization for fully-connected of convolutional value after adding Plaintext
# matrix showing BatchNormalization for fully-connected condition after adding Plaintext
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FC_BNorms = [[ [   1,      1,      1,      1,      1,      0], # subModel0
                            [   1,      1,      1,      1,      1,      0], # subModel1
                            [   1,      1,      1,      1,      1,      0], # subModel2
                            [   1,      1,      1,      1,      1,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]]]    # subModel5
# Dropout for fully-connected of convolutional value after adding Plaintext
# matrix showing Dropout for fully-connected condition after adding Plaintext
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FC_Drops = [[  [   0.2,    0.2,    0,      0.1,    0,      0], # subModel0
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel1
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel2
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]]]    # subModel5

# Softmax for each sub-model if available
#                               subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_classification =    [[  0,      0,      0,      0,      0,      0]]

if xType == 'wave':
    subMods_join =  [   0]
else:
    subMods_join =  [   1]  

subMods_join_FCs =  [[  1024,   1024,   512,    256,    128,    0]]
# BatchNormalization for fully-connected of convolutional value after joining PoIs
# matrix showing BatchNormalization for fully-connected condition after joining PoIs
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_join_FC_BNorms =    [[  1,      1,      1,      1,      1,      0]]
# Dropout for fully-connected of convolutional value after joining PoIs
# matrix showing Dropout for fully-connected condition after joining PoIs
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_join_FC_Drops = [[  0.2,        0.2,        0,      0.1,        0,      0]]

# Softmax for joined-model if available
subMods_join_classification =   [   1]

################################################################################################
##################################### MODELS STRUCTURE END #####################################
################################################################################################


def check_file_exists(file_path):
    if os.path.exists(file_path) == False:
        print("Error: provided file path '%s' does not exist!" % file_path)
        sys.exit(-1)
    return

def listDirWithExt(directory, extension):
    return (f for f in os.listdir(directory) if f.endswith('.' + extension))

def subModels_gen(xType,noConv1Dbranch, noBPbranch, noLayers, tracelen, NumBPinput, MLmodel_detail, modelLogFolder, logFilename, classes=noClasses):
    input_trace_shape = (tracelen,1)
    input_Ptext1hot_shape = (NumBPinput,1)
    m_traceinputs = []
    m_Ptextinputs = []
    inputs = []
    for dataNo in range(noConv1Dbranch):
        trace_input = Input(shape=input_trace_shape)    #trace_input need to be generated many times to shows that they are different inputs
        m_traceinputs.append(trace_input)
        #inputs.append(trace_input)
    for dataNo in range(noBPbranch):
        Ptext_input = Input(shape=input_Ptext1hot_shape)    #Ptext_input need to be generated many times to shows that they are different inputs
        m_Ptextinputs.append(Ptext_input)
        #inputs.append(Ptext_input)
    if xType == 'wave':
        inputs = [m_traceinputs]
    else:
        inputs = [m_traceinputs, m_Ptextinputs]

    subModels = list()
    for conv1DbranchNo in range(0, noConv1Dbranch):
        print('\nconv1DbranchNo =', conv1DbranchNo)
        # Convolutional filter for input trace
        for layerNo in range(noLayers):
            # Conv_PtextExtenssion_Block_size*y*
            # *x* *y*: interesting point size *y*, convolutional layer *x*; start from 0
            if layerNo == 0:
                if (subMods_inputBNorms[conv1DbranchNo]!=0):
                    conv1Dbranch_out = BatchNormalization(trainable=True, name='Input_'+'subModels'+str(conv1DbranchNo)+'_BNorm'+str(layerNo))(inputs[0][conv1DbranchNo])
                    #conv1Dbranch_out = BatchNormalization()(inputs[conv1DbranchNo])
                else:
                    conv1Dbranch_out = inputs[0][conv1DbranchNo]
                    #conv1Dbranch_out = inputs[conv1DbranchNo]
                print('conv1Dbranch_out(trace_input).shape, conv1DbranchNo, layerNo =', conv1Dbranch_out.shape, conv1DbranchNo, layerNo)

            # ConvBlock_size*y*_layer*x*
            if (subMods_NoConvNodes[conv1DbranchNo][layerNo]!=0 and subMods_convKernelSizes[conv1DbranchNo][layerNo]!=0 and subMods_convPoolSizes[conv1DbranchNo][layerNo]!=0 and subMods_convPoolStrides[conv1DbranchNo][layerNo]!=0):
                #print('subMods_NoConvNodes[',conv1DbranchNo,'][',layerNo,'] =', subMods_NoConvNodes[conv1DbranchNo][layerNo])
                conv1Dbranch_out = Conv1D(subMods_NoConvNodes[conv1DbranchNo][layerNo], subMods_convKernelSizes[conv1DbranchNo][layerNo], activation='relu', padding='same', name='ConvBlock_'+'subModels'+str(conv1DbranchNo)+'_conv'+str(layerNo)+'_'+str(subMods_NoConvNodes[conv1DbranchNo][layerNo])+'nodes_sz'+str(subMods_convKernelSizes[conv1DbranchNo][layerNo]))(conv1Dbranch_out)
                conv1Dbranch_out = MaxPooling1D(subMods_convPoolSizes[conv1DbranchNo][layerNo], strides=subMods_convPoolStrides[conv1DbranchNo][layerNo], name='ConvBlock_'+'subModels'+str(conv1DbranchNo)+'_pool'+str(layerNo)+'_sz'+str(subMods_convPoolSizes[conv1DbranchNo][layerNo])+'stride'+str(subMods_convPoolStrides[conv1DbranchNo][layerNo]))(conv1Dbranch_out)
            if (subMods_convBNorms[conv1DbranchNo][layerNo]!=0):
                #conv1Dbranch_out = tf.layers.batch_normalization(conv1Dbranch_out, trainable=True, name='ConvBlock_'+'subModels'+str(conv1DbranchNo)+'_BNorm'+str(layerNo))
                conv1Dbranch_out = BatchNormalization(trainable=True, name='ConvBlock_'+'subModels'+str(conv1DbranchNo)+'_BNorm'+str(layerNo))(conv1Dbranch_out)
            if (subMods_convDrops[conv1DbranchNo][layerNo]!=0):
                conv1Dbranch_out = Dropout(subMods_convDrops[conv1DbranchNo][layerNo], name='ConvBlock_'+'subModels'+str(conv1DbranchNo)+'_drop'+str(layerNo)+'_'+str(subMods_convDrops[conv1DbranchNo][layerNo]))(conv1Dbranch_out)
                print('conv1Dbranch_out(Conv1D).shape, conv1DbranchNo, layerNo =', conv1Dbranch_out.shape, conv1DbranchNo, layerNo)
        print('conv1Dbranch_out(Conv1D).shape, conv1DbranchNo =', conv1Dbranch_out.shape, conv1DbranchNo, '\n')

        # Fully connected layers for convolved trace
        for layerNo in range(noLayers):
            # FC_PoI_size*y*_layer*x*
            if ((layerNo==0) and (subMods_convFeatFlat[conv1DbranchNo]!=0)):
                conv1Dbranch_out = Flatten(name='FC_'+'subModels'+str(conv1DbranchNo)+'_flatten')(conv1Dbranch_out)
                print('conv1Dbranch_out(Flatten).shape, conv1DbranchNo, layerNo =', conv1Dbranch_out.shape, conv1DbranchNo, layerNo, '\n')
            if (subMods_FCs[conv1DbranchNo][layerNo]!=0):
                conv1Dbranch_out = Dense(subMods_FCs[conv1DbranchNo][layerNo], activation='relu', name='FC_'+'subModels'+str(conv1DbranchNo)+'_FC'+str(layerNo)+'_'+str(subMods_FCs[conv1DbranchNo][layerNo])+'nodes')(conv1Dbranch_out)
                print('conv1Dbranch_out(Dense).shape, conv1DbranchNo, layerNo =', conv1Dbranch_out.shape, conv1DbranchNo, layerNo)
            if (subMods_FC_BNorms[conv1DbranchNo][layerNo]!=0):
                #conv1Dbranch_out = tf.layers.batch_normalization(conv1Dbranch_out, trainable=True, name='FC_'+'subModels'+str(conv1DbranchNo)+'_BNorm'+str(layerNo))
                conv1Dbranch_out = BatchNormalization(trainable=True, name='FC_'+'subModels'+str(conv1DbranchNo)+'_BNorm'+str(layerNo))(conv1Dbranch_out)
            if (subMods_FC_Drops[conv1DbranchNo][layerNo]!=0):
                conv1Dbranch_out = Dropout(subMods_FC_Drops[conv1DbranchNo][layerNo], name='FC_'+'subModels'+str(conv1DbranchNo)+'_drop'+str(layerNo)+'_'+str(subMods_FC_Drops[conv1DbranchNo][layerNo]))(conv1Dbranch_out)
                print('conv1Dbranch_out(Drops).shape, conv1DbranchNo, layerNo =', conv1Dbranch_out.shape, conv1DbranchNo, layerNo, )
        print('conv1Dbranch_out(FlattenDenseDrops).shape, conv1DbranchNo =', conv1Dbranch_out.shape, conv1DbranchNo, '\n')

        BPbranchOuts_list = []
        for BPbranchNo in range(noBPbranch):
            # PtextExt_size*y*
            #conv1Dbranch_out = Flatten(name='beforePext_'+'subModels'+str(BPbranchNo)+'_flatten')(conv1Dbranch_out)
            print('BPbranchNo, noConv1Dbranch+BPbranchNo =', BPbranchNo, noConv1Dbranch+BPbranchNo)
            if (subMods_Pext[conv1DbranchNo][BPbranchNo]!=0):
                print('Check zero')
                Ptext_flatten = Flatten(name='flatten_Ptext1hot'+str(BPbranchNo))(inputs[1][BPbranchNo])
                #Ptext_flatten = Flatten(name='flatten_Ptext1hot'+str(BPbranchNo))(inputs[noConv1Dbranch+BPbranchNo])
                print('conv1Dbranch_out.shape, Ptext_flatten.shape =', conv1Dbranch_out.shape, Ptext_flatten.shape)
                BPbranchOut = Concatenate()([conv1Dbranch_out, Ptext_flatten])
                print('BPbranchOut(conv1D+Ptex).shape =', BPbranchOut.shape)
            for layerNo in range(noLayers):
                # FC_Pext_size*y*_layer*x*
                if (subMods_Pext_FCs[conv1DbranchNo][BPbranchNo][layerNo]!=0):
                    BPbranchOut = Dense(subMods_Pext_FCs[conv1DbranchNo][BPbranchNo][layerNo], activation='relu', name='FC_Pext_subModels'+str(conv1DbranchNo)+'_BPbranch'+str(BPbranchNo)+'_FC'+str(layerNo)+'_'+str(subMods_Pext_FCs[conv1DbranchNo][BPbranchNo][layerNo])+'nodes')(BPbranchOut)
                    print('BPbranchOut(conv1D+Ptex - Dense).shape, BPbranchNo, layerNo =', BPbranchOut.shape, BPbranchNo, layerNo)
                if (subMods_Pext_FC_BNorms[conv1DbranchNo][BPbranchNo][layerNo]!=0):
                    BPbranchOut = BatchNormalization(trainable=True, name='subModels'+str(conv1DbranchNo)+'BPbranch'+str(BPbranchNo)+'_BNorm'+str(layerNo))(BPbranchOut)
                    print('BPbranchOut(conv1D+Ptex - BN).shape, BPbranchNo, layerNo =', BPbranchOut.shape, BPbranchNo, layerNo)
                if (subMods_Pext_FC_Drops[conv1DbranchNo][BPbranchNo][layerNo]!=0):
                    BPbranchOut = Dropout(subMods_Pext_FC_Drops[conv1DbranchNo][BPbranchNo][layerNo], name='FC_Pext_subModels'+str(conv1DbranchNo)+'_BPbranch'+str(BPbranchNo)+'_drop'+str(layerNo)+'_'+str(subMods_Pext_FC_Drops[conv1DbranchNo][BPbranchNo][layerNo]))(BPbranchOut)
                    print('BPbranchOut(conv1D+Ptex - Drop).shape, BPbranchNo, layerNo =', BPbranchOut.shape, BPbranchNo, layerNo)
            print('BPbranchOut(Conv1D+Ptex - DenseDrop).shape, BPbranchNo =', BPbranchOut.shape, BPbranchNo)

            ###################### CLASSIFICATION (SOFTMAX) ######################
            if (subMods_classification[conv1DbranchNo][BPbranchNo]!=0):
                BPbranchOut = Dense(classes, activation='softmax', name='Predictions_subModels'+str(conv1DbranchNo)+'_BPbranch'+str(BPbranchNo))(BPbranchOut)
                print('BPbranchOut(conv1D+Ptex - Class).shape =', BPbranchOut.shape, '\n')
            BPbranchOuts_list.append(BPbranchOut)
            print('*** len(BPbranchOuts_list) =', len(BPbranchOuts_list), '\n')

        if subMods_join[conv1DbranchNo] != 0:
            BPbranchOuts_joined = Concatenate()(BPbranchOuts_list)
            print('BPbranchOuts_joined.shape =', BPbranchOuts_joined.shape)
        else:
            BPbranchOuts_joined = conv1Dbranch_out # this will not work

        print('noLayers =', noLayers)
        print('subMods_join_FCs[',conv1DbranchNo,'] =', subMods_join_FCs[conv1DbranchNo])
        for layerNo in range(noLayers):
            # FC_Pext_size*y*_layer*x*
            if (subMods_join_FCs[conv1DbranchNo][layerNo]!=0):
                BPbranchOuts_joined = Dense(subMods_join_FCs[conv1DbranchNo][layerNo], activation='relu', name='subMods_join_FCs'+str(conv1DbranchNo)+'_'+str(layerNo)+'_'+str(subMods_join_FCs[conv1DbranchNo][layerNo])+'nodes')(BPbranchOuts_joined)
                print('BPbranchOuts_joined(Dense).shape, conv1DbranchNo, layerNo =', BPbranchOuts_joined.shape, conv1DbranchNo, layerNo)
            if (subMods_join_FC_BNorms[conv1DbranchNo][layerNo]!=0):
                #BPbranchOuts_joined = tf.layers.batch_normalization(BPbranchOuts_joined, trainable=True, name='subMods_join_FCs_BNorm'+str(conv1DbranchNo)+'_'+str(layerNo))
                BPbranchOuts_joined = BatchNormalization(trainable=True, name='subMods_join_FCs_BNorm'+str(conv1DbranchNo)+'_'+str(layerNo))(BPbranchOuts_joined)
                print('BPbranchOuts_joined(BN).shape, conv1DbranchNo, layerNo =', BPbranchOuts_joined.shape, conv1DbranchNo, layerNo)
            if (subMods_join_FC_Drops[conv1DbranchNo][layerNo]!=0):
                BPbranchOuts_joined = Dropout(subMods_join_FC_Drops[conv1DbranchNo][layerNo], name='subMods_join_FCs_drop'+str(conv1DbranchNo)+'_'+str(layerNo)+'_'+str(subMods_join_FC_Drops[conv1DbranchNo][layerNo]))(BPbranchOuts_joined)
                print('BPbranchOuts_joined(Drop).shape, conv1DbranchNo, layerNo =', BPbranchOuts_joined.shape, conv1DbranchNo, layerNo)


        ###################### CLASSIFICATION (SOFTMAX) ######################
        if (subMods_join_classification[conv1DbranchNo]!=0):
            BPbranchOuts_joined = Dense(classes, activation='softmax', name='Predictions_joinSModels')(BPbranchOuts_joined)
            print('BPbranchOuts_joined(classification).shape =', BPbranchOuts_joined.shape)

        sModel = Model(inputs, BPbranchOuts_joined, name=MLmodel_detail)
        sModel.summary()
        # plot graph of ensemble
        #plot_model(sModel, show_shapes=True, to_file=modelLogFolder + logFilename + '_modelGraph.png')
        optimizer = RMSprop(lr=0.00001)
        sModel.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        #allBranchOuts_list.append(sModel)
    #return allBranchOuts_list
    return sModel

# make a prediction with a stacked model
# https://machinelearningmastery.com/stacking-ensemble-for-deep-learning-neural-networks/
def predict_stacked_model(model, inputX):
    # prepare input data
    X = [inputX for _ in range(len(model.input))]
    # make prediction
    return model.predict(X, verbose=0)

def load_sca_model(model_file):
    check_file_exists(model_file)
    try:
            model = load_model(model_file)
    except:
        print("Error: can't load Keras model file '%s'" % model_file)
        sys.exit(-1)
    return model

####### THESE FUNCTIONS ARE SPECIALIZED FOR KYBER    #######
####### Loading traces and metadata from file ############
#def load_meta_trace_file(database_file, sKeyNo, load_metadata=False):
def load_meta_trace_files(data_path):
    start_trace = 200000
    end_trace = 250000
    data = np.load(data_path)
    trace_profiling = data['data']
    bp_profiling = data['bp']
    skpv_profiling = data['label']

    return (trace_profiling[start_trace:end_trace], bp_profiling[:,start_trace:end_trace], skpv_profiling[start_trace:end_trace])#, fqmul_profiling)

#### Converting traces and metadata to training format
# inputs = [[list of traces], [list of bp]]
#def create_training_data_form(database_folder_train_file, sKeyNo, trainPortion, xType, yType):
def create_training_data_form(data_path, sKeyNo, trainPortion, xType, yType):
    #(trace_profiling, bp_profiling, skpv_profiling, fqmul_profiling) = load_meta_trace_file(database_folder_train_file, sKeyNo)
    #(trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file(database_folder_train_file, sKeyNo)
    (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_files(data_path)

    Reshaped_trace_profiling= trace_profiling.reshape((trace_profiling.shape[0], trace_profiling.shape[1], 1))
    dataSize = Reshaped_trace_profiling.shape[0]
    trainSize = math.floor(dataSize * trainPortion)
    valLoc = trainSize
    if valLoc == dataSize:
        valLoc = dataSize - 1

    lineNo = list(range(0, bp_profiling[0].shape[0]))
    #bp0_1hot_profiling = np.zeros((bp_profiling[0].shape[0], NumBPinput)).astype(np.int)
    print('2')
    print((bp_profiling[0].shape[0], NumBPinput))
    bp0_1hot_profiling = np.zeros((bp_profiling[0].shape[0], NumBPinput)).astype(int)
    print('bp0_1hot_profiling.shape =', bp0_1hot_profiling.shape, '                bp_profiling[0] =', bp_profiling[0])
    bp0_1hot_profiling[lineNo,bp_profiling[0]] = 1
    Reshaped_bp0_1hot_profiling = bp0_1hot_profiling.reshape((bp0_1hot_profiling.shape[0], NumBPinput, 1))
    #print('Reshaped_bp0_1hot_profiling.shape = ', Reshaped_bp0_1hot_profiling.shape)
   
    lineNo = list(range(0, bp_profiling[1].shape[0]))
    #bp1_1hot_profiling = np.zeros((bp_profiling[1].shape[0], NumBPinput)).astype(np.int)
    bp1_1hot_profiling = np.zeros((bp_profiling[1].shape[0], NumBPinput)).astype(int)
    print('bp1_1hot_profiling.shape =', bp1_1hot_profiling.shape, '                bp_profiling[1] =', bp_profiling[1])
    #input()
    bp1_1hot_profiling[lineNo,bp_profiling[1]] = 1
    Reshaped_bp1_1hot_profiling = bp1_1hot_profiling.reshape((bp1_1hot_profiling.shape[0], NumBPinput, 1))
    
    lineNo = list(range(0, bp_profiling[2].shape[0]))
    #bp0_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[2].shape[0], NumBPinput)).astype(np.int)
    bp0_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[2].shape[0], NumBPinput)).astype(int)
    print('bp0_1hot_profiling_next_sKeyNo.shape =', bp0_1hot_profiling_next_sKeyNo.shape, '    bp_profiling[2] =', bp_profiling[2])
    bp0_1hot_profiling_next_sKeyNo[lineNo,bp_profiling[2]] = 1
    Reshaped_bp0_1hot_profiling_next_sKeyNo = bp0_1hot_profiling_next_sKeyNo.reshape((bp0_1hot_profiling_next_sKeyNo.shape[0], NumBPinput, 1))
    
    lineNo = list(range(0, bp_profiling[3].shape[0]))
    #bp1_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[3].shape[0], NumBPinput)).astype(np.int)
    bp1_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[3].shape[0], NumBPinput)).astype(int)
    print('bp1_1hot_profiling_next_sKeyNo.shape =', bp1_1hot_profiling_next_sKeyNo.shape, '    bp_profiling[3] =', bp_profiling[3])
    #input()
    bp1_1hot_profiling_next_sKeyNo[lineNo,bp_profiling[3]] = 1
    Reshaped_bp1_1hot_profiling_next_sKeyNo = bp1_1hot_profiling_next_sKeyNo.reshape((bp1_1hot_profiling_next_sKeyNo.shape[0], NumBPinput, 1))
    
    #y_train_fqmul0 = to_categorical(fqmul_profiling[0], num_classes=NumFQMULclasses)
    #y_train_fqmul1 = to_categorical(fqmul_profiling[1], num_classes=NumFQMULclasses)
    y_train_skpv = to_categorical(skpv_profiling, num_classes=NumSKPVclasses)

    #xTrain_wave = [Reshaped_trace_profiling[0:trainSize,:,:]]
    #xTrain_wavebp0 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:]]]
    #xTrain_wavebp1 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp1_1hot_profiling[0:trainSize,:]]]
    #xTrain_wavebp01 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]]
    #xTrain_wavebp01next0 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:]]]
    #xTrain_wavebp01next01 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[0:trainSize,:]]]
    #xTrain_wave = Reshaped_trace_profiling[0:trainSize,:,:]
    #xTrain_wavebp0 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp0_1hot_profiling[0:trainSize,:]]
    #xTrain_wavebp1 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]
    #xTrain_wavebp01 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]
    #yTrain_fqmul0 = y_train_fqmul0[0:trainSize,:]
    #yTrain_fqmul1 = y_train_fqmul1[0:trainSize,:]
    #yTrain_skpv = y_train_skpv[0:trainSize,:]

    #xVal_wave = [Reshaped_trace_profiling[valLoc:,:,:]]
    #xVal_wavebp0 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:]]]
    #xVal_wavebp1 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp1_1hot_profiling[valLoc:,:]]]
    #xVal_wavebp01 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]]
    #xVal_wavebp01next0 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:]]]
    #xVal_wavebp01next01 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[valLoc:,:]]]
    #xVal_wave = Reshaped_trace_profiling[valLoc:,:,:]
    #xVal_wavebp0 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp0_1hot_profiling[valLoc:,:]]
    #xVal_wavebp1 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]
    #xVal_wavebp01 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]
    #yVal_fqmul0 = y_train_fqmul0[valLoc:,:]
    #yVal_fqmul1 = y_train_fqmul1[valLoc:,:]
    #yVal_skpv = y_train_skpv[valLoc:,:]

    # Input data creation
    if xType == 'wave':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]]]#xTrain_wave
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]]]#xVal_wave
    elif xType == 'wavebp0':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:]]]#xTrain_wavebp0
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:]]]#xVal_wavebp0
    elif xType == 'wavebp1':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp1_1hot_profiling[0:trainSize,:]]]#xTrain_wavebp1
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp1_1hot_profiling[valLoc:,:]]]#xVal_wavebp1
    elif xType == 'wavebp01':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]]#xTrain_wavebp01
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]]#xVal_wavebp01
    elif xType == 'wavebp01next0':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:]]]#xTrain_wavebp01next0
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:]]]#xVal_wavebp01next0
    elif xType == 'wavebp01next01':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[0:trainSize,:]]]#xTrain_wavebp01next01
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[valLoc:,:]]]#xVal_wavebp01next01
        #print('Created xType, len(xTrain) :', xType, len(xTrain))
        #print('len(xTrain[0]) =', len(xTrain[0]), ';    len(xTrain[1]) =', len(xTrain[1]), 'PRESS ENTER TO CONTINUE')
        #print('PRESS ENTER TO CONTINUE')
        #input()
    # Category creation
    if yType == 'fqmul0':
        yTrain = yTrain_fqmul0
        yTrain_value = fqmul_profiling[0][0:trainSize]
        yVal = yVal_fqmul0
        yVal_value = fqmul_profiling[0][valLoc:]
    elif yType == 'fqmul1':
        yTrain = yTrain_fqmul1
        yTrain_value = fqmul_profiling[1][0:trainSize]
        yVal = yVal_fqmul1
        yVal_value = fqmul_profiling[1][valLoc:]
    elif yType == 'skpv':
        yTrain = y_train_skpv[0:trainSize,:]#yTrain_skpv
        yTrain_value = skpv_profiling[0:trainSize]
        yVal = y_train_skpv[valLoc:,:]#yVal_skpv
        yVal_value = skpv_profiling[valLoc:]

    return xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value

def get_subset(idxs, X_train, y_train):
    profile_data = X_train[0][0]
    print(len(profile_data))
    profile_bp0 = X_train[1][0]
    profile_bp1 = X_train[1][1]
    sub_X_train = [[profile_data[idxs,:,:]], [profile_bp0[idxs,:], profile_bp1[idxs,:]]]
    sub_y_train = y_train[idxs,:]

    return sub_X_train, sub_y_train

def train_model_multiEpochs(xType, database_folder_train, modelLogFolder, logTrainedModel_byFile_folder, logTrainedModel_byEp_folder, logFilename, MLmodel_detail, sKeyNo, class_weight, period, maxEpochs, train_batch_size):
    #model = subModels_gen(xType, noConv1Dbranch, noBPbranch, noLayers, tracelen, NumBPinput, MLmodel_detail, modelLogFolder, logFilename, classes=noClasses)
    model = load_model(trained_model_path)
    gc.collect()
    model.summary()
    print('load success')
    
    xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value = create_training_data_form(data_path,sKeyNo, trainPortion, xType, yType)
    #model.save(logTrainedModel_byFile_folder+logFilename+'_ep'+str(epochNo).zfill(4)+'_file'+str(fileNo).zfill(2)+'.h5')
    # instruction in pair for saving each period for each file before loading back #########################
    
    print(yTrain.shape)
    preds = model.predict(xTrain)
    print(preds.shape)
    idx = 0
    #Get samples that below the mean probability
    pred_probs = []
    for sample in preds:
        label_idx = np.where(yTrain[idx]==1)
        pred_prob = preds[idx][label_idx]
        idx += 1
        pred_probs.append(pred_prob)


    mean_prob = np.mean(pred_probs)

    sample_idxs = np.where(pred_probs<mean_prob)

    sub_xTrain, sub_yTrain = get_subset(sample_idxs[0].astype(int), xTrain, yTrain)
    print('---------------------')
    print(len(sub_xTrain))
    print(len(sub_xTrain[0][0]))
    print(len(sub_xTrain[1][0]))
    #Get the data on the GPU
    #exit()
    
    split_n = int(len(xTrain[0][0]) / 100000 + 1)
    split_portion = int(len(xTrain[0][0]) / split_n)
    print(split_portion)
       

    save_model_name = os.path.join(database_folder_train, 'cp-{epoch:04d}.h5')
    csv_logger = CSVLogger(filename=database_folder_train+'/log.csv', append=True, separator=';')
    save_model = ModelCheckpoint(save_model_name, period=period)
    callbacks=[csv_logger, save_model]
    # Get the input layer shape
    input_layer_shape = model.get_layer(index=0).input_shape
    print('input_layer_shape =', input_layer_shape)
    print('input_layer_shape[0][1] =', input_layer_shape[0][1])
    print('Number of sample points per trace: len(xTrain[0][0][0]) =', len(xTrain[0][0][0]))
    print('Number of traces: len(xTrain[0][0]) =', len(xTrain[0][0]))
    #print('Number of bp: len(xTrain[1]) =', len(xTrain[1]))
    #print('Press Enter')
    #input()
    # Sanity check
    if input_layer_shape[0][1] != len(xTrain[0][0][0]):
        print("Error: model input shape %d instead of %d is not expected ..." % (input_layer_shape[0][1], len(xTrain[0][0])))
        sys.exit(-1)
    # instruction in pair for saving each period for each file before loading back #########################
    #history = model.fit(x=xTrain, y=yTrain, batch_size=train_batch_size, verbose = 1, epochs=maxEpochs, callbacks=callbacks, class_weight=class_weight, validation_data=(xVal,yVal))
    #history = model.fit(x=sub_xTrain, y=sub_yTrain, batch_size=train_batch_size, verbose = 1, epochs=maxEpochs, callbacks=callbacks, class_weight=class_weight, validation_data=(xVal,yVal))
    '''
    for i in range(split_n):
        start = i*split_portion
        end = (i+1)*split_portion
        print('----')
        print(start)
        print(end)
        with tf.device("CPU"):
            train = tf.data.Dataset.from_tensor_slices(({"input_1": xTrain[0][0][start:end], "input_2": xTrain[1][0][start:end], "input_3": xTrain[1][1][start:end]}, yTrain[start:end])).shuffle(4*64).batch(64)
            val = tf.data.Dataset.from_tensor_slices(({"input_1": xVal[0][0], "input_2": xVal[1][0], "input_3": xVal[1][1]}, yVal)).shuffle(4*64).batch(64)
 
        history = model.fit(train, batch_size=train_batch_size, verbose = 1, epochs=maxEpochs, callbacks=callbacks, class_weight=class_weight, validation_data=val)
    '''
    with tf.device("CPU"):
        train = tf.data.Dataset.from_tensor_slices(({"input_1": xTrain[0][0], "input_2": xTrain[1][0], "input_3": xTrain[1][1]}, yTrain)).shuffle(4*64).batch(64)
        val = tf.data.Dataset.from_tensor_slices(({"input_1": xVal[0][0], "input_2": xVal[1][0], "input_3": xVal[1][1]}, yVal)).shuffle(4*64).batch(64)
 
    history = model.fit(train, batch_size=train_batch_size, verbose = 1, epochs=maxEpochs, callbacks=callbacks, class_weight=class_weight, validation_data=val)

    model.save(database_folder_train+'/model.h5')

if work == 'train':
    print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
    print('Work =', work)
    classWeights = np.ones(noClasses).astype(int)
    class_weight = dict(enumerate(classWeights))

    train_model_multiEpochs(xType, database_folder_train, modelLogFolder, logTrainedModel_byFile_folder, logTrainedModel_byEp_folder, logFilename, MLmodel_detail, sKeyNo, class_weight, period, maxEpochs, train_batch_size)