import os.path
import sys
import h5py
import math
import gc
import numpy as np
#from numba import cuda
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Flatten, Dense, Input, Conv1D, MaxPooling1D, ReLU, Dropout, Concatenate, Multiply, BatchNormalization #, AveragePooling1D, Add, GlobalAveragePooling1D, GlobalMaxPooling1D
from tensorflow.keras.utils import plot_model	#, get_source_inputs
from tensorflow.keras.utils import get_file
from tensorflow.keras import backend as K
#from keras_applications.imagenet_utils import decode_predictions
#from keras_applications.imagenet_utils import preprocess_input
#from keras_applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split

# Trace and metadata parameters
bp_range = [0, 3329]
skpv_range = [0, 3328]
fqmul_range = [-1828, 1664]
tracelen = 600
NumFQMULclasses = fqmul_range[1] - fqmul_range[0] + 1;	# number of classes for fqmul(skpv, bp)
NumSKPVclasses = skpv_range[1] - skpv_range[0] + 1;		# number of classes for skpv
NumBPinput = bp_range[1] - bp_range[0] + 1;				# number of input for bp (ciphertext)
noClasses = NumSKPVclasses
noHypoKeys = NumSKPVclasses
sKeyNo = 0  # Note: sKeyNo is in range 0 to 3 and which subkeys are they are decided by code in m4 (NOT by code in PC)
work = 'train' #'train'  #'attack'
training_file_list = ['Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data000000to099999_600samples.h5',\
'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data100000to199999_600samples.h5']
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data200000to299999_600samples.h5']#,\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data300000to399999_600samples.h5',\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data400000to499999_600samples.h5']

nruns_default = 20
maxtrc_default = 200
testPortion = 1
attack_byModel_epNo = 232

#print("len(sys.argv) =", len(sys.argv))
#print("sys.argv[0] =", sys.argv[0])
#print("sys.argv[1] =", sys.argv[1])
#print("sys.argv[2] =", sys.argv[2])
if len(sys.argv) > 1:
    work = sys.argv[1]
    attack_byModel_epNo = int(sys.argv[2])
print("work =", work)
print("attack_byModel_epNo =", attack_byModel_epNo)
#input()

# training parameters
train_batch_size = 100#100#150#200#250#500#640 #80 for mars45 #170 for mars56
period = 8
maxEpochs = 1536#3072#2048#1536#1280#1024#512#256
attack_byModel_fileNo = int(attack_byModel_epNo/period)

#model hyper-parameters
noConv1Dbranch = 1
noLayers = 6	# if newly train
noClassificationLayer = 1
GPU_clear = True    # False

# training data type
xType = 'wave'  #'wave' #'wavebp0' #'wavebp1' #'wavebp01' #'wavebp01next0' #'wavebp01next01'
yType = 'skpv'	  #'fqmul0' #'fqmul1' #'skpv' 
trainPortion = 0.8

# Database and logs for model and training progress (epochs)
attackModel = 'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1'
device = 'm4_CWLite'
attackModel_dev = attackModel + '_' + device
attackModel_dev_folder = '../' + attackModel_dev + '/'

MLmodelStruct = '4C4FC_2BP4FC4FC_J4FCSM'
#MLmodel_detail = '3C[512_128_64]_2BP4FC[1024_512_256_128]4FC[1024_512_256_128]_J4FC[1024_512_256_128]SM'
MLmodel_detail = '4C/512_256_128_64/_2BP4FC/1024_512_256_128/4FC/1024_512_256_128/_J4FC/1024_512_256_128/SM'

hyper_ver = 'hy0001010101_skpv0'	#hyper-parameter contains 5 groups: Conv1D, FC for Conv1D, BP0, BP1, FC for joined BPs
#dataFile_train = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'	#'20kDatax25'
dataFile_train_folder = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'	#'20kDatax25'
dataFile_attack = '100kDatax1_test'#'skvp0_0_700points100kDatax1attack' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'	#'20kDatax25'
model_input_type = '_in[[][]]_tf2' #'in[]_tf2' #'[[][]]_tf2'
#data_type = dataFile_train + model_input_type
data_type = '100kDataxN' + str(len(training_file_list)) + model_input_type
database_folder_train = attackModel_dev_folder + attackModel + '_' + dataFile_train_folder + '_h5/'
database_folder_attack = attackModel_dev_folder + attackModel + '_' + dataFile_attack + '_h5/'
logFilename = MLmodelStruct + '_' + hyper_ver
DLmodel_name = logFilename
DLmodel_folder = attackModel_dev_folder + logFilename + '_' + data_type + '/'

modelLogFolder = DLmodel_folder + 'log' + DLmodel_name + '/'
logTrainedModel_byFile_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byDataFile/'
#logTrainedModel_byEp_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byEpoch/'
logTrainedModel_byEp_folder = logTrainedModel_byFile_folder
attackLogFolder = DLmodel_folder + 'log' + DLmodel_name + '_attack/'
if os.path.isdir(DLmodel_folder) == False:
	os.mkdir(DLmodel_folder)
if os.path.isdir(modelLogFolder) == False:
	os.mkdir(modelLogFolder)
if os.path.isdir(logTrainedModel_byFile_folder) == False:
	os.mkdir(logTrainedModel_byFile_folder)
if os.path.isdir(logTrainedModel_byEp_folder) == False:
	os.mkdir(logTrainedModel_byEp_folder)
print('DLmodel_folder =', DLmodel_folder)
print('modelLogFolder =', modelLogFolder)
print('logTrainedModel_byFile_folder =', logTrainedModel_byFile_folder)
print('logTrainedModel_byEp_folder =', logTrainedModel_byEp_folder)


################################################################################################
####################################### MODELS STRUCTURE #######################################
################################################################################################
# Input BatchNormalization for each PoI size
#							subMod0	subMod1	subMod2	subMod3	subMod4	subMod5
subMods_inputBNorms =	[	1,		0,		0,		0,		0,		0]
###################### MULTI CONVOLUTIONAL-SIZE CONVOLUTION ######################
# Convolutional nodes
# matrix showing number of nodes in each convolutional layer in each PoI length
#								layer0	layer1	layer2	layer3	layer4	layer5
subMods_NoConvNodes =	[	[	512,	256,	128,	64,		0,		0],	# subModel0
							[	0,		0,		0,		0,		0,		0],	# subModel1
							[	0,		0,		0,		0,		0,		0],	# subModel2
							[	0,		0,		0,		0,		0,		0],	# subModel3
							[	0,		0,		0,		0,		0,		0],	# subModel4
							[	0,		0,		0,		0,		0,		0]]	# subModel5
# Convolutional filter sizes
# matrix showing filter sizes in each convolutional layer in each PoI length
#								layer0	layer1	layer2	layer3	layer4	layer5
subMods_convKernelSizes = [	[	3,		3,		3,		3,		0,		0],	# subModel0
							[	0,		0,		0,		0,		0,		0],	# subModel1
							[	0,		0,		0,		0,		0,		0],	# subModel2
							[	0,		0,		0,		0,		0,		0],	# subModel3
							[	0,		0,		0,		0,		0,		0],	# subModel4
							[	0,		0,		0,		0,		0,		0]]	# subModel5

###############################################
# Pooling size in convolutional layers
# matrix showing MaxPooling sizes in each convolutional layer in each PoI length
#								layer0	layer1	layer2	layer3	layer4	layer5
subMods_convPoolSizes = [	[	2,		2,		2,		2,		0,		0],	# subModel0
							[	0,		0,		0,		0,		0,		0],	# subModel1
							[	0,		0,		0,		0,		0,		0],	# subModel2
							[	0,		0,		0,		0,		0,		0],	# subModel3
							[	0,		0,		0,		0,		0,		0],	# subModel4
							[	0,		0,		0,		0,		0,		0]]	# subModel5
# Pooling stride in convolutional layers
# matrix showing MaxPooling strike in each convolutional layer in each PoI length
#								layer0	layer1	layer2	layer3	layer4	layer5
subMods_convPoolStrides = [	[	3,		3,		3,		3,		0,		0],	# subModel0
							[	0,		0,		0,		0,		0,		0],	# subModel1
							[	0,		0,		0,		0,		0,		0],	# subModel2
							[	0,		0,		0,		0,		0,		0],	# subModel3
							[	0,		0,		0,		0,		0,		0],	# subModel4
							[	0,		0,		0,		0,		0,		0]]	# subModel5
# BatchNormalization in convolutional layers
# matrix showing BatchNormalization condition in each convolutional layer in each PoI length
#							layer0	layer1	layer2	layer3	layer4	layer5
subMods_convBNorms = [	[	1,		1,		1,		1,		0,		0],	# subModel0
						[	0,		0,		0,		0,		0,		0],	# subModel1
						[	0,		0,		0,		0,		0,		0],	# subModel2
						[	0,		0,		0,		0,		0,		0],	# subModel3
						[	0,		0,		0,		0,		0,		0],	# subModel4
						[	0,		0,		0,		0,		0,		0]]	# subModel5
# Dropout in convolutional layers
# matrix showing Dropout value in each convolutional layer in each PoI length
#							layer0	layer1	layer2	layer3	layer4	layer5
subMods_convDrops = [	[	0,		0,		0,		0,		0,		0],	# subModel0
						[	0,		0,		0,		0,		0,		0],	# subModel1
						[	0,		0,		0,		0,		0,		0],	# subModel2
						[	0,		0,		0,		0,		0,		0],	# subModel3
						[	0,		0,		0,		0,		0,		0],	# subModel4
						[	0,		0,		0,		0,		0,		0]]	# subModel5

###################### MULTI CONVOLUTIONAL-SIZE FULLY-CONNECTED ######################
# Flatten Convolutional feature map before Fully connected
#							subMod0	subMod1	subMod2	subMod3	subMod4	subMod5
subMods_convFeatFlat = [	1,		0,		0,		0,		0,		0]
# Fully-connected for convolutional value before adding Plaintext
# matrix showing fully-connected condition before adding Plaintext
#					layer0	layer1	layer2	layer3	layer4	layer5
subMods_FCs = [	[	1024,	512,	256,	128,	0,		0],	# subModel0
				[	0,		0,		0,		0,		0,		0],	# subModel1
				[	0,		0,		0,		0,		0,		0],	# subModel2
				[	0,		0,		3,		0,		0,		0],	# subModel3
				[	0,		0,		0,		0,		0,		0],	# subModel4
				[	0,		0,		0,		0,		0,		0]]	# subModel5
# BatchNormalization for fully-connected of convolutional value before adding Plaintext
# matrix showing BatchNormalization for fully-connected condition before adding Plaintext
#						layer0	layer1	layer2	layer3	layer4	layer5
subMods_FC_BNorms = [	[	1,		1,		1,		1,		0,		0],	# subModel0
						[	0,		0,		0,		0,		0,		0],	# subModel1
						[	0,		0,		0,		0,		0,		0],	# subModel2
						[	0,		0,		3,		0,		0,		0],	# subModel3
						[	0,		0,		0,		0,		0,		0],	# subModel4
						[	0,		0,		0,		0,		0,		0]]	# subModel5
# Dropout for fully-connected of convolutional value before adding Plaintext
# matrix showing Dropout for fully-connected condition before adding Plaintext
#						layer0	layer1	layer2	layer3	layer4	layer5
subMods_FC_Drops = [	[	0.2,	0,		0.2,	0,		0,		0],	# subModel0
						[	0,		0,		0,		0,		0,		0],	# subModel1
						[	0,		0,		0,		0,		0,		0],	# subModel2
						[	0,		0,		3,		0,		0,		0],	# subModel3
						[	0,		0,		0,		0,		0,		0],	# subModel4
						[	0,		0,		0,		0,		0,		0]]	# subModel5

###################### MULTI_CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENSION ######################
# Plaintext adding here
if xType == 'wave':
	noBPbranch = 0
	#						sub0	sub1	sub2	sub3	sub4	sub5
	subMods_Pext =	[[	0,		0,		0,		0,		0,		0]]	 # conv1D branch 0
elif xType == 'wavebp0':
	subMods_Pext =	[[	0,		0,		0,		0,		0,		0]]	 # conv1D branch 0
elif xType == 'wavebp1':
	subMods_Pext =	[[	0,		0,		0,		0,		0,		0]]	 # conv1D branch 0
elif xType == 'wavebp01':
	noBPbranch = 2
	subMods_Pext =	[[	1,		1,		0,		0,		0,		0]]	 # conv1D branch 0
elif xType == 'wavebp01next0':
	subMods_Pext =	[[	0,		0,		0,		0,		0,		0]]	 # conv1D branch 0
elif xType == 'wavebp01next01':
	noBPbranch = 4
	subMods_Pext =	[[	1,		1,		1,		1,		0,		0]]	 # conv1D branch 0


###################### (MULTI CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENDED) FULLY-CONNECTED ######################
# Fully-connected for convolutional value after adding Plaintext
# matrix showing fully-connected condition after adding Plaintext
#						layer0	layer1	layer2	layer3	layer4	layer5
subMods_Pext_FCs = [[	[	1024,	1024,	512,	256,	128,	0],	# subModel0
						[	1024,	1024,	512,	256,	128,	0],	# subModel1
						[	1024,	1024,	512,	256,	128,	0],	# subModel2
						[	1024,	1024,	512,	256,	128,	0],	# subModel3
						[	0,		0,		0,		0,		0,		0],	# subModel4
						[	0,		0,		0,		0,		0,		0]]]	# subModel5
# BatchNormalization for fully-connected of convolutional value after adding Plaintext
# matrix showing BatchNormalization for fully-connected condition after adding Plaintext
#								layer0	layer1	layer2	layer3	layer4	layer5
subMods_Pext_FC_BNorms = [[	[	1,		1,		1,		1,		1,		0],	# subModel0
							[	1,		1,		1,		1,		1,		0],	# subModel1
							[	1,		1,		1,		1,		1,		0],	# subModel2
							[	1,		1,		1,		1,		1,		0],	# subModel3
							[	0,		0,		0,		0,		0,		0],	# subModel4
							[	0,		0,		0,		0,		0,		0]]]	# subModel5
# Dropout for fully-connected of convolutional value after adding Plaintext
# matrix showing Dropout for fully-connected condition after adding Plaintext
#								layer0	layer1	layer2	layer3	layer4	layer5
subMods_Pext_FC_Drops = [[	[	0.2,	0.2,	0,		0.1,	0,		0],	# subModel0
							[	0.2,	0.2,	0,		0.1,	0,		0],	# subModel1
							[	0.2,	0.2,	0,		0.1,	0,		0],	# subModel2
							[	0.2,	0.2,	0,		0.1,	0,		0],	# subModel3
							[	0,		0,		0,		0,		0,		0],	# subModel4
							[	0,		0,		0,		0,		0,		0]]]	# subModel5

# Softmax for each sub-model if available
#								subMod0	subMod1	subMod2	subMod3	subMod4	subMod5
subMods_classification =	[[	0,		0,		0,		0,		0,		0]]

if xType == 'wave':
	subMods_join =	[	0]
else:
	subMods_join =	[	1]	

subMods_join_FCs =	[[	1024,	1024,	512,	256,	128,	0]]
# BatchNormalization for fully-connected of convolutional value after joining PoIs
# matrix showing BatchNormalization for fully-connected condition after joining PoIs
#							layer0	layer1	layer2	layer3	layer4	layer5
subMods_join_FC_BNorms =	[[	1,		1,		1,		1,		1,		0]]
# Dropout for fully-connected of convolutional value after joining PoIs
# matrix showing Dropout for fully-connected condition after joining PoIs
#							layer0	layer1	layer2	layer3	layer4	layer5
subMods_join_FC_Drops =	[[	0.2,		0.2,		0,		0.1,		0,		0]]

# Softmax for joined-model if available
subMods_join_classification =	[	1]

################################################################################################
##################################### MODELS STRUCTURE END #####################################
################################################################################################


def check_file_exists(file_path):
	if os.path.exists(file_path) == False:
		print("Error: provided file path '%s' does not exist!" % file_path)
		sys.exit(-1)
	return

def listDirWithExt(directory, extension):
	return (f for f in os.listdir(directory) if f.endswith('.' + extension))

def subModels_gen(xType,noConv1Dbranch, noBPbranch, noLayers, tracelen, NumBPinput, MLmodel_detail, modelLogFolder, logFilename, classes=noClasses):
	input_trace_shape = (tracelen,1)
	input_Ptext1hot_shape = (NumBPinput,1)
	m_traceinputs = []
	m_Ptextinputs = []
	inputs = []
	for dataNo in range(noConv1Dbranch):
		trace_input = Input(shape=input_trace_shape)	#trace_input need to be generated many times to shows that they are different inputs
		m_traceinputs.append(trace_input)
		#inputs.append(trace_input)
	for dataNo in range(noBPbranch):
		Ptext_input = Input(shape=input_Ptext1hot_shape)	#Ptext_input need to be generated many times to shows that they are different inputs
		m_Ptextinputs.append(Ptext_input)
		#inputs.append(Ptext_input)
	if xType == 'wave':
		inputs = [m_traceinputs]
	else:
		inputs = [m_traceinputs, m_Ptextinputs]

	subModels = list()
	for conv1DbranchNo in range(0, noConv1Dbranch):
		print('\nconv1DbranchNo =', conv1DbranchNo)
		# Convolutional filter for input trace
		for layerNo in range(noLayers):
			# Conv_PtextExtenssion_Block_size*y*
			# *x* *y*: interesting point size *y*, convolutional layer *x*; start from 0
			if layerNo == 0:
				if (subMods_inputBNorms[conv1DbranchNo]!=0):
					conv1Dbranch_out = BatchNormalization(trainable=True, name='Input_'+'subModels'+str(conv1DbranchNo)+'_BNorm'+str(layerNo))(inputs[0][conv1DbranchNo])
					#conv1Dbranch_out = BatchNormalization()(inputs[conv1DbranchNo])
				else:
					conv1Dbranch_out = inputs[0][conv1DbranchNo]
					#conv1Dbranch_out = inputs[conv1DbranchNo]
				print('conv1Dbranch_out(trace_input).shape, conv1DbranchNo, layerNo =', conv1Dbranch_out.shape, conv1DbranchNo, layerNo)

			# ConvBlock_size*y*_layer*x*
			if (subMods_NoConvNodes[conv1DbranchNo][layerNo]!=0 and subMods_convKernelSizes[conv1DbranchNo][layerNo]!=0 and subMods_convPoolSizes[conv1DbranchNo][layerNo]!=0 and subMods_convPoolStrides[conv1DbranchNo][layerNo]!=0):
				#print('subMods_NoConvNodes[',conv1DbranchNo,'][',layerNo,'] =', subMods_NoConvNodes[conv1DbranchNo][layerNo])
				conv1Dbranch_out = Conv1D(subMods_NoConvNodes[conv1DbranchNo][layerNo], subMods_convKernelSizes[conv1DbranchNo][layerNo], activation='relu', padding='same', name='ConvBlock_'+'subModels'+str(conv1DbranchNo)+'_conv'+str(layerNo)+'_'+str(subMods_NoConvNodes[conv1DbranchNo][layerNo])+'nodes_sz'+str(subMods_convKernelSizes[conv1DbranchNo][layerNo]))(conv1Dbranch_out)
				conv1Dbranch_out = MaxPooling1D(subMods_convPoolSizes[conv1DbranchNo][layerNo], strides=subMods_convPoolStrides[conv1DbranchNo][layerNo], name='ConvBlock_'+'subModels'+str(conv1DbranchNo)+'_pool'+str(layerNo)+'_sz'+str(subMods_convPoolSizes[conv1DbranchNo][layerNo])+'stride'+str(subMods_convPoolStrides[conv1DbranchNo][layerNo]))(conv1Dbranch_out)
			if (subMods_convBNorms[conv1DbranchNo][layerNo]!=0):
				#conv1Dbranch_out = tf.layers.batch_normalization(conv1Dbranch_out, trainable=True, name='ConvBlock_'+'subModels'+str(conv1DbranchNo)+'_BNorm'+str(layerNo))
				conv1Dbranch_out = BatchNormalization(trainable=True, name='ConvBlock_'+'subModels'+str(conv1DbranchNo)+'_BNorm'+str(layerNo))(conv1Dbranch_out)
			if (subMods_convDrops[conv1DbranchNo][layerNo]!=0):
				conv1Dbranch_out = Dropout(subMods_convDrops[conv1DbranchNo][layerNo], name='ConvBlock_'+'subModels'+str(conv1DbranchNo)+'_drop'+str(layerNo)+'_'+str(subMods_convDrops[conv1DbranchNo][layerNo]))(conv1Dbranch_out)
				print('conv1Dbranch_out(Conv1D).shape, conv1DbranchNo, layerNo =', conv1Dbranch_out.shape, conv1DbranchNo, layerNo)
		print('conv1Dbranch_out(Conv1D).shape, conv1DbranchNo =', conv1Dbranch_out.shape, conv1DbranchNo, '\n')

		# Fully connected layers for convolved trace
		for layerNo in range(noLayers):
			# FC_PoI_size*y*_layer*x*
			if ((layerNo==0) and (subMods_convFeatFlat[conv1DbranchNo]!=0)):
				conv1Dbranch_out = Flatten(name='FC_'+'subModels'+str(conv1DbranchNo)+'_flatten')(conv1Dbranch_out)
				print('conv1Dbranch_out(Flatten).shape, conv1DbranchNo, layerNo =', conv1Dbranch_out.shape, conv1DbranchNo, layerNo, '\n')
			if (subMods_FCs[conv1DbranchNo][layerNo]!=0):
				conv1Dbranch_out = Dense(subMods_FCs[conv1DbranchNo][layerNo], activation='relu', name='FC_'+'subModels'+str(conv1DbranchNo)+'_FC'+str(layerNo)+'_'+str(subMods_FCs[conv1DbranchNo][layerNo])+'nodes')(conv1Dbranch_out)
				print('conv1Dbranch_out(Dense).shape, conv1DbranchNo, layerNo =', conv1Dbranch_out.shape, conv1DbranchNo, layerNo)
			if (subMods_FC_BNorms[conv1DbranchNo][layerNo]!=0):
				#conv1Dbranch_out = tf.layers.batch_normalization(conv1Dbranch_out, trainable=True, name='FC_'+'subModels'+str(conv1DbranchNo)+'_BNorm'+str(layerNo))
				conv1Dbranch_out = BatchNormalization(trainable=True, name='FC_'+'subModels'+str(conv1DbranchNo)+'_BNorm'+str(layerNo))(conv1Dbranch_out)
			if (subMods_FC_Drops[conv1DbranchNo][layerNo]!=0):
				conv1Dbranch_out = Dropout(subMods_FC_Drops[conv1DbranchNo][layerNo], name='FC_'+'subModels'+str(conv1DbranchNo)+'_drop'+str(layerNo)+'_'+str(subMods_FC_Drops[conv1DbranchNo][layerNo]))(conv1Dbranch_out)
				print('conv1Dbranch_out(Drops).shape, conv1DbranchNo, layerNo =', conv1Dbranch_out.shape, conv1DbranchNo, layerNo, )
		print('conv1Dbranch_out(FlattenDenseDrops).shape, conv1DbranchNo =', conv1Dbranch_out.shape, conv1DbranchNo, '\n')

		BPbranchOuts_list = []
		for BPbranchNo in range(noBPbranch):
			# PtextExt_size*y*
			#conv1Dbranch_out = Flatten(name='beforePext_'+'subModels'+str(BPbranchNo)+'_flatten')(conv1Dbranch_out)
			print('BPbranchNo, noConv1Dbranch+BPbranchNo =', BPbranchNo, noConv1Dbranch+BPbranchNo)
			if (subMods_Pext[conv1DbranchNo][BPbranchNo]!=0):
				print('Check zero')
				Ptext_flatten = Flatten(name='flatten_Ptext1hot'+str(BPbranchNo))(inputs[1][BPbranchNo])
				#Ptext_flatten = Flatten(name='flatten_Ptext1hot'+str(BPbranchNo))(inputs[noConv1Dbranch+BPbranchNo])
				print('conv1Dbranch_out.shape, Ptext_flatten.shape =', conv1Dbranch_out.shape, Ptext_flatten.shape)
				BPbranchOut = Concatenate()([conv1Dbranch_out, Ptext_flatten])
				print('BPbranchOut(conv1D+Ptex).shape =', BPbranchOut.shape)
			for layerNo in range(noLayers):
				# FC_Pext_size*y*_layer*x*
				if (subMods_Pext_FCs[conv1DbranchNo][BPbranchNo][layerNo]!=0):
					BPbranchOut = Dense(subMods_Pext_FCs[conv1DbranchNo][BPbranchNo][layerNo], activation='relu', name='FC_Pext_subModels'+str(conv1DbranchNo)+'_BPbranch'+str(BPbranchNo)+'_FC'+str(layerNo)+'_'+str(subMods_Pext_FCs[conv1DbranchNo][BPbranchNo][layerNo])+'nodes')(BPbranchOut)
					print('BPbranchOut(conv1D+Ptex - Dense).shape, BPbranchNo, layerNo =', BPbranchOut.shape, BPbranchNo, layerNo)
				if (subMods_Pext_FC_BNorms[conv1DbranchNo][BPbranchNo][layerNo]!=0):
					BPbranchOut = BatchNormalization(trainable=True, name='subModels'+str(conv1DbranchNo)+'BPbranch'+str(BPbranchNo)+'_BNorm'+str(layerNo))(BPbranchOut)
					print('BPbranchOut(conv1D+Ptex - BN).shape, BPbranchNo, layerNo =', BPbranchOut.shape, BPbranchNo, layerNo)
				if (subMods_Pext_FC_Drops[conv1DbranchNo][BPbranchNo][layerNo]!=0):
					BPbranchOut = Dropout(subMods_Pext_FC_Drops[conv1DbranchNo][BPbranchNo][layerNo], name='FC_Pext_subModels'+str(conv1DbranchNo)+'_BPbranch'+str(BPbranchNo)+'_drop'+str(layerNo)+'_'+str(subMods_Pext_FC_Drops[conv1DbranchNo][BPbranchNo][layerNo]))(BPbranchOut)
					print('BPbranchOut(conv1D+Ptex - Drop).shape, BPbranchNo, layerNo =', BPbranchOut.shape, BPbranchNo, layerNo)
			print('BPbranchOut(Conv1D+Ptex - DenseDrop).shape, BPbranchNo =', BPbranchOut.shape, BPbranchNo)

			###################### CLASSIFICATION (SOFTMAX) ######################
			if (subMods_classification[conv1DbranchNo][BPbranchNo]!=0):
				BPbranchOut = Dense(classes, activation='softmax', name='Predictions_subModels'+str(conv1DbranchNo)+'_BPbranch'+str(BPbranchNo))(BPbranchOut)
				print('BPbranchOut(conv1D+Ptex - Class).shape =', BPbranchOut.shape, '\n')
			BPbranchOuts_list.append(BPbranchOut)
			print('*** len(BPbranchOuts_list) =', len(BPbranchOuts_list), '\n')

		if subMods_join[conv1DbranchNo] != 0:
			BPbranchOuts_joined = Concatenate()(BPbranchOuts_list)
			print('BPbranchOuts_joined.shape =', BPbranchOuts_joined.shape)
		else:
			BPbranchOuts_joined = conv1Dbranch_out # this will not work

		print('noLayers =', noLayers)
		print('subMods_join_FCs[',conv1DbranchNo,'] =', subMods_join_FCs[conv1DbranchNo])
		for layerNo in range(noLayers):
			# FC_Pext_size*y*_layer*x*
			if (subMods_join_FCs[conv1DbranchNo][layerNo]!=0):
				BPbranchOuts_joined = Dense(subMods_join_FCs[conv1DbranchNo][layerNo], activation='relu', name='subMods_join_FCs'+str(conv1DbranchNo)+'_'+str(layerNo)+'_'+str(subMods_join_FCs[conv1DbranchNo][layerNo])+'nodes')(BPbranchOuts_joined)
				print('BPbranchOuts_joined(Dense).shape, conv1DbranchNo, layerNo =', BPbranchOuts_joined.shape, conv1DbranchNo, layerNo)
			if (subMods_join_FC_BNorms[conv1DbranchNo][layerNo]!=0):
				#BPbranchOuts_joined = tf.layers.batch_normalization(BPbranchOuts_joined, trainable=True, name='subMods_join_FCs_BNorm'+str(conv1DbranchNo)+'_'+str(layerNo))
				BPbranchOuts_joined = BatchNormalization(trainable=True, name='subMods_join_FCs_BNorm'+str(conv1DbranchNo)+'_'+str(layerNo))(BPbranchOuts_joined)
				print('BPbranchOuts_joined(BN).shape, conv1DbranchNo, layerNo =', BPbranchOuts_joined.shape, conv1DbranchNo, layerNo)
			if (subMods_join_FC_Drops[conv1DbranchNo][layerNo]!=0):
				BPbranchOuts_joined = Dropout(subMods_join_FC_Drops[conv1DbranchNo][layerNo], name='subMods_join_FCs_drop'+str(conv1DbranchNo)+'_'+str(layerNo)+'_'+str(subMods_join_FC_Drops[conv1DbranchNo][layerNo]))(BPbranchOuts_joined)
				print('BPbranchOuts_joined(Drop).shape, conv1DbranchNo, layerNo =', BPbranchOuts_joined.shape, conv1DbranchNo, layerNo)


		###################### CLASSIFICATION (SOFTMAX) ######################
		if (subMods_join_classification[conv1DbranchNo]!=0):
			BPbranchOuts_joined = Dense(classes, activation='softmax', name='Predictions_joinSModels')(BPbranchOuts_joined)
			print('BPbranchOuts_joined(classification).shape =', BPbranchOuts_joined.shape)

		sModel = Model(inputs, BPbranchOuts_joined, name=MLmodel_detail)
		sModel.summary()
		# plot graph of ensemble
		plot_model(sModel, show_shapes=True, to_file=modelLogFolder + logFilename + '_modelGraph.png')
		optimizer = RMSprop(lr=0.00001)
		sModel.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
		#allBranchOuts_list.append(sModel)
	#return allBranchOuts_list
	return sModel

# make a prediction with a stacked model
# https://machinelearningmastery.com/stacking-ensemble-for-deep-learning-neural-networks/
def predict_stacked_model(model, inputX):
	# prepare input data
	X = [inputX for _ in range(len(model.input))]
	# make prediction
	return model.predict(X, verbose=0)

def load_sca_model(model_file):
	check_file_exists(model_file)
	try:
			model = load_model(model_file)
	except:
		print("Error: can't load Keras model file '%s'" % model_file)
		sys.exit(-1)
	return model

#######	THESE FUNCTIONS ARE SPECIALIZED FOR KYBER	 #######
####### Loading traces and metadata from file ############
#def load_meta_trace_file(database_file, sKeyNo, load_metadata=False):
def load_meta_trace_files(database_folder_train, sKeyNo, work, load_metadata=False):
	trace = []
	skpv_a_vec0_evenCoeff0 = []
	bp_b_vec0_evenCoeff = []
	bp_b_vec0_evenCoeff_next_sKeyNo = []
	bp_b_vec0_oddCoeff = []
	bp_b_vec0_oddCoeff_next_sKeyNo = []
	fileNo = 0
	if work == 'train':
		NoFiles = len(training_file_list)
	else:
		NoFiles = 1
	for fileNo in range(0, NoFiles):
		if work == 'train':
			database_file = database_folder_train + training_file_list[fileNo]
		else:
			database_file = database_folder_train
		print('\nLoad database_file =', database_file)
		check_file_exists(database_file)
		# Open the Kyber database HDF5 for reading
		try:
			in_file  = h5py.File(database_file, "r")
		except:
			print("Error: can't open HDF5 file '%s' for reading (it might be malformed) ..." % database_file)
			sys.exit(-1)
		trace.append(np.array(in_file['wave'], dtype=float))
		skpv_a_vec0_evenCoeff0.append(np.array(in_file['skpv_a_vec0_evenCoeff0'][:,sKeyNo].astype(int)))
		bp_b_vec0_evenCoeff.append(np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo].astype(int)))
		bp_b_vec0_evenCoeff_next_sKeyNo.append(np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo+1].astype(int)))
		bp_b_vec0_oddCoeff.append(np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo].astype(int)))
		bp_b_vec0_oddCoeff_next_sKeyNo.append(np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo+1].astype(int)))
        
	trace_profiling = np.concatenate(trace)
	skpv_profiling = np.concatenate(skpv_a_vec0_evenCoeff0)
	bp_b_vec0_evenCoeff0 = np.concatenate(bp_b_vec0_evenCoeff)
	bp_b_vec0_evenCoeff0_next_sKeyNo = np.concatenate(bp_b_vec0_evenCoeff_next_sKeyNo)
	bp_b_vec0_oddCoeff1 = np.concatenate(bp_b_vec0_oddCoeff)
	bp_b_vec0_oddCoeff1_next_sKeyNo = np.concatenate(bp_b_vec0_oddCoeff_next_sKeyNo)

	#print('skpv_a_vec0_evenCoeff0.shape =             ', skpv_a_vec0_evenCoeff0.shape)
	#print('skpv_a_vec0_evenCoeff0_next_sKeyNo.shape = ', skpv_a_vec0_evenCoeff0_next_sKeyNo.shape)
	#print('skpv_a_vec0_oddCoeff1.shape =              ', skpv_a_vec0_oddCoeff1.shape)
	#print('skpv_a_vec0_oddCoeff1_next_sKeyNo.shape =  ', skpv_a_vec0_oddCoeff1_next_sKeyNo.shape)
	#print('skpv_a_vec1_evenCoeff0.shape =             ', skpv_a_vec1_evenCoeff0.shape)
	#print('skpv_a_vec1_evenCoeff0_next_sKeyNo.shape = ', skpv_a_vec1_evenCoeff0_next_sKeyNo.shape)
	#print('skpv_a_vec1_oddCoeff1.shape =              ', skpv_a_vec1_oddCoeff1.shape)
	#print('skpv_a_vec1_oddCoeff1_next_sKeyNo.shape =  ', skpv_a_vec1_oddCoeff1_next_sKeyNo.shape)
	print('trace_profiling.shape =                    ', trace_profiling.shape)
	print('skpv_profiling.shape =                     ', skpv_profiling.shape)
	print('bp_b_vec0_evenCoeff0.shape =               ', bp_b_vec0_evenCoeff0.shape)
	print('bp_b_vec0_evenCoeff0_next_sKeyNo.shape =   ', bp_b_vec0_evenCoeff0_next_sKeyNo.shape)
	print('bp_b_vec0_oddCoeff1.shape =                ', bp_b_vec0_oddCoeff1.shape)
	print('bp_b_vec0_oddCoeff1_next_sKeyNo.shape =    ', bp_b_vec0_oddCoeff1_next_sKeyNo.shape)
	#print('bp_b_vec1_evenCoeff0.shape =               ', bp_b_vec1_evenCoeff0.shape)
	#print('bp_b_vec1_evenCoeff0_next_sKeyNo.shape =   ', bp_b_vec1_evenCoeff0_next_sKeyNo.shape)
	#print('bp_b_vec1_oddCoeff1.shape =                ', bp_b_vec1_oddCoeff1.shape)
	#print('bp_b_vec1_oddCoeff1_next_sKeyNo.shape =    ', bp_b_vec1_oddCoeff1_next_sKeyNo.shape)

	#skpv_a_vec0_evenCoeff0_hex = [hex(val) for val in skpv_a_vec0_evenCoeff0]
	#skpv_a_vec0_oddCoeff1_hex = [hex(val) for val in skpv_a_vec0_oddCoeff1]
	#skpv_a_vec1_evenCoeff0_hex = [hex(val) for val in skpv_a_vec1_evenCoeff0]
	#skpv_a_vec1_oddCoeff1_hex = [hex(val) for val in skpv_a_vec1_oddCoeff1]
	#bp_b_vec0_evenCoeff0_hex = [hex(val) for val in bp_b_vec0_evenCoeff0]
	#bp_b_vec0_oddCoeff1_hex = [hex(val) for val in bp_b_vec0_oddCoeff1]
	#bp_b_vec1_evenCoeff0_hex = [hex(val) for val in bp_b_vec1_evenCoeff0]
	#bp_b_vec1_oddCoeff1_hex = [hex(val) for val in bp_b_vec1_oddCoeff1]

	#skpv_a_vec0_evenCoeff0_next_sKeyNo_hex = [hex(val) for val in skpv_a_vec0_evenCoeff0_next_sKeyNo]
	#skpv_a_vec0_oddCoeff1_next_sKeyNo_hex = [hex(val) for val in skpv_a_vec0_oddCoeff1_next_sKeyNo]
	#skpv_a_vec1_evenCoeff0_next_sKeyNo_hex = [hex(val) for val in skpv_a_vec1_evenCoeff0_next_sKeyNo]
	#skpv_a_vec1_oddCoeff1_next_sKeyNo_hex = [hex(val) for val in skpv_a_vec1_oddCoeff1_next_sKeyNo]
	#bp_b_vec0_evenCoeff0_next_sKeyNo_hex = [hex(val) for val in bp_b_vec0_evenCoeff0_next_sKeyNo]
	#bp_b_vec0_oddCoeff1_next_sKeyNo_hex = [hex(val) for val in bp_b_vec0_oddCoeff1_next_sKeyNo]
	#bp_b_vec1_evenCoeff0_next_sKeyNo_hex = [hex(val) for val in bp_b_vec1_evenCoeff0_next_sKeyNo]
	#bp_b_vec1_oddCoeff1_next_sKeyNo_hex = [hex(val) for val in bp_b_vec1_oddCoeff1_next_sKeyNo]
    
	#print('skpv_a_vec0_evenCoeff0_hex[0:7] =             ', skpv_a_vec0_evenCoeff0_hex[0:7])
	#print('skpv_a_vec0_evenCoeff0_next_sKeyNo_hex[0:7] = ', skpv_a_vec0_evenCoeff0_next_sKeyNo_hex[0:7])
	#print('skpv_a_vec0_oddCoeff1_hex[0:7] =              ', skpv_a_vec0_oddCoeff1_hex[0:7])
	#print('skpv_a_vec0_oddCoeff1_next_sKeyNo_hex[0:7] =  ', skpv_a_vec0_oddCoeff1_next_sKeyNo_hex[0:7])
	#print('skpv_a_vec1_evenCoeff0_hex[0:7] =             ', skpv_a_vec1_evenCoeff0_hex[0:7])
	#print('skpv_a_vec1_evenCoeff0_next_sKeyNo_hex[0:7] = ', skpv_a_vec1_evenCoeff0_next_sKeyNo_hex[0:7])
	#print('skpv_a_vec1_oddCoeff1_hex[0:7] =              ', skpv_a_vec1_oddCoeff1_hex[0:7])
	#print('skpv_a_vec1_oddCoeff1_next_sKeyNo_hex[0:7] =  ', skpv_a_vec1_oddCoeff1_next_sKeyNo_hex[0:7])
	#print('bp_b_vec0_evenCoeff0_hex[0:7] =               ', bp_b_vec0_evenCoeff0_hex[0:7])
	#print('bp_b_vec0_evenCoeff0_next_sKeyNo_hex[0:7] =   ', bp_b_vec0_evenCoeff0_next_sKeyNo_hex[0:7])
	#print('bp_b_vec0_oddCoeff1_hex[0:7] =                ', bp_b_vec0_oddCoeff1_hex[0:7])
	#print('bp_b_vec0_oddCoeff1_next_sKeyNo_hex[0:7] =    ', bp_b_vec0_oddCoeff1_next_sKeyNo_hex[0:7])
	#print('bp_b_vec1_evenCoeff0_hex[0:7] =               ', bp_b_vec1_evenCoeff0_hex[0:7])
	#print('bp_b_vec1_evenCoeff0_next_sKeyNo_hex[0:7] =   ', bp_b_vec1_evenCoeff0_next_sKeyNo_hex[0:7])
	#print('bp_b_vec1_oddCoeff1_hex[0:7] =                ', bp_b_vec1_oddCoeff1_hex[0:7])
	#print('bp_b_vec1_oddCoeff1_next_sKeyNo_hex[0:7] =    ', bp_b_vec1_oddCoeff1_next_sKeyNo_hex[0:7])

	#input()
    
	#sca_bp_in = np.array(in_file['sca_bp_in'])
	#print('sca_bp_in.shape =', sca_bp_in.shape, '    sca_bp_in =', sca_bp_in)
	#bp_b_vec0_evenCoeff0 = sca_bp_in[:,0].astype(int)
	#bp_b_vec0_oddCoeff1 = sca_bp_in[:,1].astype(int)
	bp_profiling = [bp_b_vec0_evenCoeff0, bp_b_vec0_oddCoeff1, bp_b_vec0_evenCoeff0_next_sKeyNo, bp_b_vec0_oddCoeff1_next_sKeyNo]
	#print('len(bp_profiling) =', len(bp_profiling), '    bp_profiling =', bp_profiling)
	#print('bp_b_vec0_evenCoeff0.shape =', bp_b_vec0_evenCoeff0.shape)
	#print('bp_b_vec0_evenCoeff0[0~20] =', bp_b_vec0_evenCoeff0[0:20])
	# Load profiling labels
	#a_vec0_evenCoeff_by_b_vec0_evenCoeff = np.array(in_file['a_vec0_evenCoeff_by_b_vec0_evenCoeff'][:,sKeyNo])
	#a_vec0_evenCoeff_by_b_vec0_oddCoeff = np.array(in_file['a_vec0_evenCoeff_by_b_vec0_oddCoeff'][:,sKeyNo])
	#fqmul_profiling = [a_vec0_evenCoeff_by_b_vec0_evenCoeff, a_vec0_evenCoeff_by_b_vec0_oddCoeff]
    
	#print("skpv_profiling =", skpv_profiling)
	#print("bp_b_vec0_evenCoeff0 =", bp_b_vec0_evenCoeff0)
	#print("bp_b_vec0_oddCoeff1 =", bp_b_vec0_oddCoeff1)
	#print("DEBUG - WAIT FOR CHECKING, ENTER to exit")
	#input()

	if load_metadata == False:
		return (trace_profiling, bp_profiling, skpv_profiling)#, fqmul_profiling)
	else:
		return (trace_profiling, bp_profiling, skpv_profiling)#, fqmul_profiling)

#### Converting traces and metadata to training format
# inputs = [[list of traces], [list of bp]]
#def create_training_data_form(database_folder_train_file, sKeyNo, trainPortion, xType, yType):
def create_training_data_form(database_folder_train, sKeyNo, trainPortion, xType, yType):
	#(trace_profiling, bp_profiling, skpv_profiling, fqmul_profiling) = load_meta_trace_file(database_folder_train_file, sKeyNo)
	#(trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file(database_folder_train_file, sKeyNo)
	(trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_files(database_folder_train, sKeyNo, work)
	Reshaped_trace_profiling= trace_profiling.reshape((trace_profiling.shape[0], trace_profiling.shape[1], 1))
	dataSize = Reshaped_trace_profiling.shape[0]
	trainSize = math.floor(dataSize * trainPortion)
	valLoc = trainSize
	if valLoc == dataSize:
		valLoc = dataSize - 1

	lineNo = list(range(0, bp_profiling[0].shape[0]))
	#bp0_1hot_profiling = np.zeros((bp_profiling[0].shape[0], NumBPinput)).astype(np.int)
	bp0_1hot_profiling = np.zeros((bp_profiling[0].shape[0], NumBPinput)).astype(int)
	print('bp0_1hot_profiling.shape =', bp0_1hot_profiling.shape, '                bp_profiling[0] =', bp_profiling[0])
	bp0_1hot_profiling[lineNo,bp_profiling[0]] = 1
	Reshaped_bp0_1hot_profiling = bp0_1hot_profiling.reshape((bp0_1hot_profiling.shape[0], NumBPinput, 1))
	#print('Reshaped_bp0_1hot_profiling.shape = ', Reshaped_bp0_1hot_profiling.shape)
   
	lineNo = list(range(0, bp_profiling[1].shape[0]))
	#bp1_1hot_profiling = np.zeros((bp_profiling[1].shape[0], NumBPinput)).astype(np.int)
	bp1_1hot_profiling = np.zeros((bp_profiling[1].shape[0], NumBPinput)).astype(int)
	print('bp1_1hot_profiling.shape =', bp1_1hot_profiling.shape, '                bp_profiling[1] =', bp_profiling[1])
	#input()
	bp1_1hot_profiling[lineNo,bp_profiling[1]] = 1
	Reshaped_bp1_1hot_profiling = bp1_1hot_profiling.reshape((bp1_1hot_profiling.shape[0], NumBPinput, 1))
    
	lineNo = list(range(0, bp_profiling[2].shape[0]))
	#bp0_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[2].shape[0], NumBPinput)).astype(np.int)
	bp0_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[2].shape[0], NumBPinput)).astype(int)
	print('bp0_1hot_profiling_next_sKeyNo.shape =', bp0_1hot_profiling_next_sKeyNo.shape, '    bp_profiling[2] =', bp_profiling[2])
	bp0_1hot_profiling_next_sKeyNo[lineNo,bp_profiling[2]] = 1
	Reshaped_bp0_1hot_profiling_next_sKeyNo = bp0_1hot_profiling_next_sKeyNo.reshape((bp0_1hot_profiling_next_sKeyNo.shape[0], NumBPinput, 1))
    
	lineNo = list(range(0, bp_profiling[3].shape[0]))
	#bp1_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[3].shape[0], NumBPinput)).astype(np.int)
	bp1_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[3].shape[0], NumBPinput)).astype(int)
	print('bp1_1hot_profiling_next_sKeyNo.shape =', bp1_1hot_profiling_next_sKeyNo.shape, '    bp_profiling[3] =', bp_profiling[3])
	#input()
	bp1_1hot_profiling_next_sKeyNo[lineNo,bp_profiling[3]] = 1
	Reshaped_bp1_1hot_profiling_next_sKeyNo = bp1_1hot_profiling_next_sKeyNo.reshape((bp1_1hot_profiling_next_sKeyNo.shape[0], NumBPinput, 1))
    
	#y_train_fqmul0 = to_categorical(fqmul_profiling[0], num_classes=NumFQMULclasses)
	#y_train_fqmul1 = to_categorical(fqmul_profiling[1], num_classes=NumFQMULclasses)
	y_train_skpv = to_categorical(skpv_profiling, num_classes=NumSKPVclasses)

	#xTrain_wave = [Reshaped_trace_profiling[0:trainSize,:,:]]
	#xTrain_wavebp0 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:]]]
	#xTrain_wavebp1 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp1_1hot_profiling[0:trainSize,:]]]
	#xTrain_wavebp01 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]]
	#xTrain_wavebp01next0 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:]]]
	#xTrain_wavebp01next01 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[0:trainSize,:]]]
	#xTrain_wave = Reshaped_trace_profiling[0:trainSize,:,:]
	#xTrain_wavebp0 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp0_1hot_profiling[0:trainSize,:]]
	#xTrain_wavebp1 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]
	#xTrain_wavebp01 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]
	#yTrain_fqmul0 = y_train_fqmul0[0:trainSize,:]
	#yTrain_fqmul1 = y_train_fqmul1[0:trainSize,:]
	#yTrain_skpv = y_train_skpv[0:trainSize,:]

	#xVal_wave = [Reshaped_trace_profiling[valLoc:,:,:]]
	#xVal_wavebp0 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:]]]
	#xVal_wavebp1 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp1_1hot_profiling[valLoc:,:]]]
	#xVal_wavebp01 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]]
	#xVal_wavebp01next0 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:]]]
	#xVal_wavebp01next01 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[valLoc:,:]]]
	#xVal_wave = Reshaped_trace_profiling[valLoc:,:,:]
	#xVal_wavebp0 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp0_1hot_profiling[valLoc:,:]]
	#xVal_wavebp1 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]
	#xVal_wavebp01 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]
	#yVal_fqmul0 = y_train_fqmul0[valLoc:,:]
	#yVal_fqmul1 = y_train_fqmul1[valLoc:,:]
	#yVal_skpv = y_train_skpv[valLoc:,:]

	# Input data creation
	if xType == 'wave':
		xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]]]#xTrain_wave
		xVal = [[Reshaped_trace_profiling[valLoc:,:,:]]]#xVal_wave
	elif xType == 'wavebp0':
		xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:]]]#xTrain_wavebp0
		xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:]]]#xVal_wavebp0
	elif xType == 'wavebp1':
		xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp1_1hot_profiling[0:trainSize,:]]]#xTrain_wavebp1
		xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp1_1hot_profiling[valLoc:,:]]]#xVal_wavebp1
	elif xType == 'wavebp01':
		xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]]#xTrain_wavebp01
		xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]]#xVal_wavebp01
	elif xType == 'wavebp01next0':
		xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:]]]#xTrain_wavebp01next0
		xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:]]]#xVal_wavebp01next0
	elif xType == 'wavebp01next01':
		xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[0:trainSize,:]]]#xTrain_wavebp01next01
		xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[valLoc:,:]]]#xVal_wavebp01next01
		#print('Created xType, len(xTrain) :', xType, len(xTrain))
		#print('len(xTrain[0]) =', len(xTrain[0]), ';    len(xTrain[1]) =', len(xTrain[1]), 'PRESS ENTER TO CONTINUE')
		#print('PRESS ENTER TO CONTINUE')
		#input()
	# Category creation
	if yType == 'fqmul0':
		yTrain = yTrain_fqmul0
		yTrain_value = fqmul_profiling[0][0:trainSize]
		yVal = yVal_fqmul0
		yVal_value = fqmul_profiling[0][valLoc:]
	elif yType == 'fqmul1':
		yTrain = yTrain_fqmul1
		yTrain_value = fqmul_profiling[1][0:trainSize]
		yVal = yVal_fqmul1
		yVal_value = fqmul_profiling[1][valLoc:]
	elif yType == 'skpv':
		yTrain = y_train_skpv[0:trainSize,:]#yTrain_skpv
		yTrain_value = skpv_profiling[0:trainSize]
		yVal = y_train_skpv[valLoc:,:]#yVal_skpv
		yVal_value = skpv_profiling[valLoc:]

	return xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value

def train_model_1datafile(xTrain, yTrain, xVal, yVal, model, class_weight, modelLogFolder, logFilename, period, epochNo, train_batch_size, moreEpNo):
	# Save model every period epochs
	#save_model_name = modelLogFolder + logFilename + "_epSt" + str(int(epochNo)).zfill(4) + "_epTr{epoch:04d}_trLoss{loss:.4f}_trAcc{acc:.4f}_valLoss{val_loss:.4f}_valAcc{val_acc:.4f}.h5"
	save_model_name = logTrainedModel_byFile_folder + logFilename + "_epSt" + str(int(epochNo)).zfill(4) + "_epTr{epoch:04d}.h5"
	#save_model_name = logTrainedModel_byFile_folder+logFilename+"_ep{(epoch+epochNo):04d}"+_file'+str(fileNo).zfill(2)+'.h5'
	#save_model_name = logTrainedModel_byFile_folder+logFilename+"_ep{(epoch+epochNo):04d}.h5"
	#save_model_name = logTrainedModel_byFile_folder+logFilename+"_ep{epoch:04d}.h5"
	#save_model_name = logTrainedModel_byFile_folder+logFilename+"_ep{epoch:04d}_file(fileNo:02d}.h5"
	#save_model_name = logTrainedModel_byFile_folder+logFilename+"_ep{epoch+epochNo:04d}_file(fileNo:02d}.h5"
	print("save_model_name =", modelLogFolder + logFilename)
	print("modelLogFolder+logFilename =", modelLogFolder+logFilename)
	train_AccLoss_LogFile = modelLogFolder + logFilename + ".csv"
	check_file_exists(os.path.dirname(save_model_name))
	check_file_exists(os.path.dirname(train_AccLoss_LogFile))
	csv_logger = CSVLogger(filename=train_AccLoss_LogFile, append=True, separator=';')
	save_model = ModelCheckpoint(save_model_name, period=period)
	callbacks=[csv_logger, save_model]
	# Get the input layer shape
	input_layer_shape = model.get_layer(index=0).input_shape
	print('input_layer_shape =', input_layer_shape)
	print('input_layer_shape[0][1] =', input_layer_shape[0][1])
	print('Number of sample points per trace: len(xTrain[0][0][0]) =', len(xTrain[0][0][0]))
	print('Number of traces: len(xTrain[0][0]) =', len(xTrain[0][0]))
	#print('Number of bp: len(xTrain[1]) =', len(xTrain[1]))
	#print('Press Enter')
	#input()
	# Sanity check
	if input_layer_shape[0][1] != len(xTrain[0][0][0]):
		print("Error: model input shape %d instead of %d is not expected ..." % (input_layer_shape[0][1], len(xTrain[0][0])))
		sys.exit(-1)
	# instruction in pair for saving each period for each file before loading back #########################
	#history = model.fit(x=xTrain, y=yTrain, batch_size=train_batch_size, verbose = 1, epochs=period, callbacks=callbacks, class_weight=class_weight, validation_data=(xVal,yVal))
	history = model.fit(x=xTrain, y=yTrain, batch_size=train_batch_size, verbose = 1, epochs=moreEpNo, callbacks=callbacks, class_weight=class_weight, validation_data=(xVal,yVal))
	#if os.path.exists(modelLogFolder+logFilename + '.h5') == True:	#remove the saved file
	#	os.remove(modelLogFolder+logFilename + '.h5')
	print('Save log model to', modelLogFolder+logFilename + '.h5')
	model.save(modelLogFolder+logFilename + '.h5')
	#del model
	#del history
	#K.clear_session()
	#model = load_model(modelLogFolder + logFilename + '.h5')
	return model, history

def train_model_1epoch(database_folder_train, modelLogFolder, logTrainedModel_byFile_folder, logFilename, model, epochNo, sKeyNo, class_weight, period, startFileNo, train_batch_size, moreEpNo):
	# make the file list (of the database) before reading each file and trains model
	databaseFileList = os.listdir(database_folder_train)
	databaseFileList.sort()
	print('startFileNo =', startFileNo, ';    len(databaseFileList) =', len(databaseFileList))
	#print('Press ENTER')
	#input()
	#for fileNo in range(startFileNo, len(databaseFileList)):
	for fileNo in range(0, 1):
		# loading one file from the database
		DB_filename = databaseFileList[fileNo]
		database_folder_train_file = database_folder_train + DB_filename
		#xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value = create_training_data_form(database_folder_train_file, sKeyNo, trainPortion, xType, yType)
		xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value = create_training_data_form(database_folder_train, sKeyNo, trainPortion, xType, yType)
		model, history = train_model_1datafile(xTrain, yTrain, xVal, yVal, model, class_weight, modelLogFolder, logFilename, period, epochNo, train_batch_size, moreEpNo)
		print('\n*** Save part epoch model to', logTrainedModel_byFile_folder+logFilename+'_ep'+str(epochNo).zfill(4)+'_file'+str(fileNo).zfill(2)+'.h5')

		#model.save(logTrainedModel_byFile_folder+logFilename+'_ep'+str(epochNo).zfill(4)+'_file'+str(fileNo).zfill(2)+'.h5')
		# instruction in pair for saving each period for each file before loading back #########################
		model.save(logTrainedModel_byFile_folder+logFilename+'_ep'+str(epochNo+period).zfill(4)+'_file'+str(fileNo).zfill(2)+'.h5')
		logTrainedModel_byFile_fileList = os.listdir(logTrainedModel_byFile_folder)
		logTrainedModel_byFile_fileList.sort()
		if len(logTrainedModel_byFile_fileList) > len(databaseFileList):	# more number of saved trained model than the number of database file means redundance available
			os.remove(logTrainedModel_byFile_folder + logTrainedModel_byFile_fileList[0])
		if GPU_clear == True:
			# https://github.com/keras-team/keras/issues/5345
			print('*** Delete model and flush GPU')
			del model
			del history
			K.clear_session()
			#tf.keras.backend.clear_session()
			#gc.collect()
			#device = cuda.get_current_device()
			#device.reset()
			print('*** Load continuous model from', logTrainedModel_byFile_folder+logFilename+'_ep'+str(epochNo).zfill(4)+'_file'+str(fileNo).zfill(2)+'.h5')
			model = load_model(logTrainedModel_byFile_folder+logFilename+'_ep'+str(epochNo).zfill(4)+'_file'+str(fileNo).zfill(2)+'.h5')
		else:
			print('*** No GPU flushing')
	print('********* END epoch ', epochNo, 'training **********')
	#input()
	return model

def train_model_multiEpochs(xType, database_folder_train, modelLogFolder, logTrainedModel_byFile_folder, logTrainedModel_byEp_folder, logFilename, MLmodel_detail, sKeyNo, class_weight, period, maxEpochs, train_batch_size):
	databaseFileList = os.listdir(database_folder_train)
	databaseFileList.sort()
	logTrainedModel_byFile_fileList = os.listdir(logTrainedModel_byFile_folder)
	logTrainedModel_byFile_fileList.sort()
	print('len(logTrainedModel_byFile_fileList) =', len(logTrainedModel_byFile_fileList))

	if len(logTrainedModel_byFile_fileList) == 0:   #new model
		print('\nGenerating model')
		model = subModels_gen(xType, noConv1Dbranch, noBPbranch, noLayers, tracelen, NumBPinput, MLmodel_detail, modelLogFolder, logFilename, classes=noClasses)
		nextEpNo = 0				# continue current epoch from the last trained file
		nextFileNo = 0	  		# back to the first data file
		moreEpNo = maxEpochs
	else:   #continue training
		logTrainedModel_lastEpDatafile = logTrainedModel_byFile_fileList[len(logTrainedModel_byFile_fileList)-1]
		print('\nLoading initial model from', logTrainedModel_byFile_folder + logTrainedModel_lastEpDatafile)
		model = load_model(logTrainedModel_byFile_folder + logTrainedModel_lastEpDatafile)
		#epPos = logTrainedModel_lastEpDatafile.find('ep')+2
		epPosStart = logTrainedModel_lastEpDatafile.find('epSt')+4
		epPosEnd = logTrainedModel_lastEpDatafile.find('epTr')+4
		filePos = logTrainedModel_lastEpDatafile.find('file')+4
		#if epPos > 2:
		if epPosStart > 2:
			print("logTrainedModel_lastEpDatafile[epPosStart:epPosStart+4] =", logTrainedModel_lastEpDatafile[epPosStart:epPosStart+4])
			print("logTrainedModel_lastEpDatafile[epPosEnd:epPosEnd+4] =", logTrainedModel_lastEpDatafile[epPosEnd:epPosEnd+4])
			lastEpNo = int(logTrainedModel_lastEpDatafile[epPosStart:epPosStart+4]) + int(logTrainedModel_lastEpDatafile[epPosEnd:epPosEnd+4])
			print("lastEpNo =", lastEpNo)
			#input()
		else:
			lastEpNo = 0
		if filePos > 4:
			lastFileNo = int(logTrainedModel_lastEpDatafile[filePos:filePos+2])
		else:
			lastFileNo = -1		# so that next file will be file No. 0
		if (lastFileNo == len(databaseFileList)-1):
			nextEpNo = lastEpNo + 1		# next epoch from the first file
			nextFileNo = 0				# back to the first data file
		else:
			nextEpNo = lastEpNo				# continue current epoch from the last trained file
			nextFileNo = lastFileNo + 1		# back to the first data file
		moreEpNo = maxEpochs - lastEpNo
	print('nextEpNo =', nextEpNo)
	print('moreEpNo =', moreEpNo)
	#input()
	for epochNo in range(nextEpNo, maxEpochs):
		if epochNo == nextEpNo:
			startFileNo = nextFileNo
		else:
			startFileNo = 0
		model = train_model_1epoch(database_folder_train, modelLogFolder, logTrainedModel_byFile_folder, logFilename, model, epochNo, sKeyNo, class_weight, period, startFileNo, train_batch_size, moreEpNo)
		print('********* EXIT epoch ', epochNo, 'training **********')
		#input()
		if ((epochNo+1) % period == 0):
			print('\n***---*** Save ONE EPOCH model to', logTrainedModel_byEp_folder+logFilename+'_ep'+str(epochNo).zfill(4)+'.h5')
			model.save(logTrainedModel_byEp_folder+logFilename+'_ep'+str(epochNo).zfill(4)+'.h5')

def plot_meanrank(rankmat, maxtrc, label):
	nt = np.arange(maxtrc) + 1
	mr = np.mean(rankmat, 0)
	plt.xlabel('number of traces')
	plt.ylabel('mean rank')
	plt.plot(nt, mr, label = label)
	#print("label =", label)
	#print("nt =", nt)
	#print("mr =", mr)
	#plt.plot(nt, mr)

def mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
	realkey = int(yTest_value[0])
	rankmat_byKey = np.tile(0, (nruns, maxtrc))
	rankmat_byClass = np.tile(0, (nruns, maxtrc))
	ps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
	lps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
	lps_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
	lpsums_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
	print("")
	#print('%s  is running' % (model.__name__))
	for krun in range(nruns):
		#print('%s  run %d of %d' % (model.__name__, krun+1, nruns))
		if (krun % nruns) == 0:
			print('%s  run %d of %d' % (model.name, krun+1, nruns))
		samp = batches[krun,:]
		#ps = model.predict(U[samp,:])

		if xType == 'wave':
			ps = model.predict([xTest[0][0][samp,:,:]])
		elif xType == 'wavebp0':
			ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:]])
		elif xType == 'wavebp1':
			ps = model.predict([xTest[0][0][samp,:,:], xTest[1][1][samp,:]])
		elif xType == 'wavebp01':
			ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:]])
		elif xType == 'wavebp01next0':
			ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:]])
		elif xType == 'wavebp01next01':
			ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:], xTest[1][3][samp,:]])

		lps = np.log(ps)
		lpsums = np.zeros(noHypoKeys)
		#lpsAllHypoKeys = np.zeros((maxtrc, noHypoKeys))
		for i in range(maxtrc):
			#S = AES_Sbox[P[samp[i]] ^ range(0x100)]
			#S = AES_Sbox[P[samp[i]] ^ range(noHypoKeys)]
			realClass = realkey#S[realkey]
			#S = AES_Sbox[P[samp[i]] ^ range(0x100)]
			#S = P[samp[i]] ^ range(noHypoKeys)
			#realClass = HWcompute(S[realkey])
			lpsAllHypoKeys = lps
			#for hypoKey in range(noHypoKeys):
			#	lpsAllHypoKeys[i, hypoKey] = lps[i, S[hypoKey]]
			#print('lpsums.shape =', lpsums.shape, ';   lps.shape =', lps.shape)
			lpsums += lps[i]#, S]
			lpsums_AllHypoKeys_Nruns[i,:,krun] = lpsums
			#print('realkey =', realkey)
			rnk_byKey = sum(lpsums > lpsums[realkey])
			rankmat_byKey[krun, i] = rnk_byKey
			rnk_byClass = sum(lps[i, :] > lps[i, realClass])
			rankmat_byClass[krun, i] = rnk_byClass
		ps_AllClasses_Nruns[:,:,krun] = ps
		lps_AllClasses_Nruns[:,:,krun] = lps
		lps_AllHypoKeys_Nruns[:,:,krun] = lpsAllHypoKeys
	#print("rankmat_byKeys.shape =", rankmat_byKeys.shape)
	#print("ps.shape =", ps.shape)
	#print("lps.shape =", lps.shape)
	#print("lpsums.shape =", lpsums.shape)
	#print("ps_AllClasses_Nruns.shape =", ps_AllClasses_Nruns.shape)
	#print("lps_AllClasses_Nruns.shape =", lps_AllClasses_Nruns.shape)
	#print("lpsums_AllHypoKeys_Nruns.shape =", lpsums_AllHypoKeys_Nruns.shape)
	return rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns

def one_model_attack_and_plot(logTrainedModel_byEp_folder, attack_byModel_fileNo, database_folder_attack, nruns, maxtrc, attackLogFolder, logFilename):
    if os.path.isdir(attackLogFolder) == False:
	    os.mkdir(attackLogFolder)
    trainedModeFileList = os.listdir(logTrainedModel_byEp_folder)
    trainedModeFileList.sort()
    print('attack_byModel_fileNo =', attack_byModel_fileNo, ';   len(trainedModeFileList) =', len(trainedModeFileList))
    if (len(trainedModeFileList) >= attack_byModel_fileNo):
        print('\nLoad_model: ' ,logTrainedModel_byEp_folder + trainedModeFileList[attack_byModel_fileNo-1])
        #print('Press ENTER')
        #input()
        attack_model_filename = trainedModeFileList[attack_byModel_fileNo-1]
        #epPos = attack_model_filename.find('ep')+2
        #if epPos > 2:
        #    EpNo = int(attack_model_filename[epPos:epPos+4])
        epPosStart = attack_model_filename.find('epSt')+4
        epPosEnd = attack_model_filename.find('epTr')+4
        if epPosStart > 2:
            EpNo = int(attack_model_filename[epPosStart:epPosStart+4]) + int(attack_model_filename[epPosEnd:epPosEnd+4])
            print('Attack to epoch ', EpNo, '\n')
        else:
            print('\nInvalid epoch number for file    ', logTrainedModel_byEp_folder + trainedModeFileList[attack_byModel_fileNo-1], '   . Exit attacking.')
            sys.exit(-1)
        model = load_model(logTrainedModel_byEp_folder + trainedModeFileList[attack_byModel_fileNo-1])
        print('Model: ', model.name)
    else:
        print('\nInvalid attack_byModel_fileNo (', attack_byModel_fileNo, ') or not enough trained epoch files (len(trainedModeFileList) =', len(trainedModeFileList), '). Exit attacking.')
        sys.exit(-1)

    databaseFileList = os.listdir(database_folder_attack)
    databaseFileList.sort()
    database_folder_attack_file = database_folder_attack + databaseFileList[0]
    print('database_folder_attack_file =', database_folder_attack_file)
    xTest, yTest, xVal_, yVal_, yTest_value, yVal_value = create_training_data_form(database_folder_attack_file, sKeyNo, testPortion, xType, yType)
    print('len(xTest[0][0]) =', len(xTest[0][0]))
    print('yTest[0][1730:1740] =', yTest[0][1730:1740])
    print('yTest[1][1730:1740] =', yTest[1][1730:1740])
    print('yTest[1][1733] =', yTest[1][1733])
    print('yTest_value =', yTest_value)

    batches = np.zeros((nruns, maxtrc), 'int')
    for i in range(nruns):
        batches[i,:] = np.random.choice(len(xTest[0][0]), maxtrc, False)
    print("batches.shape =", batches.shape)
    print("batches =", batches)
    print('Attack to epoch ', EpNo, '\n')

    rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns = mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses)

    plt.figure(figsize=(20,15))
    plt.rcParams.update({'font.size': 25})
    plt.grid()
    plot_meanrank(rankmat_byKey, maxtrc, model.name)
    plt.legend()	
    plt.title('Model comparison using %d test runs' % nruns)
    plt.tight_layout()	
    #plt.savefig(attackLogFolder + logFilename + 'attack' + str(nruns) + 'run' + str(maxtrc) + 'trc_fileNo' + str(attack_byModel_fileNo) + '.png')
    plt.savefig(attackLogFolder + logFilename + 'attack' + str(nruns) + 'run' + str(maxtrc) + 'trc_epNo' + str(EpNo) + '.png')
    #plt.savefig(trainedFolder + saveFileName + '.pdf')
    plt.show(block = False)
    print('FINISH attacking with epoch ', EpNo)

if work == 'train':
    print('Work =', work)
    classWeights = np.ones(noClasses).astype(int)
    class_weight = dict(enumerate(classWeights))

    train_model_multiEpochs(xType, database_folder_train, modelLogFolder, logTrainedModel_byFile_folder, logTrainedModel_byEp_folder, logFilename, MLmodel_detail, sKeyNo, class_weight, period, maxEpochs, train_batch_size)

elif work == 'attack':
    print('Work =', work)
    try:
        nruns = int(argv[1])
    except:	
        nruns = nruns_default
    try:
        maxtrc = int(argv[2])
    except:
        maxtrc = maxtrc_default

    one_model_attack_and_plot(logTrainedModel_byEp_folder, attack_byModel_fileNo, database_folder_attack, nruns, maxtrc, attackLogFolder, logFilename)