import os.path
import sys
import h5py
import math
import gc
import numpy as np
#from numba import cuda
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Flatten, Dense, Input, Conv1D, MaxPooling1D, ReLU, Dropout, Concatenate, Activation, Multiply, BatchNormalization #, AveragePooling1D, Add, GlobalAveragePooling1D, GlobalMaxPooling1D
from tensorflow.keras.utils import plot_model   #, get_source_inputs
from tensorflow.keras.utils import get_file
from tensorflow.keras import backend as K
#from keras_applications.imagenet_utils import decode_predictions
#from keras_applications.imagenet_utils import preprocess_input
#from keras_applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model
import tensorflow as tf
import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid


def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--train_type', type=str, help='baseline or active')
    parser.add_argument('--sampling', type=str, help='random, minmax or uncertainty')
    parser.add_argument('--name', type=str, help='experiment name', default='test')
    parser.add_argument('--xType', type=str, help='number of ciphertext')
    parser.add_argument('--start_trace', type=int, help='start trace')
    parser.add_argument('--end_trace', type=int)
    parser.add_argument('--batch_size', type=int, help='batch_size', default=256)
    parser.add_argument('--num_epoch', type=int, help='batch_size', default=256)
    parser.add_argument('--trained_model_path', type=str)
    parser.add_argument('--num_iteration', type=int, help='iteration_num', default=5)
    parser.add_argument('--all_ids', type=str)
    parser.add_argument('--eval_path', type=str)
    parser.add_argument('--num_sample', type=int, help='iteration_num', default=5)
    parser.add_argument('--eval_interval', type=int, help='iteration_num', default=1)
    parser.add_argument('--seed', type=int, default = 0)
    parser.add_argument('--normalize', type=int, default = 0)
    parser.add_argument('--transformer', type=int, default = 0)
    parser.add_argument('--cummulative', type=int, default = 0)
    parser.add_argument('--lr', type=float, default = 0.0001)

    parser.add_argument('--node', type=int, default=256, required=True)
    parser.add_argument('--n_layers', type=int, default=6, required=True)
    parser.add_argument('--batch_norm', type=int, default=1, required=True)
    parser.add_argument('--dropout', type=int, default=1, required=True)
    parser.add_argument('--dropout_rate', type=float, default=0.2, required=True)

    return parser

parser = parse_arguments()
args = parser.parse_args()
#tf.config.experimental.enable_op_determinism()

def set_seeds(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)

set_seeds(args.seed)

bp_range = [0, 3329]
skpv_range = [0, 3328]
fqmul_range = [-1828, 1664]
tracelen = 600
NumFQMULclasses = fqmul_range[1] - fqmul_range[0] + 1;  # number of classes for fqmul(skpv, bp)
NumSKPVclasses = skpv_range[1] - skpv_range[0] + 1;     # number of classes for skpv
NumBPinput = bp_range[1] - bp_range[0] + 1;             # number of input for bp (ciphertext)
noClasses = NumSKPVclasses
noHypoKeys = NumSKPVclasses
sKeyNo = 0  # Note: sKeyNo is in range 0 to 3 and which subkeys are they are decided by code in m4 (NOT by code in PC)
work = 'train' #'train'  #'attack'
training_file_list = ['Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data000000to099999_600samples.h5',\
'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data100000to199999_600samples.h5']
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data200000to299999_600samples.h5']#,\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data300000to399999_600samples.h5',\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data400000to499999_600samples.h5']

trained_model_path = args.trained_model_path
data_path = 'data.npz'
nruns_default = 10
maxtrc_default = 200
#maxtrc_default = 115
testPortion = 1
attack_byModel_epNo = 232


# training parameters
train_batch_size = args.batch_size#100#150#200#250#500#640 #80 for mars45 #170 for mars56
period = 8 #8
maxEpochs = args.num_epoch#3072#2048#1536#1280#1024#512#256 #1536
attack_byModel_fileNo = int(attack_byModel_epNo/period)
N_TRACE = 20000
Threshold_Save = 200

#model hyper-parameters
noConv1Dbranch = 1
noLayers = 6    # if newly train
noClassificationLayer = 1
GPU_clear = True    # False

# training data type
xType = args.xType  #'wave' #'wavebp0' #'wavebp1' #'wavebp01' #'wavebp01next0' #'wavebp01next01'
yType = 'skpv'    #'fqmul0' #'fqmul1' #'skpv' 
trainPortion = 0.8

# Database and logs for model and training progress (epochs)
attackModel = 'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1'
device = 'm4_CWLite'
attackModel_dev = attackModel + '_' + device
attackModel_dev_folder = '../' + attackModel_dev + '/'

MLmodelStruct = '4C4FC_2BP4FC4FC_J4FCSM'
#MLmodel_detail = '3C[512_128_64]_2BP4FC[1024_512_256_128]4FC[1024_512_256_128]_J4FC[1024_512_256_128]SM'
MLmodel_detail = '4C/512_256_128_64/_2BP4FC/1024_512_256_128/4FC/1024_512_256_128/_J4FC/1024_512_256_128/SM'

hyper_ver = 'hy0001010101_skpv0'    #hyper-parameter contains 5 groups: Conv1D, FC for Conv1D, BP0, BP1, FC for joined BPs
#dataFile_train = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'  #'20kDatax25'
dataFile_train_folder = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'    #'20kDatax25'
dataFile_attack = '100kDatax1_test'#'skvp0_0_700points100kDatax1attack' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'  #'20kDatax25'
model_input_type = '_in[[][]]_tf2' #'in[]_tf2' #'[[][]]_tf2'
#data_type = dataFile_train + model_input_type
data_type = '100kDataxN' + str(len(training_file_list)) + model_input_type
#database_folder_train = attackModel_dev_folder + attackModel + '_' + dataFile_train_folder + '_h5/'
name_length = len('200k_2000cluster/minmax_')
save_path = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_lr_{}_MLP_{}_{}_{}_{}_{}'.format(args.name ,args.train_type, args.sampling ,xType, args.start_trace, args.end_trace, args.num_sample, args.seed, args.normalize, args.cummulative, args.lr, args.node, args.n_layers, args.batch_norm, args.dropout, args.dropout_rate)
#database_folder_train = os.path.join('trained_models', save_path)
database_folder_train = os.path.join('multi_attack_trained_models', save_path)
Path(database_folder_train).mkdir(parents=True, exist_ok=True)
database_folder_attack = attackModel_dev_folder + attackModel + '_' + dataFile_attack + '_h5/'
logFilename = MLmodelStruct + '_' + hyper_ver
DLmodel_name = logFilename
#DLmodel_folder = attackModel_dev_folder + logFilename + '_' + data_type + '/'
DLmodel_folder = 'models/'
modelLogFolder = DLmodel_folder + 'log' + DLmodel_name + '/'
logTrainedModel_byFile_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byDataFile/'
#logTrainedModel_byEp_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byEpoch/'
logTrainedModel_byEp_folder = logTrainedModel_byFile_folder
attackLogFolder = DLmodel_folder + 'log' + DLmodel_name + '_attack/'
if os.path.isdir(DLmodel_folder) == False:
    os.mkdir(DLmodel_folder)
if os.path.isdir(modelLogFolder) == False:
    os.mkdir(modelLogFolder)
if os.path.isdir(logTrainedModel_byFile_folder) == False:
    os.mkdir(logTrainedModel_byFile_folder)
if os.path.isdir(logTrainedModel_byEp_folder) == False:
    os.mkdir(logTrainedModel_byEp_folder)
print('DLmodel_folder =', DLmodel_folder)
print('modelLogFolder =', modelLogFolder)
print('logTrainedModel_byFile_folder =', logTrainedModel_byFile_folder)
print('logTrainedModel_byEp_folder =', logTrainedModel_byEp_folder)


################################################################################################
####################################### MODELS STRUCTURE #######################################
################################################################################################
# Input BatchNormalization for each PoI size
#                           subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_inputBNorms =   [   1,      0,      0,      0,      0,      0]
###################### MULTI CONVOLUTIONAL-SIZE CONVOLUTION ######################
# Convolutional nodes
# matrix showing number of nodes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_NoConvNodes =   [   512,    256,    128,    64,     0,      0]
# Convolutional filter sizes
# matrix showing filter sizes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convKernelSizes = [   3,      3,      3,      3,      0,      0] # subModel0

###############################################
# Pooling size in convolutional layers
# matrix showing MaxPooling sizes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convPoolSizes = [   2,      2,      2,      2,      0,      0] # subModel0

# Pooling stride in convolutional layers
# matrix showing MaxPooling strike in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convPoolStrides = [   3,      3,      3,      3,      0,      0] # subModel0

# BatchNormalization in convolutional layers
# matrix showing BatchNormalization condition in each convolutional layer in each PoI length
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_convBNorms = [   1,      1,      1,      1,      0,      0] # subModel0
# Dropout in convolutional layers
# matrix showing Dropout value in each convolutional layer in each PoI length
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_convDrops = [   0,      0,      0,      0,      0,      0] # subModel0

###################### MULTI CONVOLUTIONAL-SIZE FULLY-CONNECTED ######################
# Flatten Convolutional feature map before Fully connected
#                           subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_convFeatFlat = [    1,      0,      0,      0,      0,      0]
# Fully-connected for convolutional value before adding Plaintext
# matrix showing fully-connected condition before adding Plaintext
#                   layer0  layer1  layer2  layer3  layer4  layer5
subMods_FCs = [   1024,   512,    256,    128,    0,      0] # subModel0
# BatchNormalization for fully-connected of convolutional value before adding Plaintext
# matrix showing BatchNormalization for fully-connected condition before adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_FC_BNorms = [   1,      1,      1,      1,      0,      0] # subModel0

# Dropout for fully-connected of convolutional value before adding Plaintext
# matrix showing Dropout for fully-connected condition before adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_FC_Drops = [   0.2,    0,      0.2,    0,      0,      0] # subModel0

###################### MULTI_CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENSION ######################
# Plaintext adding here
if xType == 'wave':
    noBPbranch = 0
    #                       sub0    sub1    sub2    sub3    sub4    sub5
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp0':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0] # conv1D branch 0
elif xType == 'wavebp1':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01':
    noBPbranch = 2
    subMods_Pext =  [  1,      1,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01next0':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01next01':
    noBPbranch = 4
    subMods_Pext =  [  1,      1,      1,      1,      0,      0]  # conv1D branch 0


###################### (MULTI CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENDED) FULLY-CONNECTED ######################
# Fully-connected for convolutional value after adding Plaintext
# matrix showing fully-connected condition after adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FCs = [   [   1024,   1024,   512,    256,    128,    0], # subModel0
                        [   1024,   1024,   512,    256,    128,    0], # subModel1
                        [   1024,   1024,   512,    256,    128,    0], # subModel2
                        [   1024,   1024,   512,    256,    128,    0], # subModel3
                        [   0,      0,      0,      0,      0,      0], # subModel4
                        [   0,      0,      0,      0,      0,      0]]    # subModel5
# BatchNormalization for fully-connected of convolutional value after adding Plaintext
# matrix showing BatchNormalization for fully-connected condition after adding Plaintext
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FC_BNorms = [ [   1,      1,      1,      1,      1,      0], # subModel0
                            [   1,      1,      1,      1,      1,      0], # subModel1
                            [   1,      1,      1,      1,      1,      0], # subModel2
                            [   1,      1,      1,      1,      1,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]]    # subModel5
# Dropout for fully-connected of convolutional value after adding Plaintext
# matrix showing Dropout for fully-connected condition after adding Plaintext
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FC_Drops = [  [   0.2,    0.2,    0,      0.1,    0,      0], # subModel0
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel1
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel2
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]]    # subModel5

# Softmax for each sub-model if available
#                               subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_classification =    [  0,      0,      0,      0,      0,      0]

if xType == 'wave':
    subMods_join =  [   0]
else:
    subMods_join =  [   1]  

subMods_join_FCs =  [  1024,   1024,   512,    256,    128,    0]
# BatchNormalization for fully-connected of convolutional value after joining PoIs
# matrix showing BatchNormalization for fully-connected condition after joining PoIs
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_join_FC_BNorms =    [  1,      1,      1,      1,      1,      0]
# Dropout for fully-connected of convolutional value after joining PoIs
# matrix showing Dropout for fully-connected condition after joining PoIs
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_join_FC_Drops = [  0.2,        0.2,        0,      0.1,        0,      0]

# Softmax for joined-model if available
subMods_join_classification =   [   1]

################################################################################################
##################################### MODELS STRUCTURE END #####################################
################################################################################################


def check_file_exists(file_path):
    if os.path.exists(file_path) == False:
        print("Error: provided file path '%s' does not exist!" % file_path)
        sys.exit(-1)
    return

def listDirWithExt(directory, extension):
    return (f for f in os.listdir(directory) if f.endswith('.' + extension))

def subModels_gen(args,xType,noConv1Dbranch, noBPbranch, noLayers, tracelen, NumBPinput, MLmodel_detail, modelLogFolder, logFilename, classes=noClasses):
    input_trace_shape = (tracelen,1)
    input_Ptext1hot_shape = (NumBPinput,1)
    m_traceinputs = []
    m_Ptextinputs = []
    inputs = []
    for dataNo in range(noConv1Dbranch):
        trace_input = Input(shape=input_trace_shape)    #trace_input need to be generated many times to shows that they are different inputs
        m_traceinputs.append(trace_input)
        #inputs.append(trace_input)
    Ptext_input1 = Input(shape=input_Ptext1hot_shape)    #Ptext_input need to be generated many times to shows that they are different inputs
        #inputs.append(Ptext_input)
    Ptext_input2 = Input(shape=input_Ptext1hot_shape)
    Ptext_input3 = Input(shape=input_Ptext1hot_shape)
    Ptext_input4 = Input(shape=input_Ptext1hot_shape)
    if xType == 'wave':
        inputs = [trace_input]
    elif xType == 'wavebp01':
        inputs = [trace_input, Ptext_input1, Ptext_input2]


    x = BatchNormalization()(inputs[0])
    x = Dense(args.node, input_dim=600, activation='relu')(inputs[0][:,:,0])

    node = args.node
    for i in range(args.n_layers - 1):
        if i < args.n_layers // 2:
            node = node * 2
        else:
            node = node // 2

        x = Dense(node, activation='relu')(x)
        if args.batch_norm == 1:
            x = BatchNormalization(trainable=True)(x)
        if args.dropout == 1:
            x = Dropout(args.dropout_rate)(x)



    outputs = Dense(classes, activation='softmax')(x)
    
    sModel = Model(inputs, outputs, name='model')
    sModel.summary()
    tf.keras.utils.plot_model(sModel, show_shapes=True, to_file='model.png')
    # plot graph of ensemble
    #plot_model(sModel, show_shapes=True, to_file=modelLogFolder + logFilename + '_modelGraph.png')
    optimizer = RMSprop(learning_rate=args.lr)
    sModel.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        #allBranchOuts_list.append(sModel)
    #return allBranchOuts_list
    return sModel

# make a prediction with a stacked model
# https://machinelearningmastery.com/stacking-ensemble-for-deep-learning-neural-networks/
def predict_stacked_model(model, inputX):
    # prepare input data
    X = [inputX for _ in range(len(model.input))]
    # make prediction
    return model.predict(X, verbose=0)

def load_sca_model(model_file):
    check_file_exists(model_file)
    try:
            model = load_model(model_file)
    except:
        print("Error: can't load Keras model file '%s'" % model_file)
        sys.exit(-1)
    return model

####### THESE FUNCTIONS ARE SPECIALIZED FOR KYBER    #######
####### Loading traces and metadata from file ############
#def load_meta_trace_file(database_file, sKeyNo, load_metadata=False):
def load_meta_trace_files(data_path, start_trace, end_trace):
    data = np.load(data_path)
    trace_profiling = data['data']
    bp_profiling = data['bp']
    skpv_profiling = data['label']

    return (trace_profiling[start_trace:end_trace], bp_profiling[:,start_trace:end_trace], skpv_profiling[start_trace:end_trace])#, fqmul_profiling)

def load_meta_trace_file_from_test(database_file, sKeyNo, load_metadata=False):
    print('\nLoad database_file =', database_file)
    check_file_exists(database_file)
    # Open the Kyber database HDF5 for reading
    try:
        in_file  = h5py.File(database_file, "r")
    except:
        print("Error: can't open HDF5 file '%s' for reading (it might be malformed) ..." % database_file)
        sys.exit(-1)
    # Load profiling traces
    #trace_profiling = np.array(in_file['wave'], dtype=np.float)
    trace_profiling = np.array(in_file['wave'], dtype=float)
    #skpv_a_vec0_evenCoeff0 = np.array(in_file['sca_tmp_skpv'][:,sKeyNo])
    skpv_a_vec0_evenCoeff0 = np.array(in_file['skpv_a_vec0_evenCoeff0'][:,sKeyNo].astype(int))
    skpv_a_vec0_evenCoeff0_next_sKeyNo = np.array(in_file['skpv_a_vec0_evenCoeff0'][:,sKeyNo+1].astype(int))
    skpv_a_vec0_oddCoeff1 = np.array(in_file['skpv_a_vec0_oddCoeff1'][:,sKeyNo].astype(int))
    skpv_a_vec0_oddCoeff1_next_sKeyNo = np.array(in_file['skpv_a_vec0_oddCoeff1'][:,sKeyNo+1].astype(int))
    skpv_a_vec1_evenCoeff0 = np.array(in_file['skpv_a_vec1_evenCoeff0'][:,sKeyNo].astype(int))
    skpv_a_vec1_evenCoeff0_next_sKeyNo = np.array(in_file['skpv_a_vec1_evenCoeff0'][:,sKeyNo+1].astype(int))
    skpv_a_vec1_oddCoeff1 = np.array(in_file['skpv_a_vec1_oddCoeff1'][:,sKeyNo].astype(int))
    skpv_a_vec1_oddCoeff1_next_sKeyNo = np.array(in_file['skpv_a_vec1_oddCoeff1'][:,sKeyNo+1].astype(int))
    skpv_profiling = skpv_a_vec0_evenCoeff0
    bp_b_vec0_evenCoeff0 = np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo].astype(int))
    bp_b_vec0_evenCoeff0_next_sKeyNo = np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo+1].astype(int))
    bp_b_vec0_oddCoeff1 = np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo].astype(int))
    bp_b_vec0_oddCoeff1_next_sKeyNo = np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo+1].astype(int))
    bp_b_vec1_evenCoeff0 = np.array(in_file['bp_b_vec1_evenCoeff0'][:,sKeyNo].astype(int))
    bp_b_vec1_evenCoeff0_next_sKeyNo = np.array(in_file['bp_b_vec1_evenCoeff0'][:,sKeyNo+1].astype(int))
    bp_b_vec1_oddCoeff1 = np.array(in_file['bp_b_vec1_oddCoeff1'][:,sKeyNo].astype(int))
    bp_b_vec1_oddCoeff1_next_sKeyNo = np.array(in_file['bp_b_vec1_oddCoeff1'][:,sKeyNo+1].astype(int))
    
    sca_bp_in = np.array(in_file['sca_bp_in'])
    bp_profiling = [bp_b_vec0_evenCoeff0, bp_b_vec0_oddCoeff1, bp_b_vec0_evenCoeff0_next_sKeyNo, bp_b_vec0_oddCoeff1_next_sKeyNo]
    a_vec0_evenCoeff_by_b_vec0_evenCoeff = np.array(in_file['a_vec0_evenCoeff_by_b_vec0_evenCoeff'][:,sKeyNo])
    a_vec0_evenCoeff_by_b_vec0_oddCoeff = np.array(in_file['a_vec0_evenCoeff_by_b_vec0_oddCoeff'][:,sKeyNo])
    fqmul_profiling = [a_vec0_evenCoeff_by_b_vec0_evenCoeff, a_vec0_evenCoeff_by_b_vec0_oddCoeff]

    return (trace_profiling, bp_profiling, skpv_profiling)
#### Converting traces and metadata to training format
# inputs = [[list of traces], [list of bp]]
#def create_training_data_form(database_folder_train_file, sKeyNo, trainPortion, xType, yType):
    

def create_training_data_form(data_path, sKeyNo, trainPortion, xType, yType, is_test, start_trace, end_trace):
    #(trace_profiling, bp_profiling, skpv_profiling, fqmul_profiling) = load_meta_trace_file(database_folder_train_file, sKeyNo)
    #(trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file(database_folder_train_file, sKeyNo)
    if is_test:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file_from_test(data_path, sKeyNo)
    else:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_files(data_path, start_trace, end_trace)

    Reshaped_trace_profiling= trace_profiling.reshape((trace_profiling.shape[0], trace_profiling.shape[1], 1))
    dataSize = Reshaped_trace_profiling.shape[0]
    trainSize = math.floor(dataSize * trainPortion)
    valLoc = trainSize
    if valLoc == dataSize:
        valLoc = dataSize - 1

    lineNo = list(range(0, bp_profiling[0].shape[0]))
    #bp0_1hot_profiling = np.zeros((bp_profiling[0].shape[0], NumBPinput)).astype(np.int)
    print('2')
    print((bp_profiling[0].shape[0], NumBPinput))
    bp0_1hot_profiling = np.zeros((bp_profiling[0].shape[0], NumBPinput)).astype(int)
    print('bp0_1hot_profiling.shape =', bp0_1hot_profiling.shape, '                bp_profiling[0] =', bp_profiling[0])
    bp0_1hot_profiling[lineNo,bp_profiling[0]] = 1
    Reshaped_bp0_1hot_profiling = bp0_1hot_profiling.reshape((bp0_1hot_profiling.shape[0], NumBPinput, 1))
    #print('Reshaped_bp0_1hot_profiling.shape = ', Reshaped_bp0_1hot_profiling.shape)
   
    lineNo = list(range(0, bp_profiling[1].shape[0]))
    #bp1_1hot_profiling = np.zeros((bp_profiling[1].shape[0], NumBPinput)).astype(np.int)
    bp1_1hot_profiling = np.zeros((bp_profiling[1].shape[0], NumBPinput)).astype(int)
    print('bp1_1hot_profiling.shape =', bp1_1hot_profiling.shape, '                bp_profiling[1] =', bp_profiling[1])
    #input()
    bp1_1hot_profiling[lineNo,bp_profiling[1]] = 1
    Reshaped_bp1_1hot_profiling = bp1_1hot_profiling.reshape((bp1_1hot_profiling.shape[0], NumBPinput, 1))
    
    lineNo = list(range(0, bp_profiling[2].shape[0]))
    #bp0_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[2].shape[0], NumBPinput)).astype(np.int)
    bp0_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[2].shape[0], NumBPinput)).astype(int)
    print('bp0_1hot_profiling_next_sKeyNo.shape =', bp0_1hot_profiling_next_sKeyNo.shape, '    bp_profiling[2] =', bp_profiling[2])
    bp0_1hot_profiling_next_sKeyNo[lineNo,bp_profiling[2]] = 1
    Reshaped_bp0_1hot_profiling_next_sKeyNo = bp0_1hot_profiling_next_sKeyNo.reshape((bp0_1hot_profiling_next_sKeyNo.shape[0], NumBPinput, 1))
    
    lineNo = list(range(0, bp_profiling[3].shape[0]))
    #bp1_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[3].shape[0], NumBPinput)).astype(np.int)
    bp1_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[3].shape[0], NumBPinput)).astype(int)
    print('bp1_1hot_profiling_next_sKeyNo.shape =', bp1_1hot_profiling_next_sKeyNo.shape, '    bp_profiling[3] =', bp_profiling[3])
    #input()
    bp1_1hot_profiling_next_sKeyNo[lineNo,bp_profiling[3]] = 1
    Reshaped_bp1_1hot_profiling_next_sKeyNo = bp1_1hot_profiling_next_sKeyNo.reshape((bp1_1hot_profiling_next_sKeyNo.shape[0], NumBPinput, 1))
    
    #y_train_fqmul0 = to_categorical(fqmul_profiling[0], num_classes=NumFQMULclasses)
    #y_train_fqmul1 = to_categorical(fqmul_profiling[1], num_classes=NumFQMULclasses)
    y_train_skpv = to_categorical(skpv_profiling, num_classes=NumSKPVclasses)

    #xTrain_wave = [Reshaped_trace_profiling[0:trainSize,:,:]]
    #xTrain_wavebp0 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:]]]
    #xTrain_wavebp1 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp1_1hot_profiling[0:trainSize,:]]]
    #xTrain_wavebp01 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]]
    #xTrain_wavebp01next0 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:]]]
    #xTrain_wavebp01next01 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[0:trainSize,:]]]
    #xTrain_wave = Reshaped_trace_profiling[0:trainSize,:,:]
    #xTrain_wavebp0 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp0_1hot_profiling[0:trainSize,:]]
    #xTrain_wavebp1 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]
    #xTrain_wavebp01 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]
    #yTrain_fqmul0 = y_train_fqmul0[0:trainSize,:]
    #yTrain_fqmul1 = y_train_fqmul1[0:trainSize,:]
    #yTrain_skpv = y_train_skpv[0:trainSize,:]

    #xVal_wave = [Reshaped_trace_profiling[valLoc:,:,:]]
    #xVal_wavebp0 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:]]]
    #xVal_wavebp1 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp1_1hot_profiling[valLoc:,:]]]
    #xVal_wavebp01 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]]
    #xVal_wavebp01next0 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:]]]
    #xVal_wavebp01next01 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[valLoc:,:]]]
    #xVal_wave = Reshaped_trace_profiling[valLoc:,:,:]
    #xVal_wavebp0 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp0_1hot_profiling[valLoc:,:]]
    #xVal_wavebp1 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]
    #xVal_wavebp01 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]
    #yVal_fqmul0 = y_train_fqmul0[valLoc:,:]
    #yVal_fqmul1 = y_train_fqmul1[valLoc:,:]
    #yVal_skpv = y_train_skpv[valLoc:,:]

    # Input data creation
    if xType == 'wave':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]]]#xTrain_wave
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]]]#xVal_wave
    elif xType == 'wavebp0':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:]]]#xTrain_wavebp0
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:]]]#xVal_wavebp0
    elif xType == 'wavebp1':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp1_1hot_profiling[0:trainSize,:]]]#xTrain_wavebp1
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp1_1hot_profiling[valLoc:,:]]]#xVal_wavebp1
    elif xType == 'wavebp01':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]]#xTrain_wavebp01
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]]#xVal_wavebp01
    elif xType == 'wavebp01next0':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:]]]#xTrain_wavebp01next0
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:]]]#xVal_wavebp01next0
    elif xType == 'wavebp01next01':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[0:trainSize,:]]]#xTrain_wavebp01next01
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[valLoc:,:]]]#xVal_wavebp01next01
        #print('Created xType, len(xTrain) :', xType, len(xTrain))
        #print('len(xTrain[0]) =', len(xTrain[0]), ';    len(xTrain[1]) =', len(xTrain[1]), 'PRESS ENTER TO CONTINUE')
        #print('PRESS ENTER TO CONTINUE')
        #input()
    # Category creation
    if yType == 'fqmul0':
        yTrain = yTrain_fqmul0
        yTrain_value = fqmul_profiling[0][0:trainSize]
        yVal = yVal_fqmul0
        yVal_value = fqmul_profiling[0][valLoc:]
    elif yType == 'fqmul1':
        yTrain = yTrain_fqmul1
        yTrain_value = fqmul_profiling[1][0:trainSize]
        yVal = yVal_fqmul1
        yVal_value = fqmul_profiling[1][valLoc:]
    elif yType == 'skpv':
        yTrain = y_train_skpv[0:trainSize,:]#yTrain_skpv
        yTrain_value = skpv_profiling[0:trainSize]
        yVal = y_train_skpv[valLoc:,:]#yVal_skpv
        yVal_value = skpv_profiling[valLoc:]

    return xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value

def get_subset(idxs, X_train, y_train):
    profile_data = X_train[0][0]
    print(len(profile_data))
    profile_bp0 = X_train[1][0]
    profile_bp1 = X_train[1][1]
    sub_X_train = [[profile_data[idxs,:,:]], [profile_bp0[idxs,:], profile_bp1[idxs,:]]]
    sub_y_train = y_train[idxs,:]

    return sub_X_train, sub_y_train

def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

def mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
    realkey = int(yTest_value[0])
    rankmat_byKey = np.tile(0, (nruns, maxtrc))
    rankmat_byClass = np.tile(0, (nruns, maxtrc))
    ps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    lpsums_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    #print("-------------------------------------------")
    #print(len(xTest[0][0]))
    #print('%s  is running' % (model.__name__))
    for krun in range(nruns):
        #print('%s  run %d of %d' % (model.__name__, krun+1, nruns))
        #if (krun % nruns) == 0:
        #    print('%s  run %d of %d' % (model.name, krun+1, nruns))
        samp = batches[krun,:]
        #print('All samp')
        #print(samp)
        #ps = model.predict(U[samp,:])

        if xType == 'wave':
            ps = model.predict([xTest[0][0][samp,:,:]])
        elif xType == 'wavebp0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:]])
        elif xType == 'wavebp1':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][1][samp,:]])
        elif xType == 'wavebp01':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:]])
        elif xType == 'wavebp01next0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:]])
        elif xType == 'wavebp01next01':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:], xTest[1][3][samp,:]])

        lps = np.log(ps)
        #print('Len lps')
        #print(len(lps))
        #print(maxtrc)
        lpsums = np.zeros(noHypoKeys)
        #lpsAllHypoKeys = np.zeros((maxtrc, noHypoKeys))
        for i in range(maxtrc):
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = AES_Sbox[P[samp[i]] ^ range(noHypoKeys)]
            realClass = realkey#S[realkey]
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = P[samp[i]] ^ range(noHypoKeys)
            #realClass = HWcompute(S[realkey])
            lpsAllHypoKeys = lps
            #for hypoKey in range(noHypoKeys):
            #   lpsAllHypoKeys[i, hypoKey] = lps[i, S[hypoKey]]
            #print('lpsums.shape =', lpsums.shape, ';   lps.shape =', lps.shape)
            lpsums += lps[i]#, S]
            lpsums_AllHypoKeys_Nruns[i,:,krun] = lpsums
            #print('realkey =', realkey)
            rnk_byKey = sum(lpsums > lpsums[realkey])
            rankmat_byKey[krun, i] = rnk_byKey
            rnk_byClass = sum(lps[i, :] > lps[i, realClass])
            rankmat_byClass[krun, i] = rnk_byClass
        ps_AllClasses_Nruns[:,:,krun] = ps
        lps_AllClasses_Nruns[:,:,krun] = lps
        lps_AllHypoKeys_Nruns[:,:,krun] = lpsAllHypoKeys
    return rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns

def eval_model(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
    rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns = mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses)
    plot_data = ['model_type', rankmat_byKey]
    mr = np.mean(rankmat_byKey, 0)

    return mr[-1]

import pandas as pd

class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, xTest, yTest_value, save_model_name, database_folder_train, eval_interval):
        super().__init__()
        self.xTest = xTest
        self.yTest_value = yTest_value
        self.save_model_name = save_model_name
        self.database_folder_train = database_folder_train
        self.all_rank = []
        self.best_model = self.model
        self.eval_interval = eval_interval

    def on_epoch_end(self, epoch, logs=None):
        #keys = list(logs.keys())
        #print("End epoch {} of training; got log keys: {}".format(epoch, keys))
        if epoch % self.eval_interval == 0:
            batches = np.zeros((nruns_default, maxtrc_default), 'int')
            for i in range(nruns_default):
                batches[i,:] = np.random.choice(len(self.xTest[0][0]), maxtrc_default, False)
            model = self.model
            test_rank = eval_model(model, nruns_default, maxtrc_default, batches, self.xTest, self.yTest_value, noHypoKeys, noClasses)
            
            if len(self.all_rank) > 0:
                if test_rank < np.min(np.array(self.all_rank)):
                    print(test_rank)
                    self.model.save(self.save_model_name + '.keras')
            self.all_rank.append(test_rank)

    def on_train_end(self, logs=None):
        df = pd.DataFrame({'Attack Mean Rank': self.all_rank})
        df.to_csv(os.path.join(self.database_folder_train, 'attack_rank.csv'))
        #Save best model


class MultiKeyCallback(tf.keras.callbacks.Callback):
    def __init__(self, xTest_multi, yTest_vals, save_model_name, database_folder_train, eval_interval):
        super().__init__()
        self.xTest_multi = xTest_multi
        self.yTest_vals = yTest_vals
        self.save_model_name = save_model_name
        self.database_folder_train = database_folder_train
        self.all_rank = []
        self.best_model = self.model
        self.maxtrc = 80 #Max trace num for multi label
        self.eval_interval = eval_interval


    def on_epoch_end(self, epoch, logs=None):
        #keys = list(logs.keys())
        #print("End epoch {} of training; got log keys: {}".format(epoch, keys))
        if epoch % self.eval_interval == 0:
            multi_rank = []
            for i in range(len(self.yTest_vals)): #Iterate each key
                nruns = 1 #data for multi-label is limited, so we did 1 run only
                batches = np.zeros((nruns, self.maxtrc), 'int')
                batches[0] = np.arange(self.maxtrc)
                #for i in range(nruns_default):
                #    batches[i,:] = np.random.choice(len(self.xTest[0][0]), maxtrc_default, False)
                model = self.model
                test_rank = eval_model(model, nruns, self.maxtrc, batches, [[self.xTest_multi[i]]], [self.yTest_vals[i]], noHypoKeys, noClasses)
                multi_rank.append(test_rank)


            if len(self.all_rank) > 0:
                all_rank_np = np.array(self.all_rank)
                #for i in range(len(self.yTest_vals)):
                    #if multi_rank[i] < np.min(all_rank_np[:,i] ):
                        #self.model.save(self.save_model_name + str(i) + '.keras') #Save best model for each key
                all_rank_avg = np.mean(all_rank_np, axis = 1)
                if np.mean(np.array(multi_rank)) < np.min(all_rank_np):
                    self.model.save(self.save_model_name + '_overall.keras')
            self.all_rank.append(multi_rank)
            self.model.save(self.save_model_name + str(epoch) + '.keras')

    def on_train_end(self, logs=None):
        all_cols = ['Mean Rank Key No.' + str(i) for i in range(len(self.yTest_vals))]
        df = pd.DataFrame(self.all_rank, columns=all_cols)
        #df = pd.DataFrame({'Attack Mean Rank': self.all_rank})
        df.to_csv(os.path.join(self.database_folder_train, 'attack_rank_multi.csv'))
        self.model.save(self.save_model_name + '_end.keras')
        #Save best model

def get_subset(idxs, X_train, y_train, xType):
    profile_data = X_train
    if xType == 'wave':
        sub_X_train = profile_data[idxs,:,:]
    elif xType == 'wavebp01':
        profile_bp0 = X_train[1][0]
        profile_bp1 = X_train[1][1]
        sub_X_train = [[profile_data[idxs,:,:]], [profile_bp0[idxs,:], profile_bp1[idxs,:]]]
    sub_y_train = y_train[idxs,:]

    return sub_X_train, sub_y_train

from sklearn.cluster import KMeans

def min_max_sampling(data, num_cluster, num_sample):
    sum_data = np.sum(data, axis = 1)
    exp_data = np.expand_dims(sum_data, axis=1)
    exp_data = np.squeeze(np.hstack((exp_data,exp_data)))
    #kmeans = KMedoids(n_clusters=num_cluster, random_state=0).fit(X)
    kmeans = KMeans(n_clusters=200, random_state=0, n_init="auto").fit(exp_data)
    #Get all maximum distances
    min_dist = []
    for sample in data:
        all_dist = []
        for clusterNo in range(len(kmeans.cluster_centers_)):
            centroid = kmeans.cluster_centers_[clusterNo]
            all_dist.append(np.linalg.norm(centroid-sample))
        all_dist = np.array(all_dist)
        min_dist.append(np.min(all_dist))
    #max_idx = np.argmax(min_dist)
    max_idxs = np.argpartition(min_dist, -num_sample)[-num_sample:]

    return max_idxs

def random_sampling(data, num_sample):
    #print(len(data))
    #print(data.shape)
    np.random.seed(args.seed)
    rand_ids = np.random.choice(len(data), num_sample, replace=False)
    print(len(rand_ids))
    print('---')
    return rand_ids

def unceratainty_sampling(model, xTrain_in, yTrain_in, num_sample):
    preds = model.predict(xTrain_in)
    idx = 0
    #Get samples that below the mean probability
    pred_probs = []
    for sample in preds:
        label_idx = np.where(yTrain_in[idx]==1)
        pred_prob = preds[idx][label_idx]
        idx += 1
        pred_probs.append(pred_prob)
    pred_probs = np.squeeze(np.array(pred_probs))
    #max_idxs = np.argpartition(pred_probs, -num_sample)[-num_sample:]
    min_idxs = np.argpartition(pred_probs, num_sample)[:num_sample]

    return min_idxs



import keras
from keras import layers

def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
    # Attention and Normalization
    x = layers.MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout
    )(inputs, inputs)
    x = layers.Dropout(dropout)(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    res = x + inputs

    # Feed Forward Part
    x = layers.Conv1D(filters=ff_dim, kernel_size=3, activation="relu", padding='same')(res)
    x = layers.Dropout(dropout)(x)
    x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=3, padding='same')(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    return x + res

def build_model(
    input_shape,
    head_size,
    num_heads,
    ff_dim,
    num_transformer_blocks,
    mlp_units,
    dropout=0,
    mlp_dropout=0,
):
    inputs = keras.Input(shape=input_shape)
    x = inputs
    for _ in range(num_transformer_blocks):
        x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)

    #x = layers.GlobalAveragePooling1D(data_format="channels_last")(x)
    x = layers.Flatten()(x)
    for dim in mlp_units:
        x = layers.Dense(dim, activation="relu")(x)
        x = layers.Dropout(mlp_dropout)(x)
    n_classes = noClasses
    outputs = layers.Dense(n_classes, activation="softmax")(x)
    return keras.Model(inputs, outputs)

from sklearn.preprocessing import StandardScaler
from scipy import stats

def normalize(timeseries):
    return (timeseries-timeseries.min())/(timeseries.max()-timeseries.min())

def z_norm(timeseries):
    '''
    timeseries = np.expand_dims(timeseries, 1)
    scaler = StandardScaler()
    scaler.fit(timeseries)
    res = scaler.transform(timeseries)[:,0]
    '''

    return stats.zscore(timeseries)

def normalize_data_per_trace(data):
    print(data.shape)
    for i in range(len(data)):
        data[i] = normalize(data[i])

    return data

def z_normalize_data_per_trace(data):
    print(data.shape)
    for i in range(len(data)):
        data[i] = z_norm(data[i])

    return data

def cummulative_transform(data):
    for i in range(len(data)):
        data[i] = np.cumsum(data[i], dtype=float)
    return data

def create_training_data_optimize(args, data_path, sKeyNo, trainPortion, xType, yType, is_test, start_trace, end_trace):
    val_num = 300000
    val_ids = np.load('val_ids.npz')['ids']
    print(len(val_ids))
    print(val_ids[:10])
    end_val_trace = end_trace + val_num
    if is_test:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file_from_test(data_path, sKeyNo)
    else:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_files(data_path, start_trace, end_val_trace)

    #train_dataset = tf.data.Dataset.from_tensor_slices((trace_profiling[:150000], tf.one_hot(skpv_profiling[:150000], 3329) ))
    #exit()

    #Reshaped_trace_profiling= trace_profiling.reshape((trace_profiling.shape[0], trace_profiling.shape[1], 1))
    print(trace_profiling.shape)


    dataSize = end_trace - start_trace
    trainSize = math.floor(dataSize * trainPortion)
    #print(skpv_profiling[:10])
    y_train_skpv = to_categorical(skpv_profiling, num_classes=NumSKPVclasses)

    if args.normalize == 0:
        xTrain = np.expand_dims(trace_profiling[:dataSize], axis = 2)
    elif args.normalize == 1:
        if args.cummulative == 1:
            xTrain = np.expand_dims(normalize_data_per_trace(cummulative_transform(trace_profiling[:dataSize])), axis = 2)
        elif args.cummulative == 2:
            xTrain = np.expand_dims(cummulative_transform(normalize_data_per_trace(trace_profiling[:dataSize])), axis = 2)
        else:
            xTrain = np.expand_dims(normalize_data_per_trace(trace_profiling[:dataSize]), axis = 2)
    else:
        if args.cummulative == 1:
            xTrain = np.expand_dims(z_normalize_data_per_trace(cummulative_transform(trace_profiling[:dataSize])), axis = 2)
        elif args.cummulative == 2:
            xTrain = np.expand_dims(cummulative_transform(z_normalize_data_per_trace(trace_profiling[:dataSize])), axis = 2)
        else:
            xTrain = np.expand_dims(z_normalize_data_per_trace(trace_profiling[:dataSize]), axis = 2)
        #print(xTrain.shape)
    print(xTrain[0][:10])
    
    if is_test:
        xVal = np.expand_dims(trace_profiling[dataSize:end_val_trace], axis = 2)
        yVal = y_train_skpv[dataSize:end_val_trace]
        yVal_value = skpv_profiling[dataSize:end_val_trace]
    else:
        xVal = np.expand_dims(trace_profiling[val_ids], axis = 2)
        yVal = y_train_skpv[val_ids]
        yVal_value = skpv_profiling[val_ids]

    yTrain = y_train_skpv[:dataSize]
    yTrain_value = skpv_profiling[:dataSize]
    

    #del trace_profiling, bp_profiling, skpv_profiling
    return xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value

def train_model_multiEpochs(xType, database_folder_train, modelLogFolder, logTrainedModel_byFile_folder, logTrainedModel_byEp_folder, logFilename, MLmodel_detail, sKeyNo, class_weight, period, maxEpochs, train_batch_size, args):
    
    if args.train_type == 'baseline':
        model = subModels_gen(args, xType, noConv1Dbranch, noBPbranch, noLayers, tracelen, NumBPinput, MLmodel_detail, modelLogFolder, logFilename, classes=noClasses)
    else:
        model = load_model(trained_model_path)


            
    gc.collect()
    #model.summary()
    print('load success')
    trainPortion = 1.0 #To get all data
    #xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value = create_training_data_form(data_path,sKeyNo, trainPortion, xType, yType,False, args.start_trace, args.end_trace)
    #xTrain = xTrain[0][0]
    #print(yTrain.shape)

    #xTrain_Pool, yTrain_Pool, _, _, _, _ = create_training_data_form(data_path,sKeyNo, trainPortion, xType, yType,False, args.end_trace, 100000)
    xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value = create_training_data_optimize(args, data_path,sKeyNo, trainPortion, xType, yType,False, args.start_trace, args.end_trace)
    print(yTrain.shape)
    print(len(xTrain))
    yTrain = np.expand_dims(yTrain_value, axis = 1)
    yVal = yVal_value
    '''
    xTest, yTest, _xVal_, yVal_, yTest_value, yVal_value = create_training_data_form("KYBER51.H5", sKeyNo, testPortion, xType, yType,True, args.start_trace, args.end_trace)
    print('-------------------')
    print(xTest[0][0].shape)
    print(yTest_value)
    print(yTest.shape)
    print('-----------------')
    '''
    xTest, yTest, _xVal_, yVal_ , yTest_value, yVal_value = create_training_data_optimize(args, "KYBER51.H5", sKeyNo, testPortion, xType, yType,True, args.start_trace, args.end_trace)
    xTest = [[xTest]]
    print(xTest[0][0].shape)
    print(yTest_value)
    print(yTest.shape)
    xTest_multi, yTest_multi = load_multi_attack(args.eval_path)
    print(yTest_value)
    print(yTest_multi)
    xTest_multi = np.expand_dims(xTest_multi, axis = 3)
    print(xTest_multi.shape)
    print(yTrain.shape)


    if args.sampling == 'random':
        all_ids = random_sampling( xTrain, num_sample = args.num_sample)
    else:
        all_ids = np.load(args.all_ids)[:args.num_sample]
    np.save(os.path.join(database_folder_train,'all_ids.npy'), all_ids)
    #xTrain[0][0], yTrain = get_subset(all_ids.astype(int),  xTrain[0][0], yTrain, xType)
    xTrain, yTrain = get_subset(all_ids.astype(int),  xTrain, yTrain, xType)


    database_folder_train_it = database_folder_train
    Path(database_folder_train_it).mkdir(parents=True, exist_ok=True)
    csv_logger = CSVLogger(filename=os.path.join(database_folder_train_it+'/log.csv'), append=True, separator=';')

    save_model_name = (database_folder_train_it+'/model_best')
    attack_callback = CustomCallback(xTest, yTest_value, save_model_name, database_folder_train_it, args.eval_interval)
    attack_multi_callback = MultiKeyCallback(xTest_multi, yTest_multi, save_model_name, database_folder_train_it, args.eval_interval)

    callbacks=[csv_logger, attack_callback, attack_multi_callback]
    print('Len training:')
    print(xTrain.shape)
    print(yTrain.shape)


    yTrain = np.squeeze(yTrain)
    #yVal= np.squeeze(yVal)
    print(yTrain.shape)
    print(yVal.shape)
    print(xVal.shape)

    print(NumSKPVclasses)
    print(yTrain[-10:])
    #exit()

    with tf.device("CPU"):
        #train = tf.data.Dataset.from_tensor_slices(({"input_1": xTrain[0][0], "input_2": xTrain[1][0], "input_3": xTrain[1][1]}, yTrain)).shuffle(4*64).batch(64)
        #val = tf.data.Dataset.from_tensor_slices(({"input_1": xVal[0][0], "input_2": xVal[1][0], "input_3": xVal[1][1]}, yVal)).shuffle(4*64).batch(64)
        if xType == 'wave':
            train = tf.data.Dataset.from_tensor_slices(( xTrain, tf.one_hot(yTrain, NumSKPVclasses) )).shuffle(4*64).batch(train_batch_size)
            val = tf.data.Dataset.from_tensor_slices((xVal, tf.one_hot(yVal,NumSKPVclasses) )).shuffle(4*64).batch(train_batch_size)
        elif xType == 'wavebp01':
            train = tf.data.Dataset.from_tensor_slices(({"input_layer": xTrain, "input_layer_1": xTrain[1][0], "input_layer_2": xTrain[1][1]}, yTrain)).shuffle(4*64).batch(train_batch_size)
            val = tf.data.Dataset.from_tensor_slices(({"input_layer": xVal, "input_layer_1": xVal[1][0], "input_layer_2": xVal[1][1]}, yVal)).shuffle(4*64).batch(train_batch_size)

    if args.transformer == 1:
        input_shape = xTrain.shape[1:]
        
        model = build_model(
            input_shape,
            head_size=args.head_size,
            num_heads=args.num_heads,
            ff_dim=args.ff_dim,
            num_transformer_blocks=args.num_transformer_blocks,
            mlp_units=[512, 256],
            mlp_dropout=0.4,
            dropout=0.25,
        )
        '''
        model.compile(
            loss="sparse_categorical_crossentropy",
            optimizer=keras.optimizers.Adam(learning_rate=1e-4),
            metrics=["sparse_categorical_accuracy"],
        )
        '''
        model.compile(loss='categorical_crossentropy', 
            optimizer=keras.optimizers.Adam(learning_rate=1e-4), metrics=['accuracy'])
        model.summary()
        #exit()

    model.fit(train, batch_size=train_batch_size, verbose = 1, epochs=maxEpochs, callbacks=callbacks, class_weight=class_weight, validation_data=val)        
        #test_rank = eval_model(model, nruns_default, maxtrc_default, batches, xTest, yTest_value, noHypoKeys, noClasses)
        #print('Test Rank:, ', test_rank)




        #model.save(database_folder_train+'/model_iteration_{}.keras'.format(it))

if work == 'train':
    print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
    print('Work =', work)
    classWeights = np.ones(noClasses).astype(int)
    class_weight = dict(enumerate(classWeights))

    train_model_multiEpochs(xType, database_folder_train, modelLogFolder, logTrainedModel_byFile_folder, logTrainedModel_byEp_folder, logFilename, MLmodel_detail, sKeyNo, class_weight, period, maxEpochs, train_batch_size, args)
