#Use this after base model training to get coreset

import os.path
import sys
import h5py
import math
import gc
import numpy as np
#from numba import cuda
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Flatten, Dense, Input, Conv1D, MaxPooling1D, ReLU, Dropout, Concatenate, Activation, Multiply, BatchNormalization #, AveragePooling1D, Add, GlobalAveragePooling1D, GlobalMaxPooling1D
from tensorflow.keras.utils import plot_model   #, get_source_inputs
from tensorflow.keras.utils import get_file
from tensorflow.keras import backend as K
#from keras_applications.imagenet_utils import decode_predictions
#from keras_applications.imagenet_utils import preprocess_input
#from keras_applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model
import tensorflow as tf
import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--train_type', type=str, help='baseline or active')
    parser.add_argument('--sampling', type=str, help='random, minmax or uncertainty')
    parser.add_argument('--name', type=str, help='experiment name', default='test')
    parser.add_argument('--xType', type=str, help='number of ciphertext')
    parser.add_argument('--start_trace', type=int, help='start trace')
    parser.add_argument('--end_trace', type=int)
    parser.add_argument('--batch_size', type=int, help='batch_size', default=256)
    parser.add_argument('--num_epoch', type=int, help='batch_size', default=256)
    parser.add_argument('--trained_model_path', type=str)
    parser.add_argument('--num_iteration', type=int, help='iteration_num', default=5)
    parser.add_argument('--all_ids', type=str)
    parser.add_argument('--eval_path', type=str)
    parser.add_argument('--num_sample', type=int, help='iteration_num', default=5)
    parser.add_argument('--eval_interval', type=int, help='iteration_num', default=1)
    parser.add_argument('--seed', type=int, default = 0)
    parser.add_argument('--normalize', type=int, default = 0)
    parser.add_argument('--transformer', type=int, default = 0)
    parser.add_argument('--mlp', type=int, default = 0)
    parser.add_argument('--cummulative', type=int, default = 0)
    parser.add_argument('--use_bp', type=int, default = 0)
    parser.add_argument('--lr', type=float, default = 0.00001)
    parser.add_argument('--coreset_ratio', type=float, default = 0.1)
    parser.add_argument('--coreset_key', type=str)

    parser.add_argument('--node', type=int, default=256)
    parser.add_argument('--n_layers', type=int, default=6)
    parser.add_argument('--batch_norm', type=int, default=1)
    parser.add_argument('--dropout', type=int, default=1)
    parser.add_argument('--dropout_rate', type=float, default=0.2)

    return parser

parser = parse_arguments()
args = parser.parse_args()
#tf.config.experimental.enable_op_determinism()

def set_seeds(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)

set_seeds(args.seed)

bp_range = [0, 3329]
skpv_range = [0, 3328]
fqmul_range = [-1828, 1664]
tracelen = 600
NumFQMULclasses = fqmul_range[1] - fqmul_range[0] + 1;  # number of classes for fqmul(skpv, bp)
NumSKPVclasses = skpv_range[1] - skpv_range[0] + 1;     # number of classes for skpv
NumBPinput = bp_range[1] - bp_range[0] + 1;             # number of input for bp (ciphertext)
noClasses = NumSKPVclasses
noHypoKeys = NumSKPVclasses
sKeyNo = 0  # Note: sKeyNo is in range 0 to 3 and which subkeys are they are decided by code in m4 (NOT by code in PC)
work = 'train' #'train'  #'attack'
training_file_list = ['Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data000000to099999_600samples.h5',\
'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data100000to199999_600samples.h5']
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data200000to299999_600samples.h5']#,\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data300000to399999_600samples.h5',\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data400000to499999_600samples.h5']

trained_model_path = args.trained_model_path
data_path = 'data.npz'
nruns_default = 10
maxtrc_default = 200
#maxtrc_default = 115
testPortion = 1
attack_byModel_epNo = 232


# training parameters
train_batch_size = args.batch_size#100#150#200#250#500#640 #80 for mars45 #170 for mars56
period = 8 #8
maxEpochs = args.num_epoch#3072#2048#1536#1280#1024#512#256 #1536
attack_byModel_fileNo = int(attack_byModel_epNo/period)
N_TRACE = 20000
Threshold_Save = 200

#model hyper-parameters
noConv1Dbranch = 1
noLayers = 6    # if newly train
noClassificationLayer = 1
GPU_clear = True    # False

# training data type
xType = args.xType  #'wave' #'wavebp0' #'wavebp1' #'wavebp01' #'wavebp01next0' #'wavebp01next01'
yType = 'skpv'    #'fqmul0' #'fqmul1' #'skpv' 
trainPortion = 0.8

# Database and logs for model and training progress (epochs)
attackModel = 'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1'
device = 'm4_CWLite'
attackModel_dev = attackModel + '_' + device
attackModel_dev_folder = '../' + attackModel_dev + '/'

MLmodelStruct = '4C4FC_2BP4FC4FC_J4FCSM'
#MLmodel_detail = '3C[512_128_64]_2BP4FC[1024_512_256_128]4FC[1024_512_256_128]_J4FC[1024_512_256_128]SM'
MLmodel_detail = '4C/512_256_128_64/_2BP4FC/1024_512_256_128/4FC/1024_512_256_128/_J4FC/1024_512_256_128/SM'

hyper_ver = 'hy0001010101_skpv0'    #hyper-parameter contains 5 groups: Conv1D, FC for Conv1D, BP0, BP1, FC for joined BPs
#dataFile_train = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'  #'20kDatax25'
dataFile_train_folder = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'    #'20kDatax25'
dataFile_attack = '100kDatax1_test'#'skvp0_0_700points100kDatax1attack' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'  #'20kDatax25'
model_input_type = '_in[[][]]_tf2' #'in[]_tf2' #'[[][]]_tf2'
#data_type = dataFile_train + model_input_type
data_type = '100kDataxN' + str(len(training_file_list)) + model_input_type
#database_folder_train = attackModel_dev_folder + attackModel + '_' + dataFile_train_folder + '_h5/'
name_length = len('200k_2000cluster/minmax_')
save_path = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_lr_{}_mlp_{}'.format(args.name ,args.train_type, args.sampling ,xType, args.start_trace, args.end_trace, args.num_sample, args.seed, args.normalize, args.cummulative, args.lr, args.mlp)
#database_folder_train = os.path.join('trained_models', save_path)
database_folder_train = os.path.join('multi_attack_trained_models', save_path)
Path(database_folder_train).mkdir(parents=True, exist_ok=True)
database_folder_attack = attackModel_dev_folder + attackModel + '_' + dataFile_attack + '_h5/'
logFilename = MLmodelStruct + '_' + hyper_ver
DLmodel_name = logFilename
#DLmodel_folder = attackModel_dev_folder + logFilename + '_' + data_type + '/'
DLmodel_folder = 'models/'
modelLogFolder = DLmodel_folder + 'log' + DLmodel_name + '/'
logTrainedModel_byFile_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byDataFile/'
#logTrainedModel_byEp_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byEpoch/'
logTrainedModel_byEp_folder = logTrainedModel_byFile_folder
attackLogFolder = DLmodel_folder + 'log' + DLmodel_name + '_attack/'
if os.path.isdir(DLmodel_folder) == False:
    os.mkdir(DLmodel_folder)
if os.path.isdir(modelLogFolder) == False:
    os.mkdir(modelLogFolder)
if os.path.isdir(logTrainedModel_byFile_folder) == False:
    os.mkdir(logTrainedModel_byFile_folder)
if os.path.isdir(logTrainedModel_byEp_folder) == False:
    os.mkdir(logTrainedModel_byEp_folder)
print('DLmodel_folder =', DLmodel_folder)
print('modelLogFolder =', modelLogFolder)
print('logTrainedModel_byFile_folder =', logTrainedModel_byFile_folder)
print('logTrainedModel_byEp_folder =', logTrainedModel_byEp_folder)


################################################################################################
####################################### MODELS STRUCTURE #######################################
################################################################################################
# Input BatchNormalization for each PoI size
#                           subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_inputBNorms =   [   1,      0,      0,      0,      0,      0]
###################### MULTI CONVOLUTIONAL-SIZE CONVOLUTION ######################
# Convolutional nodes
# matrix showing number of nodes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_NoConvNodes =   [   512,    256,    128,    64,     0,      0]
# Convolutional filter sizes
# matrix showing filter sizes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convKernelSizes = [   3,      3,      3,      3,      0,      0] # subModel0

###############################################
# Pooling size in convolutional layers
# matrix showing MaxPooling sizes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convPoolSizes = [   2,      2,      2,      2,      0,      0] # subModel0

# Pooling stride in convolutional layers
# matrix showing MaxPooling strike in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convPoolStrides = [   3,      3,      3,      3,      0,      0] # subModel0

# BatchNormalization in convolutional layers
# matrix showing BatchNormalization condition in each convolutional layer in each PoI length
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_convBNorms = [   1,      1,      1,      1,      0,      0] # subModel0
# Dropout in convolutional layers
# matrix showing Dropout value in each convolutional layer in each PoI length
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_convDrops = [   0,      0,      0,      0,      0,      0] # subModel0

###################### MULTI CONVOLUTIONAL-SIZE FULLY-CONNECTED ######################
# Flatten Convolutional feature map before Fully connected
#                           subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_convFeatFlat = [    1,      0,      0,      0,      0,      0]
# Fully-connected for convolutional value before adding Plaintext
# matrix showing fully-connected condition before adding Plaintext
#                   layer0  layer1  layer2  layer3  layer4  layer5
subMods_FCs = [   1024,   512,    256,    128,    0,      0] # subModel0
# BatchNormalization for fully-connected of convolutional value before adding Plaintext
# matrix showing BatchNormalization for fully-connected condition before adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_FC_BNorms = [   1,      1,      1,      1,      0,      0] # subModel0

# Dropout for fully-connected of convolutional value before adding Plaintext
# matrix showing Dropout for fully-connected condition before adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_FC_Drops = [   0.2,    0,      0.2,    0,      0,      0] # subModel0

###################### MULTI_CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENSION ######################
# Plaintext adding here
if xType == 'wave':
    noBPbranch = 0
    #                       sub0    sub1    sub2    sub3    sub4    sub5
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp0':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0] # conv1D branch 0
elif xType == 'wavebp1':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01':
    noBPbranch = 2
    subMods_Pext =  [  1,      1,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01next0':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01next01':
    noBPbranch = 4
    subMods_Pext =  [  1,      1,      1,      1,      0,      0]  # conv1D branch 0


###################### (MULTI CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENDED) FULLY-CONNECTED ######################
# Fully-connected for convolutional value after adding Plaintext
# matrix showing fully-connected condition after adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FCs = [   [   1024,   1024,   512,    256,    128,    0], # subModel0
                        [   1024,   1024,   512,    256,    128,    0], # subModel1
                        [   1024,   1024,   512,    256,    128,    0], # subModel2
                        [   1024,   1024,   512,    256,    128,    0], # subModel3
                        [   0,      0,      0,      0,      0,      0], # subModel4
                        [   0,      0,      0,      0,      0,      0]]    # subModel5
# BatchNormalization for fully-connected of convolutional value after adding Plaintext
# matrix showing BatchNormalization for fully-connected condition after adding Plaintext
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FC_BNorms = [ [   1,      1,      1,      1,      1,      0], # subModel0
                            [   1,      1,      1,      1,      1,      0], # subModel1
                            [   1,      1,      1,      1,      1,      0], # subModel2
                            [   1,      1,      1,      1,      1,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]]    # subModel5
# Dropout for fully-connected of convolutional value after adding Plaintext
# matrix showing Dropout for fully-connected condition after adding Plaintext
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FC_Drops = [  [   0.2,    0.2,    0,      0.1,    0,      0], # subModel0
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel1
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel2
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]]    # subModel5

# Softmax for each sub-model if available
#                               subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_classification =    [  0,      0,      0,      0,      0,      0]

if xType == 'wave':
    subMods_join =  [   0]
else:
    subMods_join =  [   1]  

subMods_join_FCs =  [  1024,   1024,   512,    256,    128,    0]
# BatchNormalization for fully-connected of convolutional value after joining PoIs
# matrix showing BatchNormalization for fully-connected condition after joining PoIs
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_join_FC_BNorms =    [  1,      1,      1,      1,      1,      0]
# Dropout for fully-connected of convolutional value after joining PoIs
# matrix showing Dropout for fully-connected condition after joining PoIs
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_join_FC_Drops = [  0.2,        0.2,        0,      0.1,        0,      0]

# Softmax for joined-model if available
subMods_join_classification =   [   1]

################################################################################################
##################################### MODELS STRUCTURE END #####################################
################################################################################################


def check_file_exists(file_path):
    if os.path.exists(file_path) == False:
        print("Error: provided file path '%s' does not exist!" % file_path)
        sys.exit(-1)
    return

def listDirWithExt(directory, extension):
    return (f for f in os.listdir(directory) if f.endswith('.' + extension))

def subModels_gen_MLP_Phu(args, xType, noConv1Dbranch, noBPbranch, noLayers, tracelen, NumBPinput, MLmodel_detail, modelLogFolder, logFilename, classes=noClasses):
    input_trace_shape = (tracelen,1)
    input_Ptext1hot_shape = (NumBPinput,1)
    m_traceinputs = []
    m_Ptextinputs = []
    inputs = []
    for dataNo in range(noConv1Dbranch):
        trace_input = Input(shape=input_trace_shape)    #trace_input need to be generated many times to shows that they are different inputs
        m_traceinputs.append(trace_input)
        #inputs.append(trace_input)
    Ptext_input1 = Input(shape=input_Ptext1hot_shape)    #Ptext_input need to be generated many times to shows that they are different inputs
        #inputs.append(Ptext_input)
    Ptext_input2 = Input(shape=input_Ptext1hot_shape)
    Ptext_input3 = Input(shape=input_Ptext1hot_shape)
    Ptext_input4 = Input(shape=input_Ptext1hot_shape)
    if xType == 'wave':
        inputs = [trace_input]
    elif xType == 'wavebp01':
        inputs = [trace_input, Ptext_input1, Ptext_input2]


    x = BatchNormalization()(inputs[0])
    x = Dense(args.node, input_dim=600, activation='relu')(inputs[0][:,:,0])

    node = args.node
    for i in range(args.n_layers - 1):
        if i < args.n_layers // 2:
            node = node * 2
        else:
            node = node // 2

        x = Dense(node, activation='relu')(x)
        if args.batch_norm == 1:
            x = BatchNormalization(trainable=True)(x)
        if args.dropout == 1:
            x = Dropout(args.dropout_rate)(x)



    outputs = Dense(classes, activation='softmax')(x)
    
    sModel = Model(inputs, outputs, name='model')
    sModel.summary()
    tf.keras.utils.plot_model(sModel, show_shapes=True, to_file='model.png')
    # plot graph of ensemble
    #plot_model(sModel, show_shapes=True, to_file=modelLogFolder + logFilename + '_modelGraph.png')
    optimizer = RMSprop(learning_rate=args.lr)
    sModel.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        #allBranchOuts_list.append(sModel)
    #return allBranchOuts_list
    return sModel

def subModels_gen_MLP(args,xType,noConv1Dbranch, noBPbranch, noLayers, tracelen, NumBPinput, MLmodel_detail, modelLogFolder, logFilename, classes=noClasses):
    input_trace_shape = (tracelen,1)
    input_Ptext1hot_shape = (NumBPinput,1)
    m_traceinputs = []
    m_Ptextinputs = []
    inputs = []
    for dataNo in range(noConv1Dbranch):
        trace_input = Input(shape=input_trace_shape)    #trace_input need to be generated many times to shows that they are different inputs
        m_traceinputs.append(trace_input)
        #inputs.append(trace_input)
    Ptext_input1 = Input(shape=input_Ptext1hot_shape)    #Ptext_input need to be generated many times to shows that they are different inputs
        #inputs.append(Ptext_input)
    Ptext_input2 = Input(shape=input_Ptext1hot_shape)
    Ptext_input3 = Input(shape=input_Ptext1hot_shape)
    Ptext_input4 = Input(shape=input_Ptext1hot_shape)
    if xType == 'wave':
        inputs = [trace_input]
    elif xType == 'wavebp01':
        inputs = [trace_input, Ptext_input1, Ptext_input2]

    #First block: taking inputs and output features
    #x = Conv1D(128, 3, strides=2, padding="same")(inputs[0])
    #x = BatchNormalization()(x)
    #x = Activation("relu")(x)

    x = BatchNormalization()(inputs[0])
    #
    for layerNo in range(noLayers):
        if (subMods_NoConvNodes[layerNo]!=0 and subMods_convKernelSizes[layerNo]!=0 and subMods_convPoolSizes[layerNo]!=0 and subMods_convPoolStrides[layerNo]!=0):
            x = Dense(subMods_NoConvNodes[layerNo])(x)
            x = MaxPooling1D(subMods_convPoolSizes[layerNo], strides=subMods_convPoolStrides[layerNo], name='ConvBlock_'+'subModels'+'_pool'+str(layerNo))(x)
        if subMods_convBNorms[layerNo] != 0:
            x = BatchNormalization(trainable=True)(x)
        if (subMods_convDrops[layerNo]!=0):
            x = Dropout(subMods_convDrops[layerNo])(x)

    for layerNo in range(noLayers):
        # FC_PoI_size*y*_layer*x*
        if ((layerNo==0) and (subMods_convFeatFlat!=0)):
            x = Flatten()(x)
        if (subMods_FCs[layerNo]!=0):
            x = Dense(subMods_FCs[layerNo])(x)
        if (subMods_FC_BNorms[layerNo]!=0):
            x = BatchNormalization(trainable=True)(x)
        if (subMods_FC_Drops[layerNo]!=0):
            x = Dropout(subMods_FC_Drops[layerNo])(x)
    '''
    for layerNo in range(noLayers):
        # FC_Pext_size*y*_layer*x*
        if (subMods_join_FCs[layerNo]!=0):
            print(subMods_join_FCs[layerNo])
            x = Dense(subMods_join_FCs[layerNo], activation='relu')(x)
        if (subMods_join_FC_BNorms[layerNo]!=0):
            #BPbranchOuts_joined = tf.layers.batch_normalization(BPbranchOuts_joined, trainable=True, name='subMods_join_FCs_BNorm'+str(conv1DbranchNo)+'_'+str(layerNo))
            x = BatchNormalization(trainable=True )(x)
        if (subMods_join_FC_Drops[layerNo]!=0):
            x = Dropout(subMods_join_FC_Drops[layerNo])(x)
    '''
    output_cnn = x
    #Ptext Area
    BPbranchOuts_list = []
    for BPbranchNo in range(noBPbranch):
        # PtextExt_size*y*
        if (subMods_Pext[BPbranchNo]!=0):
            Ptext_flatten = Flatten()(inputs[1 + BPbranchNo])
            x = Concatenate()([output_cnn, Ptext_flatten])
        for layerNo in range(noLayers):
            # FC_Pext_size*y*_layer*x*
            if (subMods_Pext_FCs[BPbranchNo][layerNo]!=0):
                x = Dense(subMods_Pext_FCs[BPbranchNo][layerNo], activation='relu')(x)
            if (subMods_Pext_FC_BNorms[BPbranchNo][layerNo]!=0):
                x = BatchNormalization(trainable=True)(x)
            if (subMods_Pext_FC_Drops[BPbranchNo][layerNo]!=0):
                x = Dropout(subMods_Pext_FC_Drops[BPbranchNo][layerNo])(x)

        ###################### CLASSIFICATION (SOFTMAX) ######################
        if (subMods_classification[BPbranchNo]!=0):
            x = Dense(classes, activation='softmax')(x)
        print(x)
        BPbranchOuts_list.append(x)

    print(BPbranchOuts_list)
    if len(BPbranchOuts_list) != 0:
        if subMods_join != 0:
            x = Concatenate()(BPbranchOuts_list)
    
    for layerNo in range(noLayers):
        # FC_Pext_size*y*_layer*x*
        if (subMods_join_FCs[layerNo]!=0):
            x = Dense(subMods_join_FCs[layerNo], activation='relu')(x)
        if (subMods_join_FC_BNorms[layerNo]!=0):
            x = BatchNormalization(trainable=True)(x)
        if (subMods_join_FC_Drops[layerNo]!=0):
            x = Dropout(subMods_join_FC_Drops[layerNo])(x)
    
    #Softmax
    outputs = Dense(classes, activation='softmax')(x)
    
    sModel = Model(inputs, outputs, name='model')
    sModel.summary()
   
    tf.keras.utils.plot_model(sModel, show_shapes=True, to_file='model.png')
    # plot graph of ensemble
    #plot_model(sModel, show_shapes=True, to_file=modelLogFolder + logFilename + '_modelGraph.png')
    optimizer = RMSprop(learning_rate=args.lr)
    sModel.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        #allBranchOuts_list.append(sModel)
    #return allBranchOuts_list
    return sModel

def subModels_gen(args,xType,noConv1Dbranch, noBPbranch, noLayers, tracelen, NumBPinput, MLmodel_detail, modelLogFolder, logFilename, classes=noClasses):
    input_trace_shape = (tracelen,1)
    input_Ptext1hot_shape = (NumBPinput,1)
    m_traceinputs = []
    m_Ptextinputs = []
    inputs = []
    for dataNo in range(noConv1Dbranch):
        trace_input = Input(shape=input_trace_shape)    #trace_input need to be generated many times to shows that they are different inputs
        m_traceinputs.append(trace_input)
        #inputs.append(trace_input)
    Ptext_input1 = Input(shape=input_Ptext1hot_shape)    #Ptext_input need to be generated many times to shows that they are different inputs
        #inputs.append(Ptext_input)
    Ptext_input2 = Input(shape=input_Ptext1hot_shape)
    Ptext_input3 = Input(shape=input_Ptext1hot_shape)
    Ptext_input4 = Input(shape=input_Ptext1hot_shape)
    if xType == 'wave':
        inputs = [trace_input]
    elif xType == 'wavebp01':
        inputs = [trace_input, Ptext_input1, Ptext_input2]

    #First block: taking inputs and output features
    #x = Conv1D(128, 3, strides=2, padding="same")(inputs[0])
    #x = BatchNormalization()(x)
    #x = Activation("relu")(x)

    x = BatchNormalization()(inputs[0])
    #
    for layerNo in range(noLayers):
        if (subMods_NoConvNodes[layerNo]!=0 and subMods_convKernelSizes[layerNo]!=0 and subMods_convPoolSizes[layerNo]!=0 and subMods_convPoolStrides[layerNo]!=0):
            x = Conv1D(subMods_NoConvNodes[layerNo], subMods_convKernelSizes[layerNo], activation='relu', padding='same', name='ConvBlock_'+'subModels'+'_conv'+str(layerNo))(x)
            x = MaxPooling1D(subMods_convPoolSizes[layerNo], strides=subMods_convPoolStrides[layerNo], name='ConvBlock_'+'subModels'+'_pool'+str(layerNo))(x)
        if subMods_convBNorms[layerNo] != 0:
            x = BatchNormalization(trainable=True)(x)
        if (subMods_convDrops[layerNo]!=0):
            x = Dropout(subMods_convDrops[layerNo])(x)

    for layerNo in range(noLayers):
        # FC_PoI_size*y*_layer*x*
        if ((layerNo==0) and (subMods_convFeatFlat!=0)):
            x = Flatten()(x)
        if (subMods_FCs[layerNo]!=0):
            x = Dense(subMods_FCs[layerNo])(x)
        if (subMods_FC_BNorms[layerNo]!=0):
            x = BatchNormalization(trainable=True)(x)
        if (subMods_FC_Drops[layerNo]!=0):
            x = Dropout(subMods_FC_Drops[layerNo])(x)
    '''
    for layerNo in range(noLayers):
        # FC_Pext_size*y*_layer*x*
        if (subMods_join_FCs[layerNo]!=0):
            print(subMods_join_FCs[layerNo])
            x = Dense(subMods_join_FCs[layerNo], activation='relu')(x)
        if (subMods_join_FC_BNorms[layerNo]!=0):
            #BPbranchOuts_joined = tf.layers.batch_normalization(BPbranchOuts_joined, trainable=True, name='subMods_join_FCs_BNorm'+str(conv1DbranchNo)+'_'+str(layerNo))
            x = BatchNormalization(trainable=True )(x)
        if (subMods_join_FC_Drops[layerNo]!=0):
            x = Dropout(subMods_join_FC_Drops[layerNo])(x)
    '''
    output_cnn = x
    #Ptext Area
    BPbranchOuts_list = []
    for BPbranchNo in range(noBPbranch):
        # PtextExt_size*y*
        if (subMods_Pext[BPbranchNo]!=0):
            Ptext_flatten = Flatten()(inputs[1 + BPbranchNo])
            x = Concatenate()([output_cnn, Ptext_flatten])
        for layerNo in range(noLayers):
            # FC_Pext_size*y*_layer*x*
            if (subMods_Pext_FCs[BPbranchNo][layerNo]!=0):
                x = Dense(subMods_Pext_FCs[BPbranchNo][layerNo], activation='relu')(x)
            if (subMods_Pext_FC_BNorms[BPbranchNo][layerNo]!=0):
                x = BatchNormalization(trainable=True)(x)
            if (subMods_Pext_FC_Drops[BPbranchNo][layerNo]!=0):
                x = Dropout(subMods_Pext_FC_Drops[BPbranchNo][layerNo])(x)

        ###################### CLASSIFICATION (SOFTMAX) ######################
        if (subMods_classification[BPbranchNo]!=0):
            x = Dense(classes, activation='softmax')(x)
        print(x)
        BPbranchOuts_list.append(x)

    print(BPbranchOuts_list)
    if len(BPbranchOuts_list) != 0:
        if subMods_join != 0:
            x = Concatenate()(BPbranchOuts_list)
    
    for layerNo in range(noLayers):
        # FC_Pext_size*y*_layer*x*
        if (subMods_join_FCs[layerNo]!=0):
            x = Dense(subMods_join_FCs[layerNo], activation='relu')(x)
        if (subMods_join_FC_BNorms[layerNo]!=0):
            x = BatchNormalization(trainable=True)(x)
        if (subMods_join_FC_Drops[layerNo]!=0):
            x = Dropout(subMods_join_FC_Drops[layerNo])(x)
    
    #Softmax
    outputs = Dense(classes, activation='softmax')(x)
    
    sModel = Model(inputs, outputs, name='model')
    sModel.summary()
    tf.keras.utils.plot_model(sModel, show_shapes=True, to_file='model.png')
    # plot graph of ensemble
    #plot_model(sModel, show_shapes=True, to_file=modelLogFolder + logFilename + '_modelGraph.png')
    optimizer = RMSprop(learning_rate=args.lr)
    sModel.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        #allBranchOuts_list.append(sModel)
    #return allBranchOuts_list
    return sModel

# make a prediction with a stacked model
# https://machinelearningmastery.com/stacking-ensemble-for-deep-learning-neural-networks/
def predict_stacked_model(model, inputX):
    # prepare input data
    X = [inputX for _ in range(len(model.input))]
    # make prediction
    return model.predict(X, verbose=0)

def load_sca_model(model_file):
    check_file_exists(model_file)
    try:
            model = load_model(model_file)
    except:
        print("Error: can't load Keras model file '%s'" % model_file)
        sys.exit(-1)
    return model

####### THESE FUNCTIONS ARE SPECIALIZED FOR KYBER    #######
####### Loading traces and metadata from file ############
#def load_meta_trace_file(database_file, sKeyNo, load_metadata=False):
def load_meta_trace_files(data_path, start_trace, end_trace):
    data = np.load(data_path)
    trace_profiling = data['data']
    bp_profiling = data['bp']
    skpv_profiling = data['label']

    return (trace_profiling[start_trace:end_trace], bp_profiling[:,start_trace:end_trace], skpv_profiling[start_trace:end_trace])#, fqmul_profiling)

def load_meta_trace_file_from_test(database_file, sKeyNo, load_metadata=False):
    print('\nLoad database_file =', database_file)
    check_file_exists(database_file)
    # Open the Kyber database HDF5 for reading
    try:
        in_file  = h5py.File(database_file, "r")
    except:
        print("Error: can't open HDF5 file '%s' for reading (it might be malformed) ..." % database_file)
        sys.exit(-1)
    # Load profiling traces
    #trace_profiling = np.array(in_file['wave'], dtype=np.float)
    trace_profiling = np.array(in_file['wave'], dtype=float)
    #skpv_a_vec0_evenCoeff0 = np.array(in_file['sca_tmp_skpv'][:,sKeyNo])
    skpv_a_vec0_evenCoeff0 = np.array(in_file['skpv_a_vec0_evenCoeff0'][:,sKeyNo].astype(int))
    skpv_a_vec0_evenCoeff0_next_sKeyNo = np.array(in_file['skpv_a_vec0_evenCoeff0'][:,sKeyNo+1].astype(int))
    skpv_a_vec0_oddCoeff1 = np.array(in_file['skpv_a_vec0_oddCoeff1'][:,sKeyNo].astype(int))
    skpv_a_vec0_oddCoeff1_next_sKeyNo = np.array(in_file['skpv_a_vec0_oddCoeff1'][:,sKeyNo+1].astype(int))
    skpv_a_vec1_evenCoeff0 = np.array(in_file['skpv_a_vec1_evenCoeff0'][:,sKeyNo].astype(int))
    skpv_a_vec1_evenCoeff0_next_sKeyNo = np.array(in_file['skpv_a_vec1_evenCoeff0'][:,sKeyNo+1].astype(int))
    skpv_a_vec1_oddCoeff1 = np.array(in_file['skpv_a_vec1_oddCoeff1'][:,sKeyNo].astype(int))
    skpv_a_vec1_oddCoeff1_next_sKeyNo = np.array(in_file['skpv_a_vec1_oddCoeff1'][:,sKeyNo+1].astype(int))
    skpv_profiling = skpv_a_vec0_evenCoeff0
    bp_b_vec0_evenCoeff0 = np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo].astype(int))
    bp_b_vec0_evenCoeff0_next_sKeyNo = np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo+1].astype(int))
    bp_b_vec0_oddCoeff1 = np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo].astype(int))
    bp_b_vec0_oddCoeff1_next_sKeyNo = np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo+1].astype(int))
    bp_b_vec1_evenCoeff0 = np.array(in_file['bp_b_vec1_evenCoeff0'][:,sKeyNo].astype(int))
    bp_b_vec1_evenCoeff0_next_sKeyNo = np.array(in_file['bp_b_vec1_evenCoeff0'][:,sKeyNo+1].astype(int))
    bp_b_vec1_oddCoeff1 = np.array(in_file['bp_b_vec1_oddCoeff1'][:,sKeyNo].astype(int))
    bp_b_vec1_oddCoeff1_next_sKeyNo = np.array(in_file['bp_b_vec1_oddCoeff1'][:,sKeyNo+1].astype(int))
    
    sca_bp_in = np.array(in_file['sca_bp_in'])
    bp_profiling = [bp_b_vec0_evenCoeff0, bp_b_vec0_oddCoeff1, bp_b_vec0_evenCoeff0_next_sKeyNo, bp_b_vec0_oddCoeff1_next_sKeyNo]
    a_vec0_evenCoeff_by_b_vec0_evenCoeff = np.array(in_file['a_vec0_evenCoeff_by_b_vec0_evenCoeff'][:,sKeyNo])
    a_vec0_evenCoeff_by_b_vec0_oddCoeff = np.array(in_file['a_vec0_evenCoeff_by_b_vec0_oddCoeff'][:,sKeyNo])
    fqmul_profiling = [a_vec0_evenCoeff_by_b_vec0_evenCoeff, a_vec0_evenCoeff_by_b_vec0_oddCoeff]

    return (trace_profiling, np.array(bp_profiling), skpv_profiling)
#### Converting traces and metadata to training format
# inputs = [[list of traces], [list of bp]]
#def create_training_data_form(database_folder_train_file, sKeyNo, trainPortion, xType, yType):
    

def create_training_data_form(data_path, sKeyNo, trainPortion, xType, yType, is_test, start_trace, end_trace):
    #(trace_profiling, bp_profiling, skpv_profiling, fqmul_profiling) = load_meta_trace_file(database_folder_train_file, sKeyNo)
    #(trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file(database_folder_train_file, sKeyNo)
    if is_test:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file_from_test(data_path, sKeyNo)
    else:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_files(data_path, start_trace, end_trace)

    Reshaped_trace_profiling= trace_profiling.reshape((trace_profiling.shape[0], trace_profiling.shape[1], 1))
    dataSize = Reshaped_trace_profiling.shape[0]
    trainSize = math.floor(dataSize * trainPortion)
    valLoc = trainSize
    if valLoc == dataSize:
        valLoc = dataSize - 1

    lineNo = list(range(0, bp_profiling[0].shape[0]))
    #bp0_1hot_profiling = np.zeros((bp_profiling[0].shape[0], NumBPinput)).astype(np.int)
    print('2')
    print((bp_profiling[0].shape[0], NumBPinput))
    bp0_1hot_profiling = np.zeros((bp_profiling[0].shape[0], NumBPinput)).astype(int)
    print('bp0_1hot_profiling.shape =', bp0_1hot_profiling.shape, '                bp_profiling[0] =', bp_profiling[0])
    bp0_1hot_profiling[lineNo,bp_profiling[0]] = 1
    Reshaped_bp0_1hot_profiling = bp0_1hot_profiling.reshape((bp0_1hot_profiling.shape[0], NumBPinput, 1))
    #print('Reshaped_bp0_1hot_profiling.shape = ', Reshaped_bp0_1hot_profiling.shape)
   
    lineNo = list(range(0, bp_profiling[1].shape[0]))
    #bp1_1hot_profiling = np.zeros((bp_profiling[1].shape[0], NumBPinput)).astype(np.int)
    bp1_1hot_profiling = np.zeros((bp_profiling[1].shape[0], NumBPinput)).astype(int)
    print('bp1_1hot_profiling.shape =', bp1_1hot_profiling.shape, '                bp_profiling[1] =', bp_profiling[1])
    #input()
    bp1_1hot_profiling[lineNo,bp_profiling[1]] = 1
    Reshaped_bp1_1hot_profiling = bp1_1hot_profiling.reshape((bp1_1hot_profiling.shape[0], NumBPinput, 1))
    
    lineNo = list(range(0, bp_profiling[2].shape[0]))
    #bp0_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[2].shape[0], NumBPinput)).astype(np.int)
    bp0_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[2].shape[0], NumBPinput)).astype(int)
    print('bp0_1hot_profiling_next_sKeyNo.shape =', bp0_1hot_profiling_next_sKeyNo.shape, '    bp_profiling[2] =', bp_profiling[2])
    bp0_1hot_profiling_next_sKeyNo[lineNo,bp_profiling[2]] = 1
    Reshaped_bp0_1hot_profiling_next_sKeyNo = bp0_1hot_profiling_next_sKeyNo.reshape((bp0_1hot_profiling_next_sKeyNo.shape[0], NumBPinput, 1))
    
    lineNo = list(range(0, bp_profiling[3].shape[0]))
    #bp1_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[3].shape[0], NumBPinput)).astype(np.int)
    bp1_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[3].shape[0], NumBPinput)).astype(int)
    print('bp1_1hot_profiling_next_sKeyNo.shape =', bp1_1hot_profiling_next_sKeyNo.shape, '    bp_profiling[3] =', bp_profiling[3])
    #input()
    bp1_1hot_profiling_next_sKeyNo[lineNo,bp_profiling[3]] = 1
    Reshaped_bp1_1hot_profiling_next_sKeyNo = bp1_1hot_profiling_next_sKeyNo.reshape((bp1_1hot_profiling_next_sKeyNo.shape[0], NumBPinput, 1))
    
    #y_train_fqmul0 = to_categorical(fqmul_profiling[0], num_classes=NumFQMULclasses)
    #y_train_fqmul1 = to_categorical(fqmul_profiling[1], num_classes=NumFQMULclasses)
    y_train_skpv = to_categorical(skpv_profiling, num_classes=NumSKPVclasses)

    #xTrain_wave = [Reshaped_trace_profiling[0:trainSize,:,:]]
    #xTrain_wavebp0 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:]]]
    #xTrain_wavebp1 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp1_1hot_profiling[0:trainSize,:]]]
    #xTrain_wavebp01 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]]
    #xTrain_wavebp01next0 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:]]]
    #xTrain_wavebp01next01 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[0:trainSize,:]]]
    #xTrain_wave = Reshaped_trace_profiling[0:trainSize,:,:]
    #xTrain_wavebp0 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp0_1hot_profiling[0:trainSize,:]]
    #xTrain_wavebp1 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]
    #xTrain_wavebp01 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]
    #yTrain_fqmul0 = y_train_fqmul0[0:trainSize,:]
    #yTrain_fqmul1 = y_train_fqmul1[0:trainSize,:]
    #yTrain_skpv = y_train_skpv[0:trainSize,:]

    #xVal_wave = [Reshaped_trace_profiling[valLoc:,:,:]]
    #xVal_wavebp0 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:]]]
    #xVal_wavebp1 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp1_1hot_profiling[valLoc:,:]]]
    #xVal_wavebp01 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]]
    #xVal_wavebp01next0 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:]]]
    #xVal_wavebp01next01 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[valLoc:,:]]]
    #xVal_wave = Reshaped_trace_profiling[valLoc:,:,:]
    #xVal_wavebp0 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp0_1hot_profiling[valLoc:,:]]
    #xVal_wavebp1 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]
    #xVal_wavebp01 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]
    #yVal_fqmul0 = y_train_fqmul0[valLoc:,:]
    #yVal_fqmul1 = y_train_fqmul1[valLoc:,:]
    #yVal_skpv = y_train_skpv[valLoc:,:]

    # Input data creation
    if xType == 'wave':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]]]#xTrain_wave
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]]]#xVal_wave
    elif xType == 'wavebp0':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:]]]#xTrain_wavebp0
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:]]]#xVal_wavebp0
    elif xType == 'wavebp1':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp1_1hot_profiling[0:trainSize,:]]]#xTrain_wavebp1
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp1_1hot_profiling[valLoc:,:]]]#xVal_wavebp1
    elif xType == 'wavebp01':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]]#xTrain_wavebp01
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]]#xVal_wavebp01
    elif xType == 'wavebp01next0':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:]]]#xTrain_wavebp01next0
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:]]]#xVal_wavebp01next0
    elif xType == 'wavebp01next01':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[0:trainSize,:]]]#xTrain_wavebp01next01
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[valLoc:,:]]]#xVal_wavebp01next01
        #print('Created xType, len(xTrain) :', xType, len(xTrain))
        #print('len(xTrain[0]) =', len(xTrain[0]), ';    len(xTrain[1]) =', len(xTrain[1]), 'PRESS ENTER TO CONTINUE')
        #print('PRESS ENTER TO CONTINUE')
        #input()
    # Category creation
    if yType == 'fqmul0':
        yTrain = yTrain_fqmul0
        yTrain_value = fqmul_profiling[0][0:trainSize]
        yVal = yVal_fqmul0
        yVal_value = fqmul_profiling[0][valLoc:]
    elif yType == 'fqmul1':
        yTrain = yTrain_fqmul1
        yTrain_value = fqmul_profiling[1][0:trainSize]
        yVal = yVal_fqmul1
        yVal_value = fqmul_profiling[1][valLoc:]
    elif yType == 'skpv':
        yTrain = y_train_skpv[0:trainSize,:]#yTrain_skpv
        yTrain_value = skpv_profiling[0:trainSize]
        yVal = y_train_skpv[valLoc:,:]#yVal_skpv
        yVal_value = skpv_profiling[valLoc:]

    return xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value

def get_subset(idxs, X_train, y_train):
    profile_data = X_train[0][0]
    print(len(profile_data))
    profile_bp0 = X_train[1][0]
    profile_bp1 = X_train[1][1]
    sub_X_train = [[profile_data[idxs,:,:]], [profile_bp0[idxs,:], profile_bp1[idxs,:]]]
    sub_y_train = y_train[idxs,:]

    return sub_X_train, sub_y_train

def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

def mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
    realkey = int(yTest_value[0])
    rankmat_byKey = np.tile(0, (nruns, maxtrc))
    rankmat_byClass = np.tile(0, (nruns, maxtrc))
    ps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    lpsums_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    #print("-------------------------------------------")
    #print(len(xTest[0][0]))
    #print('%s  is running' % (model.__name__))
    for krun in range(nruns):
        #print('%s  run %d of %d' % (model.__name__, krun+1, nruns))
        #if (krun % nruns) == 0:
        #    print('%s  run %d of %d' % (model.name, krun+1, nruns))
        samp = batches[krun,:]
        #print('All samp')
        #print(samp)
        #ps = model.predict(U[samp,:])

        if xType == 'wave':
            ps = model.predict([xTest[0][0][samp,:,:]])
        elif xType == 'wavebp0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:]])
        elif xType == 'wavebp1':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][1][samp,:]])
        elif xType == 'wavebp01':
            ps = model.predict([xTest[0][0][0][samp,:,:], xTest[0][0][1][0][samp,:,:], xTest[0][0][1][1][samp,:,:]])
        elif xType == 'wavebp01next0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:]])
        elif xType == 'wavebp01next01':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:], xTest[1][3][samp,:]])

        lps = np.log(ps)
        #print('Len lps')
        #print(len(lps))
        #print(maxtrc)
        lpsums = np.zeros(noHypoKeys)
        #lpsAllHypoKeys = np.zeros((maxtrc, noHypoKeys))
        for i in range(maxtrc):
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = AES_Sbox[P[samp[i]] ^ range(noHypoKeys)]
            realClass = realkey#S[realkey]
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = P[samp[i]] ^ range(noHypoKeys)
            #realClass = HWcompute(S[realkey])
            lpsAllHypoKeys = lps
            #for hypoKey in range(noHypoKeys):
            #   lpsAllHypoKeys[i, hypoKey] = lps[i, S[hypoKey]]
            #print('lpsums.shape =', lpsums.shape, ';   lps.shape =', lps.shape)
            lpsums += lps[i]#, S]
            lpsums_AllHypoKeys_Nruns[i,:,krun] = lpsums
            #print('realkey =', realkey)
            rnk_byKey = sum(lpsums > lpsums[realkey])
            rankmat_byKey[krun, i] = rnk_byKey
            rnk_byClass = sum(lps[i, :] > lps[i, realClass])
            rankmat_byClass[krun, i] = rnk_byClass
        ps_AllClasses_Nruns[:,:,krun] = ps
        lps_AllClasses_Nruns[:,:,krun] = lps
        lps_AllHypoKeys_Nruns[:,:,krun] = lpsAllHypoKeys
    return rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns

def eval_model(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
    rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns = mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses)
    plot_data = ['model_type', rankmat_byKey]
    mr = np.mean(rankmat_byKey, 0)

    return mr[-1]

import pandas as pd

class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, xTest, yTest_value, save_model_name, database_folder_train, eval_interval):
        super().__init__()
        self.xTest = xTest
        self.yTest_value = yTest_value
        self.save_model_name = save_model_name
        self.database_folder_train = database_folder_train
        self.all_rank = []
        self.best_model = self.model
        self.eval_interval = eval_interval

    def on_epoch_end(self, epoch, logs=None):
        #keys = list(logs.keys())
        #print("End epoch {} of training; got log keys: {}".format(epoch, keys))
        if epoch % self.eval_interval == 0:
            batches = np.zeros((nruns_default, maxtrc_default), 'int')
            if args.xType == 'wavebp01':
                len_test = len(self.xTest[0][0][0])
            else:
                len_test = len(self.xTest[0][0])
            for i in range(nruns_default):
                batches[i,:] = np.random.choice(len_test, maxtrc_default, False)
            model = self.model
            test_rank = eval_model(model, nruns_default, maxtrc_default, batches, self.xTest, self.yTest_value, noHypoKeys, noClasses)
            
            if len(self.all_rank) > 0:
                if test_rank < np.min(np.array(self.all_rank)):
                    print(test_rank)
                    self.model.save(self.save_model_name + '.keras')
            self.all_rank.append(test_rank)

    def on_train_end(self, logs=None):
        df = pd.DataFrame({'Attack Mean Rank': self.all_rank})
        df.to_csv(os.path.join(self.database_folder_train, 'attack_rank.csv'))
        #Save best model


class MultiKeyCallback(tf.keras.callbacks.Callback):
    def __init__(self, xTest_multi, yTest_vals, save_model_name, database_folder_train, eval_interval):
        super().__init__()
        self.xTest_multi = xTest_multi
        self.yTest_vals = yTest_vals
        self.save_model_name = save_model_name
        self.database_folder_train = database_folder_train
        self.all_rank = []
        self.best_model = self.model
        self.maxtrc = 80 #Max trace num for multi label
        self.eval_interval = eval_interval


    def on_epoch_end(self, epoch, logs=None):
        #keys = list(logs.keys())
        #print("End epoch {} of training; got log keys: {}".format(epoch, keys))
        if epoch % self.eval_interval == 0:
            multi_rank = []
            for i in range(len(self.yTest_vals)): #Iterate each key
                nruns = 1 #data for multi-label is limited, so we did 1 run only
                batches = np.zeros((nruns, self.maxtrc), 'int')
                batches[0] = np.arange(self.maxtrc)
                #for i in range(nruns_default):
                #    batches[i,:] = np.random.choice(len(self.xTest[0][0]), maxtrc_default, False)
                model = self.model
                test_rank = eval_model(model, nruns, self.maxtrc, batches, [[self.xTest_multi[i]]], [self.yTest_vals[i]], noHypoKeys, noClasses)
                multi_rank.append(test_rank)


            if len(self.all_rank) > 0:
                all_rank_np = np.array(self.all_rank)
                #for i in range(len(self.yTest_vals)):
                    #if multi_rank[i] < np.min(all_rank_np[:,i] ):
                        #self.model.save(self.save_model_name + str(i) + '.keras') #Save best model for each key
                all_rank_avg = np.mean(all_rank_np, axis = 1)
                if np.mean(np.array(multi_rank)) < np.min(all_rank_np):
                    self.model.save(self.save_model_name + '_overall.keras')
            self.all_rank.append(multi_rank)
            self.model.save(self.save_model_name + str(epoch) + '.keras')

    def on_train_end(self, logs=None):
        all_cols = ['Mean Rank Key No.' + str(i) for i in range(len(self.yTest_vals))]
        df = pd.DataFrame(self.all_rank, columns=all_cols)
        #df = pd.DataFrame({'Attack Mean Rank': self.all_rank})
        df.to_csv(os.path.join(self.database_folder_train, 'attack_rank_multi.csv'))
        self.model.save(self.save_model_name + '_end.keras')
        #Save best model

def get_subset(idxs, X_train, y_train, xType):
    profile_data = X_train #[0]
    print(profile_data.shape)
    if xType == 'wave':
        sub_X_train = profile_data[idxs,:,:]
    elif xType == 'wavebp01':
        profile_bp0 = X_train[1][0]
        print(len(profile_bp0))
        print(profile_bp0.shape)
        profile_bp1 = X_train[1][1]
        sub_X_train = [profile_data[idxs,:,:], [profile_bp0[idxs,:,:], profile_bp1[idxs,:,:]]]
    sub_y_train = y_train[idxs,:]

    return sub_X_train, sub_y_train

from sklearn.cluster import KMeans

def min_max_sampling(data, num_cluster, num_sample):
    sum_data = np.sum(data, axis = 1)
    exp_data = np.expand_dims(sum_data, axis=1)
    exp_data = np.squeeze(np.hstack((exp_data,exp_data)))
    #kmeans = KMedoids(n_clusters=num_cluster, random_state=0).fit(X)
    kmeans = KMeans(n_clusters=200, random_state=0, n_init="auto").fit(exp_data)
    #Get all maximum distances
    min_dist = []
    for sample in data:
        all_dist = []
        for clusterNo in range(len(kmeans.cluster_centers_)):
            centroid = kmeans.cluster_centers_[clusterNo]
            all_dist.append(np.linalg.norm(centroid-sample))
        all_dist = np.array(all_dist)
        min_dist.append(np.min(all_dist))
    #max_idx = np.argmax(min_dist)
    max_idxs = np.argpartition(min_dist, -num_sample)[-num_sample:]

    return max_idxs

def random_sampling(data, num_sample):
    #print(len(data))
    #print(data.shape)
    np.random.seed(args.seed)
    if args.xType == 'wavebp01':
        data = data[0]
    rand_ids = np.random.choice(len(data), num_sample, replace=False)
    print(len(rand_ids))
    print('---')
    return rand_ids

def unceratainty_sampling(model, xTrain_in, yTrain_in, num_sample):
    preds = model.predict(xTrain_in)
    idx = 0
    #Get samples that below the mean probability
    pred_probs = []
    for sample in preds:
        label_idx = np.where(yTrain_in[idx]==1)
        pred_prob = preds[idx][label_idx]
        idx += 1
        pred_probs.append(pred_prob)
    pred_probs = np.squeeze(np.array(pred_probs))
    #max_idxs = np.argpartition(pred_probs, -num_sample)[-num_sample:]
    min_idxs = np.argpartition(pred_probs, num_sample)[:num_sample]

    return min_idxs



import keras
from keras import layers

def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
    # Attention and Normalization
    x = layers.MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout
    )(inputs, inputs)
    x = layers.Dropout(dropout)(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    res = x + inputs

    # Feed Forward Part
    x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(res)
    x = layers.Dropout(dropout)(x)
    x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    return x + res

def build_model(
    input_shape,
    head_size,
    num_heads,
    ff_dim,
    num_transformer_blocks,
    mlp_units,
    dropout=0,
    mlp_dropout=0,
):
    inputs = keras.Input(shape=input_shape)
    x = inputs
    for _ in range(num_transformer_blocks):
        x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)

    x = layers.GlobalAveragePooling1D(data_format="channels_last")(x)
    for dim in mlp_units:
        x = layers.Dense(dim, activation="relu")(x)
        x = layers.Dropout(mlp_dropout)(x)
    n_classes = noClasses
    outputs = layers.Dense(n_classes, activation="softmax")(x)
    return keras.Model(inputs, outputs)

from sklearn.preprocessing import StandardScaler
from scipy import stats

def normalize(timeseries):
    return (timeseries-timeseries.min())/(timeseries.max()-timeseries.min())

def z_norm(timeseries):
    '''
    timeseries = np.expand_dims(timeseries, 1)
    scaler = StandardScaler()
    scaler.fit(timeseries)
    res = scaler.transform(timeseries)[:,0]
    '''

    return stats.zscore(timeseries)

def normalize_data_per_trace(data):
    print(data.shape)
    for i in range(len(data)):
        data[i] = normalize(data[i])

    return data

def z_normalize_data_per_trace(data):
    print(data.shape)
    for i in range(len(data)):
        data[i] = z_norm(data[i])

    return data

def cummulative_transform(data):
    for i in range(len(data)):
        data[i] = np.cumsum(data[i], dtype=float)
    return data

def convert_bp(bp_input):
    res_bp = to_categorical(bp_input, num_classes=NumSKPVclasses+1)
    return np.expand_dims(res_bp, 2)

def create_training_data_optimize(args, data_path, sKeyNo, trainPortion, xType, yType, is_test, start_trace, end_trace):
    val_num = 300000
    val_ids = np.load('val_ids.npz')['ids']
    print(len(val_ids))
    print(val_ids[:10])
    end_val_trace = end_trace + val_num
    if is_test:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file_from_test(data_path, sKeyNo)
    else:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_files(data_path, start_trace, end_val_trace)

    #train_dataset = tf.data.Dataset.from_tensor_slices((trace_profiling[:150000], tf.one_hot(skpv_profiling[:150000], 3329) ))
    #exit()

    #Reshaped_trace_profiling= trace_profiling.reshape((trace_profiling.shape[0], trace_profiling.shape[1], 1))
    print(trace_profiling.shape)


    dataSize = end_trace - start_trace
    trainSize = math.floor(dataSize * trainPortion)
    #print(skpv_profiling[:10])
    y_train_skpv = to_categorical(skpv_profiling, num_classes=NumSKPVclasses)

    if args.normalize == 0:
        xTrain = np.expand_dims(trace_profiling[:dataSize], axis = 2)
    elif args.normalize == 1:
        if args.cummulative == 1:
            xTrain = np.expand_dims(normalize_data_per_trace(cummulative_transform(trace_profiling[:dataSize])), axis = 2)
        elif args.cummulative == 2:
            xTrain = np.expand_dims(cummulative_transform(normalize_data_per_trace(trace_profiling[:dataSize])), axis = 2)
        else:
            xTrain = np.expand_dims(normalize_data_per_trace(trace_profiling[:dataSize]), axis = 2)
    else:
        if args.cummulative == 1:
            xTrain = np.expand_dims(z_normalize_data_per_trace(cummulative_transform(trace_profiling[:dataSize])), axis = 2)
        elif args.cummulative == 2:
            xTrain = np.expand_dims(cummulative_transform(z_normalize_data_per_trace(trace_profiling[:dataSize])), axis = 2)
        else:
            xTrain = np.expand_dims(z_normalize_data_per_trace(trace_profiling[:dataSize]), axis = 2)
        #print(xTrain.shape)
    
    if is_test:
        xVal = np.expand_dims(trace_profiling[dataSize:end_val_trace], axis = 2)
        yVal = y_train_skpv[dataSize:end_val_trace]
        yVal_value = skpv_profiling[dataSize:end_val_trace]
        if xType == 'wavebp01':
            xVal = [xVal, convert_bp(bp_profiling[0,dataSize:end_val_trace],), 
                            convert_bp(bp_profiling[1,dataSize:end_val_trace])]
    else:
        xVal = np.expand_dims(trace_profiling[val_ids], axis = 2)
        yVal = y_train_skpv[val_ids]
        yVal_value = skpv_profiling[val_ids]
        if xType == 'wavebp01':
            xVal = [xVal, [convert_bp(bp_profiling[0,val_ids]), 
                            convert_bp(bp_profiling[1,val_ids])]]

    yTrain = y_train_skpv[:dataSize]
    yTrain_value = skpv_profiling[:dataSize]
    
    print(bp_profiling.shape)
    if xType == 'wavebp01':
        xTrain = [xTrain, [convert_bp(bp_profiling[0,:dataSize]), 
                    convert_bp(bp_profiling[1,:dataSize])]]
    #del trace_profiling, bp_profiling, skpv_profiling
    return xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value


#Infomax

def get_median(features, targets): 
    num_classes = len(np.unique(targets, axis=0))
    prot = np.zeros((num_classes, features.shape[-1]), dtype=features.dtype)

    for i in range(num_classes):
        prot[i] = np.median(features[(targets == i).nonzero(), :].squeeze(), axis=0, keepdims=False)
    return prot


def get_distance(features, labels):

    prots = get_median(features, labels)
    prots_for_each_example = np.zeros(shape=(features.shape[0], prots.shape[-1]))

    num_classes = len(np.unique(labels))
    for i in range(num_classes):
        prots_for_each_example[(labels == i).nonzero()[0], :] = prots[i]
    distance = np.linalg.norm(features - prots_for_each_example, axis=1)

    return distance


def bin_allocate(num, bins, mode='uniform', initial_budget=None):
    sorted_index = torch.argsort(bins)
    sort_bins = bins[sorted_index]

    num_bin = bins.shape[0]

    rest_exp_num = num
    budgets = []
    for i in range(num_bin):
        if sort_bins[i] == 0:
            budgets.append(0)
            continue 
        rest_bins = torch.count_nonzero(sort_bins[i:])
        if mode == 'uniform':
            avg = rest_exp_num // rest_bins
            cur_num = min(sort_bins[i].item(), avg)
            rest_exp_num -= cur_num
        else:
            avg = initial_budget[sorted_index[i]]
            cur_num = min(sort_bins[i].item(), avg)
            delta = int((avg - cur_num)/max(1, (rest_bins - 1))) 
            for j in range(i+1, num_bin):
                initial_budget[sorted_index[j]] += delta
        budgets.append(cur_num)

    budgets = torch.tensor(budgets)
    if torch.sum(budgets) < num: # TODO: check again
        delta = num - torch.sum(budgets)
        i = 1
        while delta and i <= num_bin:
            if budgets[-i] < sort_bins[-i]:
                budgets[-i] += 1
                delta -= 1
            i += 1

    rst = torch.zeros((num_bin,)).type(torch.int)
    rst[sorted_index] = torch.tensor(budgets).type(torch.int)

    assert all([b<= r for r, b in zip(bins, rst)]), ([(r.item(),b.item()) for r, b in zip(bins, rst)], bins, [x.item() for x in torch.tensor(budgets)[sorted_index]])
    return rst


class CoresetSelection(object):

    @staticmethod
    def moderate_selection(data_score, ratio, features):

        def get_prune_idx(rate, distance):
            rate = 1-rate
            low = 0.5 - rate / 2
            high = 0.5 + rate / 2

            sorted_idx = distance.argsort()
            low_idx = round(distance.shape[0] * low)
            high_idx = round(distance.shape[0] * high)

            ids = np.concatenate((sorted_idx[:low_idx], sorted_idx[high_idx:]))

            return ids

        targets_list = data_score['targets']
        distance = get_distance(features, targets_list)
        ids = get_prune_idx(ratio, distance)

        return ids


    @staticmethod
    def score_monotonic_selection(data_score, key, ratio, descending, class_balanced):
        score = data_score[key]
        score_sorted_index = score.argsort(descending=descending)
        total_num = ratio * data_score[key].shape[0]
        print("Selecting from %s samples" % total_num)
        if class_balanced:
            print('Class balance mode.')
            all_index = torch.arange(data_score['targets'].shape[0])
            #Permutation
            selected_index = []
            targets_list = data_score['targets'][score_sorted_index]
            targets_unique = torch.unique(targets_list)
            for target in targets_unique:
                target_index_mask = (targets_list == target)
                target_index = all_index[target_index_mask]
                targets_num = target_index_mask.sum()
                target_coreset_num = targets_num * ratio
                selected_index = selected_index + list(target_index[:int(target_coreset_num)])
                print("Selected %s samples for %s label" % (len(selected_index), target))
            selected_index = torch.tensor(selected_index)
            print(f'High priority {key}: {score[score_sorted_index[selected_index][:15]]}')
            print(f'Low priority {key}: {score[score_sorted_index[selected_index][-15:]]}')
            return score_sorted_index[selected_index]
        else:
            print(f'High priority {key}: {score[score_sorted_index[:15]]}')
            print(f'Low priority {key}: {score[score_sorted_index[-15:]]}')
            return score_sorted_index[:int(total_num)]

    @staticmethod
    def mislabel_mask(data_score, mis_key, mis_num, mis_descending, coreset_key):
        mis_score = data_score[mis_key]
        mis_score_sorted_index = mis_score.argsort(descending=mis_descending)
        hard_index = mis_score_sorted_index[:mis_num]
        print(f'Bad data -> High priority {mis_key}: {data_score[mis_key][hard_index][:15]}')
        print(f'Prune {hard_index.shape[0]} samples.')

        easy_index = mis_score_sorted_index[mis_num:]
        data_score[coreset_key] = data_score[coreset_key][easy_index]

        return data_score, easy_index

    @staticmethod 
    def stratified_sampling(data_score, coreset_num, args, data_embeds=None):

        if args.sampling_mode == 'graph' and args.coreset_key in ['accumulated_margin']: # TODO: check again
            score = data_score[args.coreset_key]
            min_score = torch.min(score)
            max_score = torch.max(score)
            score = score - min_score
            data_score[args.coreset_key] = -score

        print('Using stratified sampling...')
        score = data_score[args.coreset_key]
        if args.graph_score:
            graph = InfoGraphDensitySampler(X=data_embeds, y=None, gamma=args.gamma,
                                                  seed=0, importance_scores=score, args=args) 
            score = torch.tensor(graph.graph_density)

        total_num = len(score)
        min_score = torch.min(score)
        max_score = torch.max(score) * 1.0001
        print("Min score: %s, max score: %s" % (min_score.item(), max_score.item()))
        step = (max_score - min_score) / args.stratas

        def bin_range(k):
            return min_score + k * step, min_score + (k + 1) * step

        strata_num = []
        ##### calculate number of samples in each strata #####
        for i in range(args.stratas):
            start, end = bin_range(i)
            num = torch.logical_and(score >= start, score < end).sum()
            strata_num.append(num)
        strata_num = torch.tensor(strata_num)

        if args.budget_mode == 'uniform':
            budgets = bin_allocate(coreset_num, strata_num)
        elif args.budget_mode == 'confidence':
            confs = data_score['confidence']
            mean_confs = []
            for i in range(args.stratas):
                start, end = bin_range(i)
                sample_idxs = torch.logical_and(score >= start, (score < end)).nonzero().squeeze()
                if sample_idxs.size()[0] != 0:
                    mean_confs.append(1-torch.mean(confs[sample_idxs]).item())
                else:
                    mean_confs.append(0)
            total_conf = np.sum(mean_confs)
            budgets = [int(n*coreset_num/total_conf) for n in mean_confs]
            print("Initial budget", budgets)
            budgets = bin_allocate(coreset_num, strata_num, mode='confidence', initial_budget=budgets)
        elif args.budget_mode == 'aucpr':
            budgets = bin_allocate(coreset_num, strata_num)
            sample_index = torch.arange(data_score[args.coreset_key].shape[0])
            aucpr_values = []
            min_budgets = {}
            for i in tqdm(range(args.stratas), desc='Getting k-centers for aucpr-based budgeting'):
                if budgets[i] == 0:
                    aucpr_values.append(0)
                    continue
                start, end = bin_range(i)
                mask = torch.logical_and(score >= start, score < end)
                pool = sample_index[mask]

                if args.sampling_mode == 'random':
                    rand_index = torch.randperm(pool.shape[0])
                    selected_idxs = [idx.item() for idx in rand_index[:budgets[i]]] 
                elif args.sampling_mode == 'graph':
                    if pool.shape[0] <= args.n_neighbor:
                        rand_index = torch.randperm(pool.shape[0])
                        selected_idxs = rand_index[:budgets[i]].numpy().tolist()
                    else:
                        sampling_method = InfoGraphDensitySampler(X=None if data_embeds is None else data_embeds[pool], y=None, gamma=args.gamma,
                                                              seed=0, importance_scores=score[pool], args=args) 
                        selected_idxs = sampling_method.select_batch_(budgets[i])
                else:
                    raise ValueError

                kcenters = pool[selected_idxs]
                non_coreset = list(set(pool.tolist()).difference(set(kcenters.tolist())))
                aucpr = get_aucpr(data_embeds[kcenters], data_embeds[non_coreset])
                aucpr_values.append(round(aucpr, 3))
                if aucpr == 0:
                    min_budgets[i] = budgets[i]

            print("Initial AUCpr values: ", aucpr_values)
            print("Initial mean AUCpr: ", np.mean(aucpr_values))
            total_aucpr = np.sum(aucpr_values)
            print("Uniform budget", budgets)
            if total_aucpr == 0:
                pass
            else:
                budgets = [int(n*(coreset_num-sum(min_budgets.values()))/total_aucpr) if i not in min_budgets
                           else min_budgets[i] for i, n in enumerate(aucpr_values)]
                print("Initial budget", budgets)
                budgets = bin_allocate(coreset_num, strata_num, mode='aucpr', initial_budget=budgets)
        else:
            raise ValueError
        # assert budgets.sum().item() == coreset_num, (budgets.sum(), coreset_num)
        print(budgets, budgets.sum())

        ##### sampling in each strata #####
        selected_index = []
        sample_index = torch.arange(data_score[args.coreset_key].shape[0])

        pools, kcenters = [], []
        for i in tqdm(range(args.stratas), desc='sampling from each strata'):
            start, end = bin_range(i)
            mask = torch.logical_and(score >= start, score < end)
            pool = sample_index[mask]
            pools.append(pool)

            if len(pool.numpy().tolist()) == 0 or budgets[i] == 0:
                continue
            if args.sampling_mode == 'random':
                rand_index = torch.randperm(pool.shape[0])
                selected_idxs = [idx.item() for idx in rand_index[:budgets[i]]] 
                selected_idxs = sampling_method.select_batch_(None, budgets[i])
            elif args.sampling_mode == 'graph':
                if pool.shape[0] <= args.n_neighbor: # if num of samples are less than size of graph, select all
                    rand_index = torch.randperm(pool.shape[0])
                    selected_idxs = rand_index[:budgets[i]].numpy().tolist()
                else:
                    sampling_method = InfoGraphDensitySampler(X=None if data_embeds is None else data_embeds[pool], y=None, gamma=args.gamma, seed=0,
                                                          importance_scores=score[pool], args=args) 
                    selected_idxs = sampling_method.select_batch_(budgets[i])
            else:
                raise ValueError
            kcenters.append(pool[selected_idxs])

        if args.aucpr:
            final_aucpr_values = []
            for pool, samples in zip(pools, kcenters):
                if len(pool.numpy().tolist()) == 0 or budgets[i] == 0:
                    final_aucpr_values.append(0.0)
                non_coreset = list(set(pool.tolist()).difference(set(samples.tolist())))
                if len(non_coreset) == 0:
                    aucpr = 0
                else:
                    aucpr = get_aucpr(data_embeds[kcenters], data_embeds[non_coreset])
                final_aucpr_values.append(round(aucpr, 3))
            print("Final AUCpr values: ", final_aucpr_values)
            print("Final mean AUCpr: ", np.mean(final_aucpr_values))

        for samples in kcenters:
            selected_index += samples

        return selected_index, (pools, budgets)

    @staticmethod
    def density_sampling(data_score, bins, coreset_num, args, data_embeds=None):

        if args.sampling_mode == 'graph' and args.coreset_key in ['accumulated_margin']: # TODO: check again
            score = data_score[args.coreset_key]
            min_score = torch.min(score)
            max_score = torch.max(score)
            score = score - min_score
            data_score[args.coreset_key] = score

        hist = np.histogram(bins, bins=np.arange(0, np.amax(bins) + 2, 1))[0]
        n_bins = np.amax(bins) + 1
        bin_pop_density = Counter(hist.tolist())
        print("Frequency of bin counts", bin_pop_density.most_common(20))

        non_empty_bins = np.where(hist != 0)[0]
        print("Skipping %s empty bins in total %s bins" % ((n_bins - non_empty_bins.shape[0]), n_bins))

        strata_num = []
        bin2size = {bin_idx: hist[bin_idx] for bin_idx in non_empty_bins} 
        for i in non_empty_bins:
            strata_num.append(bin2size[i])
        strata_num = torch.tensor(strata_num)

        if args.budget_mode == 'density':
            total_num = sum(list(bin2size.values()))
            bin2budget = {bin_idx: math.ceil(bin2size[bin_idx]*coreset_num/total_num) for bin_idx in non_empty_bins}
        elif args.budget_mode == 'uniform':
            budgets = bin_allocate(coreset_num, strata_num)
            bin2budget = {bin_i: budgets[i] for i, bin_i in enumerate(non_empty_bins)}
        elif args.budget_mode == 'confidence':
            confs = data_score['confidence']
            mean_confs = []
            strata_num = []
            for i in non_empty_bins:
                sample_idxs = np.where(bins == i)[0]
                mean_confs.append(1 - torch.mean(confs[sample_idxs]).item())
                strata_num.append(bin2size[i])
            strata_num = torch.tensor(strata_num)
            total_conf = np.sum(mean_confs)
            budgets = [int(n * coreset_num / total_conf) for n in mean_confs]
            print("Initial budget", budgets)
            budgets = bin_allocate(coreset_num, strata_num, mode='confidence', initial_budget=budgets)
            print("Final budget", budgets)
            bin2budget = {bin_idx: budgets[i] for i, bin_idx in enumerate(non_empty_bins)}
        elif args.budget_mode == 'aucpr':
            budgets = bin_allocate(coreset_num, strata_num)
            aucpr_values = []
            min_budgets = {}
            for i, bin_idx in tqdm(enumerate(non_empty_bins), desc='Getting k-centers for aucpr-based budgeting'):
                if budgets[i] == 0:
                    aucpr_values.append(0)
                    continue
                sample_idxs = np.where(bins == bin_idx)[0]

                if args.sampling_mode == 'random':
                    rand_index = np.random.permutation(sample_idxs.shape[0])
                    selected_idx = rand_index[:budgets[i]] 
                elif args.sampling_mode == 'graph':
                    if sample_idxs.shape[0] <= 10:
                        selected_idx = np.random.permutation(sample_idxs.shape[0])
                    else:
                        sampling_method = InfoGraphDensitySampler(X=None if data_embeds is None else data_embeds[sample_idxs], y=None, gamma=args.gamma, seed=0,
                                                              importance_scores=data_score['forgetting'][sample_idxs], args=args) 
                        selected_idx = sampling_method.select_batch_(budgets[i])
                else:
                    raise ValueError

                kcenters = sample_idxs[selected_idx]
                non_coreset = list(set(sample_idxs.tolist()).difference(set(kcenters.tolist())))
                aucpr = get_aucpr(data_embeds[kcenters], data_embeds[non_coreset])
                aucpr_values.append(aucpr)
                if aucpr == 0:
                    min_budgets[bin_idx] = budgets[i]

            print("Initial AUCpr values: ", aucpr_values)
            print("Initial mean AUCpr: ", np.mean(aucpr_values))
            total_aucpr = np.sum(aucpr_values)
            print("Uniform budget", budgets)
            if total_aucpr == 0:
                pass
            else:
                budgets = [int(n * (coreset_num - sum(min_budgets.values())) / total_aucpr) if i not in min_budgets
                       else min_budgets[i] for i, n in enumerate(aucpr_values)]
                print("Initial budget", budgets)
                budgets = bin_allocate(coreset_num, strata_num, mode='aucpr', initial_budget=budgets)
            bin2budget = {bin_idx: budgets[i] for i, bin_idx in enumerate(non_empty_bins)}
        else:
            raise ValueError

        print('Using density sampling...')
        pools = []
        final_idxs = []
        # def get_coreset_for_bin(bins_to_accomplish, final_idxs):
        def get_coreset_for_bin(bin_idx):
            # while True:
            #     try:
            # bin_idx = bins_to_accomplish.get_nowait()
            sample_idxs = np.where(bins == bin_idx)[0]
            pools.append(sample_idxs)
            print("Starting process for label  ", bin_idx, "with %s samples" % len(sample_idxs))
            if len(sample_idxs) > 0:
                if bin2budget[bin_idx] > len(sample_idxs):
                    kcenters = np.random.permutation(sample_idxs.shape[0])
                else:
                    if args.sampling_mode == 'random':
                        rand_index = np.random.permutation(sample_idxs.shape[0])
                        kcenters = rand_index[:bin2budget[bin_idx]] 
                    elif args.sampling_mode == 'graph':
                        if sample_idxs.shape[0] <= args.n_neighbor:
                            kcenters = np.random.permutation(sample_idxs.shape[0])
                        else:
                            sampling_method = InfoGraphDensitySampler(X=None if data_embeds is None else data_embeds[sample_idxs], y=None,
                                                                  gamma=args.gamma, seed=0,
                                                                  importance_scores=data_score[args.coreset_key][
                                                                      sample_idxs], args=args) 
                            kcenters = sampling_method.select_batch_(bin2budget[bin_idx])
                    else:
                        raise ValueError
                kcenters = sample_idxs[kcenters]
            else:
                kcenters = [] 
            return kcenters 

        for bin_idx in non_empty_bins:
            kcenters = get_coreset_for_bin(bin_idx)
            final_idxs.append(kcenters.tolist())

        if args.aucpr:
            final_aucpr_values = []
            for pool, selected in zip(pools, final_idxs):
                non_coreset = list(set(pool.tolist()).difference(set(selected.tolist())))
                if len(non_coreset) == 0 or len(selected) == 0:
                    aucpr = 0
                else:
                    aucpr = get_aucpr(data_embeds[selected], data_embeds[non_coreset])
                final_aucpr_values.append(round(aucpr, 3))
            print("Final AUCpr values: ", final_aucpr_values)
            print("Final mean AUCpr: ", np.mean(final_aucpr_values))

        selected_idxs = []
        for idxs in final_idxs:
            selected_idxs.extend(idxs)
        random.shuffle(selected_idxs)
        if len(selected_idxs) < coreset_num:
            extra_sample_set = list(set(range(len(bins))).difference(set(selected_idxs)))
            selected_idxs = selected_idxs + random.sample(extra_sample_set, k=min(len(extra_sample_set), coreset_num-len(selected_idxs)))

        return selected_idxs[:coreset_num], None

    @staticmethod
    def random_selection(total_num, num):
        print('Random selection.')
        score_random_index = torch.randperm(total_num)

        return score_random_index[:int(num)]

"""Calculate loss and entropy"""
def post_training_metrics(model, dataloader, data_importance, device):


"""Extract feature map"""
def feature_maps(model, dataloader, device, layer=4):
    model.eval()
    features = []
    for batch_idx, (idx, (inputs, targets)) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)

        feats = model.feature_map(inputs, layer=layer)
        features.append(feats.detach().cpu().numpy())

    features = np.concatenate(features)
    return features

def get_important_score(args, model, trainset):
    data_importance = {}
    #model.eval()
    data_importance['entropy'] = []
    data_importance['loss'] = []
    data_importance['confidence'] = []

    for batch_idx, (idx, (inputs, targets)) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)

        logits = model(inputs)
        prob = nn.Softmax(dim=1)(logits)

        entropy = -1 * prob * torch.log(prob + 1e-10)
        entropy = torch.sum(entropy, dim=1).detach().cpu()

        confidence = prob[torch.arange(0, logits.shape[0]).to(device), targets].detach().cpu()
        # print(targets.shape, prob.shape, confidence.shape)

        data_importance['entropy'][idx] = entropy
        data_importance['confidence'][idx] = confidence

    return data_importance

def infomax(args, trainset):
    total_num = len(trainset) 
    coreset_num = args.coreset_ratio * total_num 
    #with open(args.data_score_path, 'rb') as f:
    #    data_score = pickle.load(f)
    data_score = get_important_score(args, model, trainset)
    if args.coreset_key in ['entropy', 'forgetting', 'el2n', 'ssl']:
        print("Using descending order")
        args.data_score_descending = 1 
    if args.label_balanced:
        uniq_labels = list(uniq_labels)
        coreset_sizes_per_label = int(coreset_num / len(uniq_labels))
        coreset_sizes_per_label = [coreset_sizes_per_label] * len(uniq_labels)
    else:
        uniq_labels = None
        targets = None
    score_index = None

    mis_num = int(args.mis_ratio * total_num)
    _, score_index = CoresetSelection.mislabel_mask(data_score, mis_key=args.mis_key,
                                                             mis_num=mis_num,
                                                             mis_descending=args.mis_key in ['entropy',
                                                                                             'forgetting', 'el2n',
                                                                                             'ssl'],
                                                             coreset_key=args.coreset_key)

    np.save('./temp/imagenet/score_index_aum_%s.npy' % args.coreset_ratio, score_index)

    if (args.sampling_mode == 'graph' and not args.precomputed_dists and not args.precomputed_neighbors):
        assert args.feature_path
        features = np.load(args.feature_path)[score_index]
    else:
        features = None
    coreset_num = int(args.coreset_ratio * total_num)
    # load data scores from training 100% data
    sampling_method = GraphDensitySampler(X=features, y=None,
                                          gamma=args.gamma,
                                          seed=0, importance_scores=data_score[args.coreset_key], args=args) 
    coreset_index = sampling_method.select_batch_(coreset_num)
    coreset_index_no_mis = np.array(coreset_index.copy())
    coreset_index = score_index[coreset_index]
    graph_scores = sampling_method.starting_density

    if len(coreset_index) < coreset_num:
        if score_index is not None:
            extra_sample_set = list(set(score_index.tolist()).difference(set(coreset_index.tolist())))
            coreset_index = np.hstack((coreset_index, np.array(random.sample(extra_sample_set,
                                                                             k=int(min(len(extra_sample_set),
                                                                                   coreset_num - len(coreset_index)))))))
            print("Added extra %s samples" %  int(min(len(extra_sample_set), coreset_num-len(coreset_index))))
            print(coreset_index.shape) 

    return coreset_index

def train_model_multiEpochs(xType, database_folder_train, modelLogFolder, logTrainedModel_byFile_folder, logTrainedModel_byEp_folder, logFilename, MLmodel_detail, sKeyNo, class_weight, period, maxEpochs, train_batch_size, args):
    
    if args.train_type == 'baseline':
        if args.mlp == 1:
            model = subModels_gen_MLP(args, xType, noConv1Dbranch, noBPbranch, noLayers, tracelen, NumBPinput, MLmodel_detail, modelLogFolder, logFilename, classes=noClasses)
        else:
            model = subModels_gen(args, xType, noConv1Dbranch, noBPbranch, noLayers, tracelen, NumBPinput, MLmodel_detail, modelLogFolder, logFilename, classes=noClasses)
    else:
        model = load_model(trained_model_path)


            
    gc.collect()
    #model.summary()
    print('load success')
    trainPortion = 1.0 #To get all data
    #xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value = create_training_data_form(data_path,sKeyNo, trainPortion, xType, yType,False, args.start_trace, args.end_trace)
    #xTrain = xTrain[0][0]
    #print(yTrain.shape)

    #xTrain_Pool, yTrain_Pool, _, _, _, _ = create_training_data_form(data_path,sKeyNo, trainPortion, xType, yType,False, args.end_trace, 100000)
    xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value = create_training_data_optimize(args, data_path,sKeyNo, trainPortion, xType, yType,False, args.start_trace, args.end_trace)
    print(yTrain.shape)
    print(len(xTrain))
    yTrain = np.expand_dims(yTrain_value, axis = 1)
    yVal = yVal_value
    '''
    xTest, yTest, _xVal_, yVal_, yTest_value, yVal_value = create_training_data_form("KYBER51.H5", sKeyNo, testPortion, xType, yType,True, args.start_trace, args.end_trace)
    print('-------------------')
    print(xTest[0][0].shape)
    print(yTest_value)
    print(yTest.shape)
    print('-----------------')
    '''
    xTest, yTest, _xVal_, yVal_ , yTest_value, yVal_value = create_training_data_optimize(args, "KYBER51.H5", sKeyNo, testPortion, xType, yType,True, args.start_trace, args.end_trace)
    xTest = [[xTest]]
    #print(xTest[0][0][0].shape)
    #print(xTest[0][0][1][0].shape)
    #print(xTest[0][0][1][1].shape)
    
    print(yTest_value)
    print(yTest.shape)
    xTest_multi, yTest_multi = load_multi_attack(args.eval_path)
    print(yTest_value)
    print(yTest_multi)
    xTest_multi = np.expand_dims(xTest_multi, axis = 3)
    print(xTest_multi.shape)
    print(yTrain.shape)

    #Get coreset

    if args.sampling == 'random':
        all_ids = random_sampling( xTrain, num_sample = args.num_sample)
    if args.sampling == 'infomax':
        all_ids = infomax(args, xTrain)
    else:
        all_ids = np.load(args.all_ids)[:args.num_sample]
    exit()
    np.save(os.path.join(database_folder_train,'all_ids.npy'), all_ids)
    #xTrain[0][0], yTrain = get_subset(all_ids.astype(int),  xTrain[0][0], yTrain, xType)
    xTrain, yTrain = get_subset(all_ids.astype(int),  xTrain, yTrain, xType)


    database_folder_train_it = database_folder_train
    Path(database_folder_train_it).mkdir(parents=True, exist_ok=True)
    csv_logger = CSVLogger(filename=os.path.join(database_folder_train_it+'/log.csv'), append=True, separator=';')

    save_model_name = (database_folder_train_it+'/model_best')
    attack_callback = CustomCallback(xTest, yTest_value, save_model_name, database_folder_train_it, args.eval_interval)
    attack_multi_callback = MultiKeyCallback(xTest_multi, yTest_multi, save_model_name, database_folder_train_it, args.eval_interval)

    callbacks=[csv_logger, attack_callback, attack_multi_callback]
    print('Len training:')
    #print(xTrain[0].shape)
    print(yTrain.shape)


    yTrain = np.squeeze(yTrain)
    #yVal= np.squeeze(yVal)
    print(yTrain.shape)
    #print(yVal.shape)
    #print(xVal.shape)

    print(NumSKPVclasses)
    print(yTrain[-10:])
    #exit()
    print(xTrain[0].shape)
    print(xTrain[1][0].shape)
    print(np.max(xTrain[1][0]))
    

    with tf.device("CPU"):
        #train = tf.data.Dataset.from_tensor_slices(({"input_1": xTrain[0][0], "input_2": xTrain[1][0], "input_3": xTrain[1][1]}, yTrain)).shuffle(4*64).batch(64)
        #val = tf.data.Dataset.from_tensor_slices(({"input_1": xVal[0][0], "input_2": xVal[1][0], "input_3": xVal[1][1]}, yVal)).shuffle(4*64).batch(64)
        if xType == 'wave':
            train = tf.data.Dataset.from_tensor_slices(( xTrain, tf.one_hot(yTrain, NumSKPVclasses) )).shuffle(4*64).batch(train_batch_size)
            val = tf.data.Dataset.from_tensor_slices((xVal, tf.one_hot(yVal,NumSKPVclasses) )).shuffle(4*64).batch(train_batch_size)
        elif xType == 'wavebp01':
            callbacks=[csv_logger, attack_callback]
            train = tf.data.Dataset.from_tensor_slices(({"input_layer": xTrain[0], "input_layer_1": xTrain[1][0], "input_layer_2": xTrain[1][1]}, tf.one_hot(yTrain, NumSKPVclasses))).shuffle(4*64).batch(train_batch_size)
            val = tf.data.Dataset.from_tensor_slices(({"input_layer": xVal[0], "input_layer_1": xVal[1][0], "input_layer_2": xVal[1][1]}, tf.one_hot(yVal,NumSKPVclasses))).shuffle(4*64).batch(train_batch_size)

    model.fit(train, batch_size=train_batch_size, verbose = 1, epochs=maxEpochs, callbacks=callbacks, class_weight=class_weight, validation_data=val)        
        #test_rank = eval_model(model, nruns_default, maxtrc_default, batches, xTest, yTest_value, noHypoKeys, noClasses)
        #print('Test Rank:, ', test_rank)




        #model.save(database_folder_train+'/model_iteration_{}.keras'.format(it))

if work == 'train':
    print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
    print('Work =', work)
    classWeights = np.ones(noClasses).astype(int)
    class_weight = dict(enumerate(classWeights))

    train_model_multiEpochs(xType, database_folder_train, modelLogFolder, logTrainedModel_byFile_folder, logTrainedModel_byEp_folder, logFilename, MLmodel_detail, sKeyNo, class_weight, period, maxEpochs, train_batch_size, args)
