import os.path
import sys
import h5py
import math
import gc
import time
import numpy as np
#from numba import cuda
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Flatten, Dense, Input, Conv1D, MaxPooling1D, ReLU, Dropout, Concatenate, Activation, Multiply, BatchNormalization #, AveragePooling1D, Add, GlobalAveragePooling1D, GlobalMaxPooling1D
from tensorflow.keras.utils import plot_model   #, get_source_inputs
from tensorflow.keras.utils import get_file
from tensorflow.keras import backend as K
#from keras_applications.imagenet_utils import decode_predictions
#from keras_applications.imagenet_utils import preprocess_input
#from keras_applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model
import tensorflow as tf
import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--train_type', type=str, help='baseline or active')
    parser.add_argument('--sampling', type=str, help='random or uncertainty')
    parser.add_argument('--xType', type=str, help='number of ciphertext')
    parser.add_argument('--start_trace', type=int, help='start trace')
    parser.add_argument('--end_trace', type=int)
    parser.add_argument('--batch_size', type=int, help='batch_size', default=256)
    parser.add_argument('--num_epoch', type=int, help='batch_size', default=256)
    parser.add_argument('--trained_model_path', type=str)
    parser.add_argument('--all_ids', type=str)
    parser.add_argument('--num_iteration', type=int, help='iteration_num', default=5)
    parser.add_argument('--num_sample', type=int, default=100)
    parser.add_argument('--schedule_iteration', type=int, nargs='+', help='when to train with sampling data')
    parser.add_argument('--resume_it', type=int)
    parser.add_argument('--eval_interval', type=int, help='iteration_num', default=1)
    parser.add_argument('--update_sampling', type=str, help='iteration_num')
    parser.add_argument('--subtrain_interval', type=int, help='iteration_num', default=1)
    parser.add_argument('--name', type=str, help='experiment name', default='test')
    parser.add_argument('--medoids_path', type=str)
    parser.add_argument('--kl_std', type=int, default=1)
    parser.add_argument('--normalize', type=int, default = 0)
    parser.add_argument('--transformer', type=int, default = 0)
    parser.add_argument('--cummulative', type=int, default = 0)
    parser.add_argument('--seed', type=int, default = 2024)
    parser.add_argument('--eval_path', type=str)
    parser.add_argument('--alpha', type=float, default=0.5)

    return parser

parser = parse_arguments()
args = parser.parse_args()

def set_seeds(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    #random.seed(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)

set_seeds(args.seed)

bp_range = [0, 3329]
skpv_range = [0, 3328]
fqmul_range = [-1828, 1664]
tracelen = 600
NumFQMULclasses = fqmul_range[1] - fqmul_range[0] + 1;  # number of classes for fqmul(skpv, bp)
NumSKPVclasses = skpv_range[1] - skpv_range[0] + 1;     # number of classes for skpv
NumBPinput = bp_range[1] - bp_range[0] + 1;             # number of input for bp (ciphertext)
noClasses = NumSKPVclasses
noHypoKeys = NumSKPVclasses
sKeyNo = 0  # Note: sKeyNo is in range 0 to 3 and which subkeys are they are decided by code in m4 (NOT by code in PC)
work = 'train' #'train'  #'attack'
training_file_list = ['Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data000000to099999_600samples.h5',\
'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data100000to199999_600samples.h5']
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data200000to299999_600samples.h5']#,\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data300000to399999_600samples.h5',\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data400000to499999_600samples.h5']

trained_model_path = args.trained_model_path
data_path = 'data.npz'
nruns_default = 10
maxtrc_default = 200
testPortion = 1
attack_byModel_epNo = 232


# training parameters
train_batch_size = args.batch_size#100#150#200#250#500#640 #80 for mars45 #170 for mars56
period = 8 #8
maxEpochs = args.num_epoch#3072#2048#1536#1280#1024#512#256 #1536
attack_byModel_fileNo = int(attack_byModel_epNo/period)
N_TRACE = 20000
Threshold_Save = 200

#model hyper-parameters
noConv1Dbranch = 1
noLayers = 6    # if newly train
noClassificationLayer = 1
GPU_clear = True    # False

# training data type
xType = args.xType  #'wave' #'wavebp0' #'wavebp1' #'wavebp01' #'wavebp01next0' #'wavebp01next01'
yType = 'skpv'    #'fqmul0' #'fqmul1' #'skpv' 
trainPortion = 0.8

# Database and logs for model and training progress (epochs)
attackModel = 'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1'
device = 'm4_CWLite'
attackModel_dev = attackModel + '_' + device
attackModel_dev_folder = '../' + attackModel_dev + '/'

MLmodelStruct = '4C4FC_2BP4FC4FC_J4FCSM'
#MLmodel_detail = '3C[512_128_64]_2BP4FC[1024_512_256_128]4FC[1024_512_256_128]_J4FC[1024_512_256_128]SM'
MLmodel_detail = '4C/512_256_128_64/_2BP4FC/1024_512_256_128/4FC/1024_512_256_128/_J4FC/1024_512_256_128/SM'

hyper_ver = 'hy0001010101_skpv0'    #hyper-parameter contains 5 groups: Conv1D, FC for Conv1D, BP0, BP1, FC for joined BPs
#dataFile_train = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'  #'20kDatax25'
dataFile_train_folder = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'    #'20kDatax25'
dataFile_attack = '100kDatax1_test'#'skvp0_0_700points100kDatax1attack' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'  #'20kDatax25'
model_input_type = '_in[[][]]_tf2' #'in[]_tf2' #'[[][]]_tf2'
#data_type = dataFile_train + model_input_type
data_type = '100kDataxN' + str(len(training_file_list)) + model_input_type
#database_folder_train = attackModel_dev_folder + attackModel + '_' + dataFile_train_folder + '_h5/'
save_path = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_num_{}'.format(args.name ,args.train_type, args.sampling, xType, args.num_iteration , args.start_trace, args.end_trace, args.num_sample, args.schedule_iteration, args.update_sampling, args.seed, args.num_sample * args.num_iteration)
print(save_path)
#database_folder_train = os.path.join('trained_models', save_path)
database_folder_train = os.path.join('multi_attack_trained_models', save_path)
Path(database_folder_train).mkdir(parents=True, exist_ok=True)
database_folder_attack = attackModel_dev_folder + attackModel + '_' + dataFile_attack + '_h5/'
logFilename = MLmodelStruct + '_' + hyper_ver
DLmodel_name = logFilename
#DLmodel_folder = attackModel_dev_folder + logFilename + '_' + data_type + '/'
DLmodel_folder = 'models/'
modelLogFolder = DLmodel_folder + 'log' + DLmodel_name + '/'
logTrainedModel_byFile_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byDataFile/'
#logTrainedModel_byEp_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byEpoch/'
logTrainedModel_byEp_folder = logTrainedModel_byFile_folder
attackLogFolder = DLmodel_folder + 'log' + DLmodel_name + '_attack/'
if os.path.isdir(DLmodel_folder) == False:
    os.mkdir(DLmodel_folder)
if os.path.isdir(modelLogFolder) == False:
    os.mkdir(modelLogFolder)
if os.path.isdir(logTrainedModel_byFile_folder) == False:
    os.mkdir(logTrainedModel_byFile_folder)
if os.path.isdir(logTrainedModel_byEp_folder) == False:
    os.mkdir(logTrainedModel_byEp_folder)
print('DLmodel_folder =', DLmodel_folder)
print('modelLogFolder =', modelLogFolder)
print('logTrainedModel_byFile_folder =', logTrainedModel_byFile_folder)
print('logTrainedModel_byEp_folder =', logTrainedModel_byEp_folder)


################################################################################################
####################################### MODELS STRUCTURE #######################################
################################################################################################
# Input BatchNormalization for each PoI size
#                           subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_inputBNorms =   [   1,      0,      0,      0,      0,      0]
###################### MULTI CONVOLUTIONAL-SIZE CONVOLUTION ######################
# Convolutional nodes
# matrix showing number of nodes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_NoConvNodes =   [   512,    256,    128,    64,     0,      0]
# Convolutional filter sizes
# matrix showing filter sizes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convKernelSizes = [   3,      3,      3,      3,      0,      0] # subModel0

###############################################
# Pooling size in convolutional layers
# matrix showing MaxPooling sizes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convPoolSizes = [   2,      2,      2,      2,      0,      0] # subModel0

# Pooling stride in convolutional layers
# matrix showing MaxPooling strike in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convPoolStrides = [   3,      3,      3,      3,      0,      0] # subModel0

# BatchNormalization in convolutional layers
# matrix showing BatchNormalization condition in each convolutional layer in each PoI length
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_convBNorms = [   1,      1,      1,      1,      0,      0] # subModel0
# Dropout in convolutional layers
# matrix showing Dropout value in each convolutional layer in each PoI length
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_convDrops = [   0,      0,      0,      0,      0,      0] # subModel0

###################### MULTI CONVOLUTIONAL-SIZE FULLY-CONNECTED ######################
# Flatten Convolutional feature map before Fully connected
#                           subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_convFeatFlat = [    1,      0,      0,      0,      0,      0]
# Fully-connected for convolutional value before adding Plaintext
# matrix showing fully-connected condition before adding Plaintext
#                   layer0  layer1  layer2  layer3  layer4  layer5
subMods_FCs = [   1024,   512,    256,    128,    0,      0] # subModel0
# BatchNormalization for fully-connected of convolutional value before adding Plaintext
# matrix showing BatchNormalization for fully-connected condition before adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_FC_BNorms = [   1,      1,      1,      1,      0,      0] # subModel0

# Dropout for fully-connected of convolutional value before adding Plaintext
# matrix showing Dropout for fully-connected condition before adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_FC_Drops = [   0.2,    0,      0.2,    0,      0,      0] # subModel0

###################### MULTI_CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENSION ######################
# Plaintext adding here
if xType == 'wave':
    noBPbranch = 0
    #                       sub0    sub1    sub2    sub3    sub4    sub5
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp0':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0] # conv1D branch 0
elif xType == 'wavebp1':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01':
    noBPbranch = 2
    subMods_Pext =  [  1,      1,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01next0':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01next01':
    noBPbranch = 4
    subMods_Pext =  [  1,      1,      1,      1,      0,      0]  # conv1D branch 0


###################### (MULTI CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENDED) FULLY-CONNECTED ######################
# Fully-connected for convolutional value after adding Plaintext
# matrix showing fully-connected condition after adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FCs = [   [   1024,   1024,   512,    256,    128,    0], # subModel0
                        [   1024,   1024,   512,    256,    128,    0], # subModel1
                        [   1024,   1024,   512,    256,    128,    0], # subModel2
                        [   1024,   1024,   512,    256,    128,    0], # subModel3
                        [   0,      0,      0,      0,      0,      0], # subModel4
                        [   0,      0,      0,      0,      0,      0]]    # subModel5
# BatchNormalization for fully-connected of convolutional value after adding Plaintext
# matrix showing BatchNormalization for fully-connected condition after adding Plaintext
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FC_BNorms = [ [   1,      1,      1,      1,      1,      0], # subModel0
                            [   1,      1,      1,      1,      1,      0], # subModel1
                            [   1,      1,      1,      1,      1,      0], # subModel2
                            [   1,      1,      1,      1,      1,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]]    # subModel5
# Dropout for fully-connected of convolutional value after adding Plaintext
# matrix showing Dropout for fully-connected condition after adding Plaintext
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FC_Drops = [  [   0.2,    0.2,    0,      0.1,    0,      0], # subModel0
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel1
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel2
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]]    # subModel5

# Softmax for each sub-model if available
#                               subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_classification =    [  0,      0,      0,      0,      0,      0]

if xType == 'wave':
    subMods_join =  [   0]
else:
    subMods_join =  [   1]  

subMods_join_FCs =  [  1024,   1024,   512,    256,    128,    0]
# BatchNormalization for fully-connected of convolutional value after joining PoIs
# matrix showing BatchNormalization for fully-connected condition after joining PoIs
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_join_FC_BNorms =    [  1,      1,      1,      1,      1,      0]
# Dropout for fully-connected of convolutional value after joining PoIs
# matrix showing Dropout for fully-connected condition after joining PoIs
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_join_FC_Drops = [  0.2,        0.2,        0,      0.1,        0,      0]

# Softmax for joined-model if available
subMods_join_classification =   [   1]

################################################################################################
##################################### MODELS STRUCTURE END #####################################
################################################################################################


def check_file_exists(file_path):
    if os.path.exists(file_path) == False:
        print("Error: provided file path '%s' does not exist!" % file_path)
        sys.exit(-1)
    return

def listDirWithExt(directory, extension):
    return (f for f in os.listdir(directory) if f.endswith('.' + extension))

def subModels_gen(xType,noConv1Dbranch, noBPbranch, noLayers, tracelen, NumBPinput, MLmodel_detail, modelLogFolder, logFilename, classes=noClasses):
    input_trace_shape = (tracelen,1)
    input_Ptext1hot_shape = (NumBPinput,1)
    m_traceinputs = []
    m_Ptextinputs = []
    inputs = []
    for dataNo in range(noConv1Dbranch):
        trace_input = Input(shape=input_trace_shape)    #trace_input need to be generated many times to shows that they are different inputs
        m_traceinputs.append(trace_input)
        #inputs.append(trace_input)
    Ptext_input1 = Input(shape=input_Ptext1hot_shape)    #Ptext_input need to be generated many times to shows that they are different inputs
        #inputs.append(Ptext_input)
    Ptext_input2 = Input(shape=input_Ptext1hot_shape)
    Ptext_input3 = Input(shape=input_Ptext1hot_shape)
    Ptext_input4 = Input(shape=input_Ptext1hot_shape)
    if xType == 'wave':
        inputs = [trace_input]
    elif xType == 'wavebp01':
        inputs = [trace_input, Ptext_input1, Ptext_input2]

    #First block: taking inputs and output features
    #x = Conv1D(128, 3, strides=2, padding="same")(inputs[0])
    #x = BatchNormalization()(x)
    #x = Activation("relu")(x)

    x = BatchNormalization()(inputs[0])
    #
    for layerNo in range(noLayers):
        if (subMods_NoConvNodes[layerNo]!=0 and subMods_convKernelSizes[layerNo]!=0 and subMods_convPoolSizes[layerNo]!=0 and subMods_convPoolStrides[layerNo]!=0):
            x = Conv1D(subMods_NoConvNodes[layerNo], subMods_convKernelSizes[layerNo], activation='relu', padding='same', name='ConvBlock_'+'subModels'+'_conv'+str(layerNo))(x)
            x = MaxPooling1D(subMods_convPoolSizes[layerNo], strides=subMods_convPoolStrides[layerNo], name='ConvBlock_'+'subModels'+'_pool'+str(layerNo))(x)
        if subMods_convBNorms[layerNo] != 0:
            x = BatchNormalization(trainable=True)(x)
        if (subMods_convDrops[layerNo]!=0):
            x = Dropout(subMods_convDrops[layerNo])(x)

    for layerNo in range(noLayers):
        # FC_PoI_size*y*_layer*x*
        if ((layerNo==0) and (subMods_convFeatFlat!=0)):
            x = Flatten()(x)
        if (subMods_FCs[layerNo]!=0):
            x = Dense(subMods_FCs[layerNo])(x)
        if (subMods_FC_BNorms[layerNo]!=0):
            x = BatchNormalization(trainable=True)(x)
        if (subMods_FC_Drops[layerNo]!=0):
            x = Dropout(subMods_FC_Drops[layerNo])(x)
    '''
    for layerNo in range(noLayers):
        # FC_Pext_size*y*_layer*x*
        if (subMods_join_FCs[layerNo]!=0):
            print(subMods_join_FCs[layerNo])
            x = Dense(subMods_join_FCs[layerNo], activation='relu')(x)
        if (subMods_join_FC_BNorms[layerNo]!=0):
            #BPbranchOuts_joined = tf.layers.batch_normalization(BPbranchOuts_joined, trainable=True, name='subMods_join_FCs_BNorm'+str(conv1DbranchNo)+'_'+str(layerNo))
            x = BatchNormalization(trainable=True )(x)
        if (subMods_join_FC_Drops[layerNo]!=0):
            x = Dropout(subMods_join_FC_Drops[layerNo])(x)
    '''
    output_cnn = x
    #Ptext Area
    BPbranchOuts_list = []
    for BPbranchNo in range(noBPbranch):
        # PtextExt_size*y*
        if (subMods_Pext[BPbranchNo]!=0):
            Ptext_flatten = Flatten()(inputs[1 + BPbranchNo])
            x = Concatenate()([output_cnn, Ptext_flatten])
        for layerNo in range(noLayers):
            # FC_Pext_size*y*_layer*x*
            if (subMods_Pext_FCs[BPbranchNo][layerNo]!=0):
                x = Dense(subMods_Pext_FCs[BPbranchNo][layerNo], activation='relu')(x)
            if (subMods_Pext_FC_BNorms[BPbranchNo][layerNo]!=0):
                x = BatchNormalization(trainable=True)(x)
            if (subMods_Pext_FC_Drops[BPbranchNo][layerNo]!=0):
                x = Dropout(subMods_Pext_FC_Drops[BPbranchNo][layerNo])(x)

        ###################### CLASSIFICATION (SOFTMAX) ######################
        if (subMods_classification[BPbranchNo]!=0):
            x = Dense(classes, activation='softmax')(x)
        print(x)
        BPbranchOuts_list.append(x)

    print(BPbranchOuts_list)
    if len(BPbranchOuts_list) != 0:
        if subMods_join != 0:
            x = Concatenate()(BPbranchOuts_list)
    
    for layerNo in range(noLayers):
        # FC_Pext_size*y*_layer*x*
        if (subMods_join_FCs[layerNo]!=0):
            x = Dense(subMods_join_FCs[layerNo], activation='relu')(x)
        if (subMods_join_FC_BNorms[layerNo]!=0):
            x = BatchNormalization(trainable=True)(x)
        if (subMods_join_FC_Drops[layerNo]!=0):
            x = Dropout(subMods_join_FC_Drops[layerNo])(x)
    
    #Softmax
    outputs = Dense(classes, activation='softmax')(x)
    
    sModel = Model(inputs, outputs, name='model')
    sModel.summary()
    tf.keras.utils.plot_model(sModel, show_shapes=True, to_file='model.png')
    # plot graph of ensemble
    #plot_model(sModel, show_shapes=True, to_file=modelLogFolder + logFilename + '_modelGraph.png')
    optimizer = RMSprop(learning_rate=0.00001)
    sModel.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        #allBranchOuts_list.append(sModel)
    #return allBranchOuts_list
    return sModel

# make a prediction with a stacked model
# https://machinelearningmastery.com/stacking-ensemble-for-deep-learning-neural-networks/
def predict_stacked_model(model, inputX):
    # prepare input data
    X = [inputX for _ in range(len(model.input))]
    # make prediction
    return model.predict(X, verbose=0)

def load_sca_model(model_file):
    check_file_exists(model_file)
    try:
            model = load_model(model_file)
    except:
        print("Error: can't load Keras model file '%s'" % model_file)
        sys.exit(-1)
    return model

####### THESE FUNCTIONS ARE SPECIALIZED FOR KYBER    #######
####### Loading traces and metadata from file ############
#def load_meta_trace_file(database_file, sKeyNo, load_metadata=False):
def load_meta_trace_files(data_path, start_trace, end_trace):
    data = np.load(data_path)
    trace_profiling = data['data']
    bp_profiling = data['bp']
    skpv_profiling = data['label']

    return (trace_profiling[start_trace:end_trace], bp_profiling[:,start_trace:end_trace], skpv_profiling[start_trace:end_trace])#, fqmul_profiling)

def load_meta_trace_file_from_test(database_file, sKeyNo, load_metadata=False):
    print('\nLoad database_file =', database_file)
    check_file_exists(database_file)
    # Open the Kyber database HDF5 for reading
    try:
        in_file  = h5py.File(database_file, "r")
    except:
        print("Error: can't open HDF5 file '%s' for reading (it might be malformed) ..." % database_file)
        sys.exit(-1)
    # Load profiling traces
    #trace_profiling = np.array(in_file['wave'], dtype=np.float)
    trace_profiling = np.array(in_file['wave'], dtype=float)
    #skpv_a_vec0_evenCoeff0 = np.array(in_file['sca_tmp_skpv'][:,sKeyNo])
    skpv_a_vec0_evenCoeff0 = np.array(in_file['skpv_a_vec0_evenCoeff0'][:,sKeyNo].astype(int))
    skpv_a_vec0_evenCoeff0_next_sKeyNo = np.array(in_file['skpv_a_vec0_evenCoeff0'][:,sKeyNo+1].astype(int))
    skpv_a_vec0_oddCoeff1 = np.array(in_file['skpv_a_vec0_oddCoeff1'][:,sKeyNo].astype(int))
    skpv_a_vec0_oddCoeff1_next_sKeyNo = np.array(in_file['skpv_a_vec0_oddCoeff1'][:,sKeyNo+1].astype(int))
    skpv_a_vec1_evenCoeff0 = np.array(in_file['skpv_a_vec1_evenCoeff0'][:,sKeyNo].astype(int))
    skpv_a_vec1_evenCoeff0_next_sKeyNo = np.array(in_file['skpv_a_vec1_evenCoeff0'][:,sKeyNo+1].astype(int))
    skpv_a_vec1_oddCoeff1 = np.array(in_file['skpv_a_vec1_oddCoeff1'][:,sKeyNo].astype(int))
    skpv_a_vec1_oddCoeff1_next_sKeyNo = np.array(in_file['skpv_a_vec1_oddCoeff1'][:,sKeyNo+1].astype(int))
    skpv_profiling = skpv_a_vec0_evenCoeff0
    bp_b_vec0_evenCoeff0 = np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo].astype(int))
    bp_b_vec0_evenCoeff0_next_sKeyNo = np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo+1].astype(int))
    bp_b_vec0_oddCoeff1 = np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo].astype(int))
    bp_b_vec0_oddCoeff1_next_sKeyNo = np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo+1].astype(int))
    bp_b_vec1_evenCoeff0 = np.array(in_file['bp_b_vec1_evenCoeff0'][:,sKeyNo].astype(int))
    bp_b_vec1_evenCoeff0_next_sKeyNo = np.array(in_file['bp_b_vec1_evenCoeff0'][:,sKeyNo+1].astype(int))
    bp_b_vec1_oddCoeff1 = np.array(in_file['bp_b_vec1_oddCoeff1'][:,sKeyNo].astype(int))
    bp_b_vec1_oddCoeff1_next_sKeyNo = np.array(in_file['bp_b_vec1_oddCoeff1'][:,sKeyNo+1].astype(int))
    
    sca_bp_in = np.array(in_file['sca_bp_in'])
    bp_profiling = [bp_b_vec0_evenCoeff0, bp_b_vec0_oddCoeff1, bp_b_vec0_evenCoeff0_next_sKeyNo, bp_b_vec0_oddCoeff1_next_sKeyNo]
    a_vec0_evenCoeff_by_b_vec0_evenCoeff = np.array(in_file['a_vec0_evenCoeff_by_b_vec0_evenCoeff'][:,sKeyNo])
    a_vec0_evenCoeff_by_b_vec0_oddCoeff = np.array(in_file['a_vec0_evenCoeff_by_b_vec0_oddCoeff'][:,sKeyNo])
    fqmul_profiling = [a_vec0_evenCoeff_by_b_vec0_evenCoeff, a_vec0_evenCoeff_by_b_vec0_oddCoeff]

    return (trace_profiling, bp_profiling, skpv_profiling)
#### Converting traces and metadata to training format
# inputs = [[list of traces], [list of bp]]
#def create_training_data_form(database_folder_train_file, sKeyNo, trainPortion, xType, yType):

def create_training_data_form(data_path, sKeyNo, trainPortion, xType, yType, is_test, start_trace, end_trace):
    #(trace_profiling, bp_profiling, skpv_profiling, fqmul_profiling) = load_meta_trace_file(database_folder_train_file, sKeyNo)
    #(trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file(database_folder_train_file, sKeyNo)
    if is_test:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file_from_test(data_path, sKeyNo)
    else:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_files(data_path, start_trace, end_trace)

    Reshaped_trace_profiling= trace_profiling.reshape((trace_profiling.shape[0], trace_profiling.shape[1], 1))
    dataSize = Reshaped_trace_profiling.shape[0]
    trainSize = math.floor(dataSize * trainPortion)
    valLoc = trainSize
    if valLoc == dataSize:
        valLoc = dataSize - 1

    lineNo = list(range(0, bp_profiling[0].shape[0]))
    #bp0_1hot_profiling = np.zeros((bp_profiling[0].shape[0], NumBPinput)).astype(np.int)
    print('2')
    print((bp_profiling[0].shape[0], NumBPinput))
    bp0_1hot_profiling = np.zeros((bp_profiling[0].shape[0], NumBPinput)).astype(int)
    print('bp0_1hot_profiling.shape =', bp0_1hot_profiling.shape, '                bp_profiling[0] =', bp_profiling[0])
    bp0_1hot_profiling[lineNo,bp_profiling[0]] = 1
    Reshaped_bp0_1hot_profiling = bp0_1hot_profiling.reshape((bp0_1hot_profiling.shape[0], NumBPinput, 1))
    #print('Reshaped_bp0_1hot_profiling.shape = ', Reshaped_bp0_1hot_profiling.shape)
   
    lineNo = list(range(0, bp_profiling[1].shape[0]))
    #bp1_1hot_profiling = np.zeros((bp_profiling[1].shape[0], NumBPinput)).astype(np.int)
    bp1_1hot_profiling = np.zeros((bp_profiling[1].shape[0], NumBPinput)).astype(int)
    print('bp1_1hot_profiling.shape =', bp1_1hot_profiling.shape, '                bp_profiling[1] =', bp_profiling[1])
    #input()
    bp1_1hot_profiling[lineNo,bp_profiling[1]] = 1
    Reshaped_bp1_1hot_profiling = bp1_1hot_profiling.reshape((bp1_1hot_profiling.shape[0], NumBPinput, 1))
    
    lineNo = list(range(0, bp_profiling[2].shape[0]))
    #bp0_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[2].shape[0], NumBPinput)).astype(np.int)
    bp0_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[2].shape[0], NumBPinput)).astype(int)
    print('bp0_1hot_profiling_next_sKeyNo.shape =', bp0_1hot_profiling_next_sKeyNo.shape, '    bp_profiling[2] =', bp_profiling[2])
    bp0_1hot_profiling_next_sKeyNo[lineNo,bp_profiling[2]] = 1
    Reshaped_bp0_1hot_profiling_next_sKeyNo = bp0_1hot_profiling_next_sKeyNo.reshape((bp0_1hot_profiling_next_sKeyNo.shape[0], NumBPinput, 1))
    
    lineNo = list(range(0, bp_profiling[3].shape[0]))
    #bp1_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[3].shape[0], NumBPinput)).astype(np.int)
    bp1_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[3].shape[0], NumBPinput)).astype(int)
    print('bp1_1hot_profiling_next_sKeyNo.shape =', bp1_1hot_profiling_next_sKeyNo.shape, '    bp_profiling[3] =', bp_profiling[3])
    #input()
    bp1_1hot_profiling_next_sKeyNo[lineNo,bp_profiling[3]] = 1
    Reshaped_bp1_1hot_profiling_next_sKeyNo = bp1_1hot_profiling_next_sKeyNo.reshape((bp1_1hot_profiling_next_sKeyNo.shape[0], NumBPinput, 1))
    
    #y_train_fqmul0 = to_categorical(fqmul_profiling[0], num_classes=NumFQMULclasses)
    #y_train_fqmul1 = to_categorical(fqmul_profiling[1], num_classes=NumFQMULclasses)
    y_train_skpv = to_categorical(skpv_profiling, num_classes=NumSKPVclasses)

    #xTrain_wave = [Reshaped_trace_profiling[0:trainSize,:,:]]
    #xTrain_wavebp0 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:]]]
    #xTrain_wavebp1 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp1_1hot_profiling[0:trainSize,:]]]
    #xTrain_wavebp01 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]]
    #xTrain_wavebp01next0 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:]]]
    #xTrain_wavebp01next01 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[0:trainSize,:]]]
    #xTrain_wave = Reshaped_trace_profiling[0:trainSize,:,:]
    #xTrain_wavebp0 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp0_1hot_profiling[0:trainSize,:]]
    #xTrain_wavebp1 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]
    #xTrain_wavebp01 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]
    #yTrain_fqmul0 = y_train_fqmul0[0:trainSize,:]
    #yTrain_fqmul1 = y_train_fqmul1[0:trainSize,:]
    #yTrain_skpv = y_train_skpv[0:trainSize,:]

    #xVal_wave = [Reshaped_trace_profiling[valLoc:,:,:]]
    #xVal_wavebp0 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:]]]
    #xVal_wavebp1 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp1_1hot_profiling[valLoc:,:]]]
    #xVal_wavebp01 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]]
    #xVal_wavebp01next0 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:]]]
    #xVal_wavebp01next01 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[valLoc:,:]]]
    #xVal_wave = Reshaped_trace_profiling[valLoc:,:,:]
    #xVal_wavebp0 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp0_1hot_profiling[valLoc:,:]]
    #xVal_wavebp1 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]
    #xVal_wavebp01 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]
    #yVal_fqmul0 = y_train_fqmul0[valLoc:,:]
    #yVal_fqmul1 = y_train_fqmul1[valLoc:,:]
    #yVal_skpv = y_train_skpv[valLoc:,:]

    # Input data creation
    if xType == 'wave':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]]]#xTrain_wave
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]]]#xVal_wave
    elif xType == 'wavebp0':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:]]]#xTrain_wavebp0
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:]]]#xVal_wavebp0
    elif xType == 'wavebp1':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp1_1hot_profiling[0:trainSize,:]]]#xTrain_wavebp1
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp1_1hot_profiling[valLoc:,:]]]#xVal_wavebp1
    elif xType == 'wavebp01':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]]#xTrain_wavebp01
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]]#xVal_wavebp01
    elif xType == 'wavebp01next0':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:]]]#xTrain_wavebp01next0
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:]]]#xVal_wavebp01next0
    elif xType == 'wavebp01next01':
        xTrain = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[0:trainSize,:]]]#xTrain_wavebp01next01
        xVal = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[valLoc:,:]]]#xVal_wavebp01next01
        #print('Created xType, len(xTrain) :', xType, len(xTrain))
        #print('len(xTrain[0]) =', len(xTrain[0]), ';    len(xTrain[1]) =', len(xTrain[1]), 'PRESS ENTER TO CONTINUE')
        #print('PRESS ENTER TO CONTINUE')
        #input()
    # Category creation
    if yType == 'fqmul0':
        yTrain = yTrain_fqmul0
        yTrain_value = fqmul_profiling[0][0:trainSize]
        yVal = yVal_fqmul0
        yVal_value = fqmul_profiling[0][valLoc:]
    elif yType == 'fqmul1':
        yTrain = yTrain_fqmul1
        yTrain_value = fqmul_profiling[1][0:trainSize]
        yVal = yVal_fqmul1
        yVal_value = fqmul_profiling[1][valLoc:]
    elif yType == 'skpv':
        yTrain = y_train_skpv[0:trainSize,:]#yTrain_skpv
        yTrain_value = skpv_profiling[0:trainSize]
        yVal = y_train_skpv[valLoc:,:]#yVal_skpv
        yVal_value = skpv_profiling[valLoc:]

    return xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value

def get_subset(idxs, X_train, y_train):
    profile_data = X_train[0][0]
    print(len(profile_data))
    profile_bp0 = X_train[1][0]
    profile_bp1 = X_train[1][1]
    sub_X_train = [[profile_data[idxs,:,:]], [profile_bp0[idxs,:], profile_bp1[idxs,:]]]
    sub_y_train = y_train[idxs,:]

    return sub_X_train, sub_y_train

def mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
    realkey = int(yTest_value[0])
    rankmat_byKey = np.tile(0, (nruns, maxtrc))
    rankmat_byClass = np.tile(0, (nruns, maxtrc))
    ps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    lpsums_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    print("-------------------------------------------")
    print(len(xTest[0][0]))
    #print('%s  is running' % (model.__name__))
    for krun in range(nruns):
        #print('%s  run %d of %d' % (model.__name__, krun+1, nruns))
        #if (krun % nruns) == 0:
        #    print('%s  run %d of %d' % (model.name, krun+1, nruns))
        samp = batches[krun,:]
        #ps = model.predict(U[samp,:])

        if xType == 'wave':
            ps = model.predict([xTest[0][0][samp,:,:]])
        elif xType == 'wavebp0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:]])
        elif xType == 'wavebp1':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][1][samp,:]])
        elif xType == 'wavebp01':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:]])
        elif xType == 'wavebp01next0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:]])
        elif xType == 'wavebp01next01':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:], xTest[1][3][samp,:]])

        lps = np.log(ps)
        lpsums = np.zeros(noHypoKeys)
        #lpsAllHypoKeys = np.zeros((maxtrc, noHypoKeys))
        for i in range(maxtrc):
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = AES_Sbox[P[samp[i]] ^ range(noHypoKeys)]
            realClass = realkey#S[realkey]
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = P[samp[i]] ^ range(noHypoKeys)
            #realClass = HWcompute(S[realkey])
            lpsAllHypoKeys = lps
            #for hypoKey in range(noHypoKeys):
            #   lpsAllHypoKeys[i, hypoKey] = lps[i, S[hypoKey]]
            #print('lpsums.shape =', lpsums.shape, ';   lps.shape =', lps.shape)
            lpsums += lps[i]#, S]
            lpsums_AllHypoKeys_Nruns[i,:,krun] = lpsums
            #print('realkey =', realkey)
            rnk_byKey = sum(lpsums > lpsums[realkey])
            rankmat_byKey[krun, i] = rnk_byKey
            rnk_byClass = sum(lps[i, :] > lps[i, realClass])
            rankmat_byClass[krun, i] = rnk_byClass
        ps_AllClasses_Nruns[:,:,krun] = ps
        lps_AllClasses_Nruns[:,:,krun] = lps
        lps_AllHypoKeys_Nruns[:,:,krun] = lpsAllHypoKeys
    return rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns

def eval_model(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
    rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns = mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses)
    plot_data = ['model_type', rankmat_byKey]
    mr = np.mean(rankmat_byKey, 0)

    return mr[-1]

import pandas as pd

class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, xTest, yTest_value, save_model_name, database_folder_train, eval_interval):
        super().__init__()
        self.xTest = xTest
        self.yTest_value = yTest_value
        self.save_model_name = save_model_name
        self.database_folder_train = database_folder_train
        self.all_rank = []
        self.best_model = self.model
        self.eval_interval = eval_interval

    def on_epoch_end(self, epoch, logs=None):
        #keys = list(logs.keys())
        #print("End epoch {} of training; got log keys: {}".format(epoch, keys))
        batches = np.zeros((nruns_default, maxtrc_default), 'int')
        if epoch % self.eval_interval == 0:
            for i in range(nruns_default):
                batches[i,:] = np.random.choice(len(self.xTest[0][0]), maxtrc_default, False)
            model = self.model
            test_rank = eval_model(model, nruns_default, maxtrc_default, batches, self.xTest, self.yTest_value, noHypoKeys, noClasses)
            
            if len(self.all_rank) > 0:
                if test_rank < np.min(np.array(self.all_rank)):
                    print(test_rank)
                    self.model.save(self.save_model_name + '.keras')
            self.all_rank.append(test_rank)

    def on_train_end(self, logs=None):
        df = pd.DataFrame({'Attack Mean Rank': self.all_rank})
        df.to_csv(os.path.join(self.database_folder_train, 'attack_rank.csv'))
        #Save best model

class MultiKeyCallback(tf.keras.callbacks.Callback):
    def __init__(self, xTest_multi, yTest_vals, save_model_name, database_folder_train, eval_interval):
        super().__init__()
        self.xTest_multi = xTest_multi
        self.yTest_vals = yTest_vals
        self.save_model_name = save_model_name
        self.database_folder_train = database_folder_train
        self.all_rank = []
        self.best_model = self.model
        self.maxtrc = 100 #Max trace num for multi label
        self.eval_interval = eval_interval

    def on_epoch_end(self, epoch, logs=None):
        #keys = list(logs.keys())
        #print("End epoch {} of training; got log keys: {}".format(epoch, keys))
        if epoch % self.eval_interval == 0:
            multi_rank = []
            for i in range(len(self.yTest_vals)): #Iterate each key
                nruns = 1 #data for multi-label is limited, so we did 1 run only
                batches = np.zeros((nruns, self.maxtrc), 'int')
                batches[0] = np.arange(self.maxtrc)
                #for i in range(nruns_default):
                #    batches[i,:] = np.random.choice(len(self.xTest[0][0]), maxtrc_default, False)
                model = self.model
                test_rank = eval_model(model, nruns, self.maxtrc, batches, [[self.xTest_multi[i]]], [self.yTest_vals[i]], noHypoKeys, noClasses)
                multi_rank.append(test_rank)


            if len(self.all_rank) > 0:
                all_rank_np = np.array(self.all_rank)
                print(all_rank_np.shape)
                '''
                for i in range(len(self.yTest_vals)):
                    if multi_rank[i] < np.min(all_rank_np[:,i] ):
                        self.model.save(self.save_model_name + str(i) + '.keras') #Save best model for each key
                '''
                all_rank_avg = np.mean(all_rank_np, axis = 1)
                print(all_rank_avg.shape)
                print('ZZZZZZZZZZZZZZZZZZZZ')
                if np.mean(np.array(multi_rank)) < np.min(all_rank_avg):
                    self.model.save(self.save_model_name + '_overall.keras')
            self.all_rank.append(multi_rank)

    def on_train_end(self, logs=None):
        all_cols = ['Mean Rank Key No.' + str(i) for i in range(len(self.yTest_vals))]
        df = pd.DataFrame(self.all_rank, columns=all_cols)
        #df = pd.DataFrame({'Attack Mean Rank': self.all_rank})
        self.model.save(self.save_model_name + '_end.keras')
        df.to_csv(os.path.join(self.database_folder_train, 'attack_rank_multi.csv'))
        #Save best model

def get_subset(idxs, X_train, y_train, xType):
    profile_data = X_train[0][0]
    if xType == 'wave':
        sub_X_train = [[profile_data[idxs,:,:]]]
    elif xType == 'wavebp01':
        profile_bp0 = X_train[1][0]
        profile_bp1 = X_train[1][1]
        sub_X_train = [[profile_data[idxs,:,:]], [profile_bp0[idxs,:], profile_bp1[idxs,:]]]
    sub_y_train = y_train[idxs,:]

    return sub_X_train, sub_y_train

from sklearn.cluster import KMeans
import copy

def min_max_sampling(data, num_cluster, num_sample):
    sum_data = np.sum(data, axis = 1)
    exp_data = np.expand_dims(sum_data, axis=1)
    exp_data = np.squeeze(np.hstack((exp_data,exp_data)))
    #kmeans = KMedoids(n_clusters=num_cluster, random_state=0).fit(X)
    kmeans = KMeans(n_clusters=200, random_state=0, n_init="auto").fit(exp_data)
    #Get all maximum distances
    min_dist = []
    for sample in data:
        all_dist = []
        for clusterNo in range(len(kmeans.cluster_centers_)):
            centroid = kmeans.cluster_centers_[clusterNo]
            all_dist.append(np.linalg.norm(centroid-sample))
        all_dist = np.array(all_dist)
        min_dist.append(np.min(all_dist))
    #max_idx = np.argmax(min_dist)
    max_idxs = np.argpartition(min_dist, -num_sample)[-num_sample:]

    return max_idxs

import random

def random_sampling(data, exclude_ids, num_sample):
    all_ids = np.arange(len(data))
    print(len(all_ids))
    if len(exclude_ids) > 0:
        available_ids = np.delete(all_ids, all_ids==exclude_ids)
    else:
        available_ids = all_ids
    print(len(available_ids))
    rand_ids = np.random.choice(available_ids, num_sample)
    #rand_ids = random.sample(available_ids, num_sample)
    return rand_ids

def unceratainty_sampling_by_label(model, xTrain_in, yTrain_in, num_sample, chunk_size = 20000):
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_in).batch(192)
    preds = model.predict(val)
    print(preds.shape)
    idx = 0
    #Get samples that below the mean probability
    pred_probs = []
    for sample in preds:
        label_idx = yTrain_in[idx]
        pred_prob = preds[idx][label_idx]
        idx += 1
        pred_probs.append(pred_prob)
    pred_probs = np.squeeze(np.array(pred_probs))
    #max_idxs = np.argpartition(pred_probs, -num_sample)[-num_sample:]
    min_idxs = np.argpartition(pred_probs, num_sample)[:num_sample]

    return min_idxs

from tqdm import tqdm as tqdm

'''
def unceratainty_sampling_by_balance(model, xTrain_in, yTrain_in, exist_labels, num_sample):
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_in).batch(192)

    preds = model.predict(val)    
    exist_labels = np.squeeze(exist_labels)
    all_label_count = np.zeros(3329)
    alpha = 0.0
    smooth = 5
    sample_indexes = []
    sampling_rate = 10 #Take top 10 samples
    num_iter = int(num_sample/sampling_rate)
    print(exist_labels.shape)
    for i in tqdm(range(num_iter)):
        unique, counts = np.unique(exist_labels, return_counts=True) #Get the counts for sample in S
        all_label_count[unique] = counts
        print(np.std(all_label_count))
        print('----')
        range_x2 = (np.max(all_label_count) - np.min(all_label_count) ) / len(exist_labels)
        #print(range_x1)
        #print(range_x2)
        min_prob = np.min(all_label_count)
        
        pred_probs = 1- np.min(preds, axis = 1)
        sample_scores = []
        idx = 0
        for sample in preds:
            label_idx = yTrain_in[idx]
            pred_prob = 1- np.min(preds[idx]) #Get reversed prob
            idx += 1
            #pred_probs.append(pred_prob)
            sample_prob = all_label_count[label_idx] / len(exist_labels)
            x2 = 1-(sample_prob- min_prob)/ range_x2


            sample_score = alpha * math.exp( - pred_prob * smooth) + (1-alpha) * (1-math.exp(-x2[0] * smooth))
            if sample_indexes is not None and idx in sample_indexes:
                sample_scores.append(0) #Assign min score to exist indexes
            else:
                sample_scores.append(sample_score)
                #pred_probs = np.min(preds, axis = 1)
        sample_index = np.argpartition(sample_scores, -sampling_rate)[-sampling_rate:]
        
        exist_labels = np.concatenate((exist_labels, np.squeeze(yTrain_in[sample_index])))
    #pred_probs = np.squeeze(np.array(pred_probs))
    #max_idxs = np.argpartition(pred_probs, -num_sample)[-num_sample:]
    #min_idxs = np.argpartition(pred_probs, num_sample)[:num_sample]
        #sample_indexes.append(sample_index)
        sample_indexes = np.hstack((sample_indexes, sample_index))
    #print(exist_labels.shape)
    #print(sample_indexes.shape)

    return sample_indexes
'''

def unceratainty_sampling_by_balance(model, xTrain_in, yTrain_in, exist_labels, num_sample, alpha):
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_in).batch(192)

    preds = model.predict(val)
    pred_index = np.arange(len(preds))    
    exist_labels = np.squeeze(exist_labels)
    all_label_count = np.zeros(3329)
    smooth = 5
    sample_indexes = []
    sampling_rate = 10 #Take top 10 samples
    num_iter = int(num_sample/sampling_rate)
    print(exist_labels.shape)
    score = np.zeros(len(preds))
    for i in tqdm(range(num_iter)):
        unique, counts = np.unique(exist_labels, return_counts=True) #Get the counts for sample in S
        all_label_count[unique] = counts
        #print('----')
        #print(np.std(all_label_count))
        #print(np.max(all_label_count))
        #print(np.min(all_label_count))
        range_x2 = (np.max(all_label_count) - np.min(all_label_count) ) / len(exist_labels)
        #print(range_x1)
        #print(range_x2)
        min_prob = np.min(all_label_count) / len(exist_labels)
        
        #pred_probs = 1- np.min(preds, axis = 1)
        sample_scores = []
        
        for idx in pred_index:
            label_idx = yTrain_in[idx]
            pred_prob = 1- np.min(preds[idx]) #Get reversed prob
            
            #pred_probs.append(pred_prob)
            sample_prob = all_label_count[label_idx] / len(exist_labels) #Label balancing
            x2 = 1-(sample_prob- min_prob)/ range_x2
            '''
            print(label_idx)
            print(sample_prob)
            print(min_prob)
            print(x2)
            exit()
            '''

            sample_score = alpha * math.exp( - pred_prob * smooth) + (1-alpha) * (1-math.exp(-x2[0] * smooth))
            score[idx] = sample_score
            #if sample_indexes is not None and idx in sample_indexes:
            #    sample_scores.append(0) #Assign min score to exist indexes
            #else:
            #    sample_scores.append(sample_score)
                #pred_probs = np.min(preds, axis = 1)
        sample_index = np.argpartition(score, -sampling_rate)[-sampling_rate:]
        #print(np.max(score))
        #print(sample_index)
        #print(yTrain_in[sample_index])
        exist_labels = np.concatenate((exist_labels, np.squeeze(yTrain_in[sample_index])))
    #pred_probs = np.squeeze(np.array(pred_probs))
    #max_idxs = np.argpartition(pred_probs, -num_sample)[-num_sample:]
    #min_idxs = np.argpartition(pred_probs, num_sample)[:num_sample]
        #sample_indexes.append(sample_index)
        pred_index = np.setdiff1d(pred_index, sample_indexes)
        #print(len(pred_index))
        sample_indexes = np.hstack((sample_indexes, sample_index))

    #print(exist_labels.shape)
    #print(sample_indexes.shape)

    return sample_indexes

def unceratainty_sampling_by_balance_label(model, xTrain_in, yTrain_in, exist_labels, num_sample):
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_in).batch(192)

    pred_vector = model.predict(val)
    print(pred_vector.shape)
    preds = []
    lb_idx = 0
    for pred in pred_vector:
        preds.append(pred[yTrain_in[lb_idx]])
        lb_idx += 1
    preds = np.array(preds)
    print(np.min(preds))
    print(np.max(preds))
    pred_index = np.arange(len(preds))    
    exist_labels = np.squeeze(exist_labels)
    all_label_count = np.zeros(3329)
    alpha = 0.5
    smooth = 5
    sample_indexes = []
    sampling_rate = 10 #Take top 10 samples
    num_iter = int(num_sample/sampling_rate)
    print(exist_labels.shape)
    score = np.zeros(len(preds))
    for i in tqdm(range(num_iter)):
        unique, counts = np.unique(exist_labels, return_counts=True) #Get the counts for sample in S
        all_label_count[unique] = counts
        #print('----')
        #print(np.std(all_label_count))
        #print(np.max(all_label_count))
        #print(np.min(all_label_count))
        range_x2 = (np.max(all_label_count) - np.min(all_label_count) ) / len(exist_labels)
        #print(range_x1)
        #print(range_x2)
        min_prob = np.min(all_label_count) / len(exist_labels)
        
        #pred_probs = 1- np.min(preds, axis = 1)
        sample_scores = []
        
        for idx in pred_index:
            label_idx = yTrain_in[idx]
            pred_prob = 1- np.min(preds[idx]) #Get reversed prob
            
            #pred_probs.append(pred_prob)
            sample_prob = all_label_count[label_idx] / len(exist_labels)
            x2 = 1-(sample_prob- min_prob)/ range_x2
            '''
            print(label_idx)
            print(sample_prob)
            print(min_prob)
            print(x2)
            exit()
            '''

            sample_score = alpha * math.exp( - pred_prob * smooth) + (1-alpha) * (1-math.exp(-x2[0] * smooth))
            score[idx] = sample_score
            #if sample_indexes is not None and idx in sample_indexes:
            #    sample_scores.append(0) #Assign min score to exist indexes
            #else:
            #    sample_scores.append(sample_score)
                #pred_probs = np.min(preds, axis = 1)
        sample_index = np.argpartition(score, -sampling_rate)[-sampling_rate:]
        #print(np.max(score))
        #print(sample_index)
        #print(yTrain_in[sample_index])
        exist_labels = np.concatenate((exist_labels, np.squeeze(yTrain_in[sample_index])))
    #pred_probs = np.squeeze(np.array(pred_probs))
    #max_idxs = np.argpartition(pred_probs, -num_sample)[-num_sample:]
    #min_idxs = np.argpartition(pred_probs, num_sample)[:num_sample]
        #sample_indexes.append(sample_index)
        pred_index = np.setdiff1d(pred_index, sample_indexes)
        #print(len(pred_index))
        sample_indexes = np.hstack((sample_indexes, sample_index))

    #print(exist_labels.shape)
    #print(sample_indexes.shape)

    return sample_indexes

def unceratainty_sampling_by_balance_medoids(model, xTrain, yTrain, all_active_ids, num_sample, medoids_path):
    medoids_indexes = np.load(args.medoids_path)
    xTrain_in = xTrain[all_active_ids]
    yTrain_in = yTrain[all_active_ids]
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_in).batch(192)

    preds = model.predict(val)
    pred_index = np.arange(len(preds))    
    exist_labels = np.squeeze(yTrain_in[all_active_ids])
    all_label_count = np.zeros(3329)
    alpha = 0.5
    smooth = 5
    sample_indexes = []
    sampling_rate = 10 #Take top 10 samples
    num_iter = int(num_sample/sampling_rate)
    print(exist_labels.shape)
    score = np.zeros(len(preds))
    for i in tqdm(range(num_iter)):
        unique, counts = np.unique(exist_labels, return_counts=True) #Get the counts for sample in S
        all_label_count[unique] = counts
        
        range_x2 = (np.max(all_label_count) - np.min(all_label_count) ) / len(exist_labels)
        min_prob = np.min(all_label_count) / len(exist_labels)
        
        #pred_probs = 1- np.min(preds, axis = 1)
        sample_scores = []
        
        for idx in pred_index:
            label_idx = yTrain_in[idx]
            pred_prob = 1- np.min(preds[idx]) #Get reversed prob
            
            #pred_probs.append(pred_prob)
            sample_prob = all_label_count[label_idx] / len(exist_labels)
            x2 = 1-(sample_prob - min_prob)/ range_x2

            sample_score = alpha * math.exp( - pred_prob * smooth) + (1-alpha) * (1-math.exp(-x2[0] * smooth))
            score[idx] = sample_score

        sample_index = np.argpartition(score, -sampling_rate)[-sampling_rate:]
        #print(np.max(score))
        #print(sample_index)
        #print(yTrain_in[sample_index])
        exist_labels = np.concatenate((exist_labels, np.squeeze(yTrain_in[sample_index])))

        pred_index = np.setdiff1d(pred_index, sample_indexes)
        #print(len(pred_index))
        sample_indexes = np.hstack((sample_indexes, sample_index))

    return sample_indexes

#def GL_sampling(model, xTrain_in, yTrain_in, exist_labels, num_sample):
    #Sample by gradient length #https://papers.nips.cc/paper_files/paper/2007/file/a1519de5b5d44b31a01de013b9b51a80-Paper.pdf
    #I leave it here for late implementation
    #This is somewhat similar to uncerainty sampling so I wont included it here atm

#I need some representative/density-based sampling here -> K_Medoids

def K_Medoids_sampling(model, xTrain, yTrain, non_active_ids, num_sample, medoids_path):
    #Method
    #Choose sample -> get medoids cluster
    #Assess the percentage of sample label in cluster -> Choose the least abundant one
    medoid_labels = np.load('200k_2000cluster/clara_labels.npy')
    dominant_labels = []
    least_dom_labels = []
    for i in tqdm(range(2000)):
        indexes = np.where(medoid_labels == i)[0]
        curr_label = yTrain[indexes]
        unique, counts = np.unique(curr_label, return_counts=True)
        max_idx = np.argmax(counts)
        min_idx = np.argmin(counts)
        dominant_labels.append(unique[max_idx])
        least_dom_labels.append(unique[min_idx])
    least_dom_labels = np.array(least_dom_labels)

    all_dom_labels = []
    for i in range(num_iter):
        for idx in non_active_ids:
            all_dom_labels.append(least_dom_labels[medoid_labels[idx]])

    min_idxs = np.argpartition(all_dom_labels, num_sample)[:num_sample]

    return min_idxs

from sklearn.metrics import pairwise_distances

# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Abstract class for sampling methods.

Provides interface to sampling methods that allow same signature
for select_batch.  Each subclass implements select_batch_ with the desired
signature for readability.
"""


import abc
import numpy as np

class SamplingMethod(object):
  __metaclass__ = abc.ABCMeta

  @abc.abstractmethod
  def __init__(self, X, y, seed, **kwargs):
    self.X = X
    self.y = y
    self.seed = seed

  def flatten_X(self):
    shape = self.X.shape
    flat_X = self.X
    if len(shape) > 2:
      flat_X = np.reshape(self.X, (shape[0],np.product(shape[1:])))
    return flat_X


  @abc.abstractmethod
  def select_batch_(self):
    return

  def select_batch(self, **kwargs):
    return self.select_batch_(**kwargs)

  def to_dict(self):
    return None

class kCenterGreedy(SamplingMethod):

  def __init__(self, X, y, seed = 2025, metric='euclidean'):
    self.X = X
    self.y = y
    self.flat_X = self.flatten_X()
    self.name = 'kcenter'
    self.features = self.flat_X
    self.metric = metric
    self.min_distances = None
    self.n_obs = self.X.shape[0]
    self.already_selected = []

  def update_distances(self, cluster_centers, only_new=True, reset_dist=False):
    """Update min distances given cluster centers.

    Args:
      cluster_centers: indices of cluster centers
      only_new: only calculate distance for newly selected points and update
        min_distances.
      rest_dist: whether to reset min_distances.
    """

    if reset_dist:
      self.min_distances = None
    if only_new:
      cluster_centers = [d for d in cluster_centers
                         if d not in self.already_selected]
    if cluster_centers:
      # Update min_distances for all examples given new cluster center.
      x = self.features[cluster_centers]
      dist = pairwise_distances(self.features, x, metric=self.metric)

      if self.min_distances is None:
        self.min_distances = np.min(dist, axis=1).reshape(-1,1)
      else:
        self.min_distances = np.minimum(self.min_distances, dist)

  def select_batch_(self, model, already_selected, N, **kwargs):
    """
    Diversity promoting active learning method that greedily forms a batch
    to minimize the maximum distance to a cluster center among all unlabeled
    datapoints.

    Args:
      model: model with scikit-like API with decision_function implemented
      already_selected: index of datapoints already selected
      N: batch size

    Returns:
      indices of points selected to minimize distance to cluster centers
    """
    print('Using flat_X as features.')
    self.update_distances(already_selected, only_new=True, reset_dist=False)

    new_batch = []

    for _ in tqdm(range(N)):
      if self.already_selected is None:
        # Initialize centers with a randomly selected datapoint
        ind = np.random.choice(np.arange(self.n_obs))
      else:
        ind = np.argmax(self.min_distances)
      # New examples should not be in already selected since those points
      # should have min_distance of zero to a cluster center.
      assert ind not in already_selected

      self.update_distances([ind], only_new=True, reset_dist=False)
      new_batch.append(ind)
    print('Maximum distance from cluster centers is %0.2f'
            % max(self.min_distances))


    self.already_selected = already_selected

    return new_batch


def K_Medoids_sampling_balance(model, xTrain, yTrain, all_active_ids, non_active_ids, num_sample, medoids_path):
    #Method
    #Choose sample -> get medoids cluster
    #Assess the percentage of sample label in cluster -> Choose the least abundant one
    medoid_labels = np.load('200k_2000cluster/clara_labels.npy')
    dominant_labels = []
    least_dom_labels = []
    count_perc = []
    for i in tqdm(range(2000)):
        indexes = np.where(medoid_labels == i)[0]
        curr_label = yTrain[indexes]
        unique, counts = np.unique(curr_label, return_counts=True)
        counts = counts/len(indexes)
        max_idx = np.argmax(counts)
        min_idx = np.argmin(counts)
        #dominant_labels.append(unique[max_idx])
        #least_dom_labels.append(unique[min_idx]/len(indexes)) #Normalize to 0-1
        count_perc.append(dict(zip(unique, counts)))
    #least_dom_labels = np.array(least_dom_labels)
    '''
    for idx in range(10):
        label_idx = yTrain[idx][0]
        print(label_idx)
        unique = count_perc[medoid_labels[idx]]
        label_perc = unique[label_idx]
        print(label_perc)
    '''
    alpha = 0.5
    smooth = 2
    sample_indexes = []
    all_label_count = np.zeros(3329)
    sampling_rate = 10 #Take top 10 samples
    num_iter = int(num_sample/sampling_rate)
    #all_dom_labels = []
    
    for i in tqdm(range(num_iter)):
        #Get 10 most suitable samples in each iteration
        exist_labels = yTrain[all_active_ids]
        unique, counts = np.unique(exist_labels, return_counts=True) #Get the counts for sample in S
        all_label_count[unique] = counts
        
        range_x2 = (np.max(all_label_count) - np.min(all_label_count) ) / len(exist_labels)
        min_prob = np.min(all_label_count) / len(exist_labels)
        
        #pred_probs = 1- np.min(preds, axis = 1)
        sample_scores = []
        
        for idx in non_active_ids:
            #all_dom_labels.append(least_dom_labels[medoid_labels[idx]])
            label_idx = yTrain[idx][0]
            label_perc = 1 - count_perc[medoid_labels[idx]][label_idx]
            #pred_prob = 1- np.min(preds[idx]) #Get reversed prob
            #pred_probs.append(pred_prob)
            sample_prob = all_label_count[label_idx] / len(exist_labels)
            x2 = 1-(sample_prob - min_prob)/ range_x2
            #print(label_perc)
            sample_score = alpha * math.exp( - label_perc * smooth) + (1-alpha) * (1-math.exp(-x2 * smooth))
            sample_scores.append(sample_score)
        score = np.array(sample_scores)
        sample_index = np.argpartition(score, -sampling_rate)[-sampling_rate:]
        sample_index = non_active_ids[sample_index]
        #print(len(sample_index))
        #print(sample_index)
        #print(non_active_ids.shape)
        #print(sample_index.shape)
        non_active_ids = np.setdiff1d(non_active_ids, sample_index) #np.delete(non_active_ids, sampled_ids)
        all_active_ids = np.hstack((all_active_ids, sample_index))
        #print('-------------')
        #print(len(non_active_ids))
        #print(len(all_active_ids))
        sample_indexes = np.hstack((sample_indexes, sample_index))

    #min_idxs = np.argpartition(all_dom_labels, num_sample)[:num_sample]
    print(len(sample_indexes))
    #unique, counts = np.unique(sample_indexes, return_counts=True)
    #print(len(unique))
    return sample_indexes

def K_Medoids_sampling_uncertainty_balance(model, xTrain, yTrain, all_active_ids, non_active_ids, num_sample, medoids_path):
    #Method
    #Choose sample -> get medoids cluster
    #Assess the percentage of sample label in cluster -> Choose the least abundant one
    medoid_labels = np.load('200k_2000cluster/clara_labels.npy')
    dominant_labels = []
    least_dom_labels = []
    count_perc = []
    for i in tqdm(range(2000)):
        indexes = np.where(medoid_labels == i)[0]
        curr_label = yTrain[indexes]
        unique, counts = np.unique(curr_label, return_counts=True)
        counts = counts/len(indexes)
        max_idx = np.argmax(counts)
        min_idx = np.argmin(counts)
        #dominant_labels.append(unique[max_idx])
        #least_dom_labels.append(unique[min_idx]/len(indexes)) #Normalize to 0-1
        count_perc.append(dict(zip(unique, counts)))
    #least_dom_labels = np.array(least_dom_labels)
    '''
    for idx in range(10):
        label_idx = yTrain[idx][0]
        print(label_idx)
        unique = count_perc[medoid_labels[idx]]
        label_perc = unique[label_idx]
        print(label_perc)
    '''

    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain[non_active_ids]).batch(192)

    preds = model.predict(val)

    pred_scores = np.zeros(len(yTrain))

    for i in range(len(preds)):
        pred_scores[non_active_ids[i]] = np.min(preds[i])

    alpha_balance = 0.3
    alpha_medoids = 0.5
    alpha_uncertain = 0.2
    smooth = 2
    sample_indexes = []
    all_label_count = np.zeros(3329)
    sampling_rate = 10 #Take top 10 samples
    num_iter = int(num_sample/sampling_rate)
    #all_dom_labels = []
    
    for i in tqdm(range(num_iter)):
        #Get 10 most suitable samples in each iteration
        exist_labels = yTrain[all_active_ids]
        unique, counts = np.unique(exist_labels, return_counts=True) #Get the counts for sample in S
        all_label_count[unique] = counts
        
        range_x2 = (np.max(all_label_count) - np.min(all_label_count) ) / len(exist_labels)
        min_prob = np.min(all_label_count) / len(exist_labels)
        
        #pred_probs = 1- np.min(preds, axis = 1)
        sample_scores = []
        
        for idx in non_active_ids:
            #all_dom_labels.append(least_dom_labels[medoid_labels[idx]])
            label_idx = yTrain[idx][0]
            label_perc = 1 - count_perc[medoid_labels[idx]][label_idx]
            uncertain_score = 1- pred_scores[idx] #Get reversed prob
            
            #pred_probs.append(pred_prob)
            sample_prob = all_label_count[label_idx] / len(exist_labels)
            x2 = 1-(sample_prob - min_prob)/ range_x2
            #print(label_perc)

            sample_score = alpha_medoids * math.exp( - label_perc * smooth) + alpha_balance * (1-math.exp(-x2 * smooth)) + alpha_uncertain * math.exp( - uncertain_score * smooth)
            sample_scores.append(sample_score)
        score = np.array(sample_scores)
        sample_index = np.argpartition(score, -sampling_rate)[-sampling_rate:]
        sample_index = non_active_ids[sample_index]
        #print(len(sample_index))
        #print(sample_index)
        #print(non_active_ids.shape)
        #print(sample_index.shape)
        non_active_ids = np.setdiff1d(non_active_ids, sample_index) #np.delete(non_active_ids, sampled_ids)
        all_active_ids = np.hstack((all_active_ids, sample_index))
        #print('-------------')
        #print(len(non_active_ids))
        #print(len(all_active_ids))
        sample_indexes = np.hstack((sample_indexes, sample_index))

    #min_idxs = np.argpartition(all_dom_labels, num_sample)[:num_sample]
    print(len(sample_indexes))
    #unique, counts = np.unique(sample_indexes, return_counts=True)
    #print(len(unique))
    return sample_indexes

def K_Medoids_Active(model, xTrain, yTrain, all_active_ids, non_active_ids, num_sample, medoids_path):
    #Method
    #Choose sample -> get medoids cluster
    #Assess the percentage of sample label in cluster BASED on ALREADY CHOSEN SAMPLES
    medoid_labels = np.load('200k_2000cluster/clara_labels.npy')
    dominant_labels = []
    count_dict = []
    cluster_length = []
    masked_index = np.zeros(len(yTrain))
    masked_index[all_active_ids] = 1
    print(medoid_labels.shape)
    print(masked_index.shape)
    chosen_label_wrt_medoids = np.zeros((2000, 3329)) #Num_label x num_medoids
    #test_sum = 0
    for i in tqdm(range(2000)):
        indexes = np.where(medoid_labels == i)[0]
        curr_label = yTrain[indexes]
        unique, counts = np.unique(curr_label, return_counts=True)
        chosen_indexes = []
        for idx in indexes:
            if masked_index[idx] == 1:
                chosen_indexes.append(idx)
        #print(len(chosen_indexes))
        chosen_indexes = np.array(chosen_indexes)
        if len(chosen_indexes) > 0:
            active_label = yTrain[chosen_indexes]
            unique_active, counts_active = np.unique(active_label, return_counts=True)
            chosen_label_wrt_medoids[i][unique_active] = counts_active
            #test_sum += np.sum(counts_active)
        max_idx = np.argmax(counts)
        min_idx = np.argmin(counts)
        #dominant_labels.append(unique[max_idx])
        #least_dom_labels.append(unique[min_idx]/len(indexes)) #Normalize to 0-1
        count_dict.append(dict(zip(unique, counts)))
        cluster_length.append(len(indexes))
        #print(np.max(counts))

    #print(test_sum)
    #print(np.sum(masked_index))
    #print(np.sum(chosen_label_wrt_medoids))
    #print(np.max(chosen_label_wrt_medoids))
    
    sample_scores = []
    for idx in non_active_ids:
        label_idx = yTrain[idx][0]
        label_perc = chosen_label_wrt_medoids[medoid_labels[idx]][label_idx]
        sample_scores.append(label_perc)
    min_idxs = np.argpartition(sample_scores, num_sample)[:num_sample]

    return non_active_ids[min_idxs]

def K_Medoids_Active_Balance(model, xTrain, yTrain, all_active_ids, non_active_ids, num_sample, medoids_path):
    #Method
    #Choose sample -> get medoids cluster
    #Assess the percentage of sample label in cluster BASED on ALREADY CHOSEN SAMPLES
    medoid_labels = np.load('200k_2000cluster/clara_labels.npy')
    dominant_labels = []
    count_dict = []
    cluster_length = []
    masked_index = np.zeros(len(yTrain))
    masked_index[all_active_ids] = 1
    print(medoid_labels.shape)
    print(masked_index.shape)
    chosen_label_wrt_medoids = np.zeros((2000, 3329)) #Num_label x num_medoids
    #test_sum = 0
    for i in tqdm(range(2000)):
        indexes = np.where(medoid_labels == i)[0]
        curr_label = yTrain[indexes]
        unique, counts = np.unique(curr_label, return_counts=True)
        chosen_indexes = []
        for idx in indexes:
            if masked_index[idx] == 1:
                chosen_indexes.append(idx)
        #print(len(chosen_indexes))
        chosen_indexes = np.array(chosen_indexes)
        if len(chosen_indexes) > 0:
            active_label = yTrain[chosen_indexes]
            unique_active, counts_active = np.unique(active_label, return_counts=True)
            chosen_label_wrt_medoids[i][unique_active] = counts_active
            #test_sum += np.sum(counts_active)
        max_idx = np.argmax(counts)
        min_idx = np.argmin(counts)
        #dominant_labels.append(unique[max_idx])
        #least_dom_labels.append(unique[min_idx]/len(indexes)) #Normalize to 0-1
        count_dict.append(dict(zip(unique, counts)))
        cluster_length.append(len(indexes))
        #print(np.max(counts))

    #print(test_sum)
    #print(np.sum(masked_index))
    #print(np.sum(chosen_label_wrt_medoids))
    #print(np.max(chosen_label_wrt_medoids))
    '''
    sample_scores = []
    for idx in non_active_ids:
        label_idx = yTrain[idx][0]
        label_perc = chosen_label_wrt_medoids[medoid_labels[idx]][label_idx]
        sample_scores.append(label_perc)
    min_idxs = np.argpartition(sample_scores, num_sample)[:num_sample]
    '''
    alpha = 0.5
    smooth = 2
    sample_indexes = []
    all_label_count = np.zeros(3329)
    sampling_rate = 10 #Take top 10 samples
    num_iter = int(num_sample/sampling_rate)
    #all_dom_labels = []
    
    for i in tqdm(range(num_iter)):
        #Get 10 most suitable samples in each iteration
        exist_labels = yTrain[all_active_ids]
        unique, counts = np.unique(exist_labels, return_counts=True) #Get the counts for sample in S
        all_label_count[unique] = counts
        
        range_x2 = (np.max(all_label_count) - np.min(all_label_count) ) / len(exist_labels)
        min_prob = np.min(all_label_count) / len(exist_labels)
        
        #pred_probs = 1- np.min(preds, axis = 1)
        sample_scores = []
        
        for idx in non_active_ids:
            #all_dom_labels.append(least_dom_labels[medoid_labels[idx]])
            label_idx = yTrain[idx][0]
            label_perc = 1 - chosen_label_wrt_medoids[medoid_labels[idx]][label_idx]
            #pred_prob = 1- np.min(preds[idx]) #Get reversed prob
            #pred_probs.append(pred_prob)
            sample_prob = all_label_count[label_idx] / len(exist_labels)
            x2 = 1-(sample_prob - min_prob)/ range_x2
            #print(label_perc)
            sample_score = alpha * math.exp( - label_perc * smooth) + (1-alpha) * (1-math.exp(-x2 * smooth))
            sample_scores.append(sample_score)
        score = np.array(sample_scores)
        sample_index = np.argpartition(score, -sampling_rate)[-sampling_rate:]
        sample_index = non_active_ids[sample_index]
        #Add chosen sample to chosen matrix
        for idx in sample_index:
            chosen_label_wrt_medoids[medoid_labels[idx]][label_idx] += 1
        #print(len(sample_index))
        #print(sample_index)
        #print(non_active_ids.shape)
        #print(sample_index.shape)
        non_active_ids = np.setdiff1d(non_active_ids, sample_index) #np.delete(non_active_ids, sampled_ids)
        all_active_ids = np.hstack((all_active_ids, sample_index))
        #print('-------------')
        #print(len(non_active_ids))
        #print(len(all_active_ids))
        sample_indexes = np.hstack((sample_indexes, sample_index))
    print(np.sum(chosen_label_wrt_medoids))
    #min_idxs = np.argpartition(all_dom_labels, num_sample)[:num_sample]
    print(len(sample_indexes))
    return sample_indexes

def K_Medoids_Active_Uncertain_Balance(model, xTrain, yTrain, all_active_ids, non_active_ids, num_sample, medoids_path):
    #Method
    #Choose sample -> get medoids cluster
    #Assess the percentage of sample label in cluster BASED on ALREADY CHOSEN SAMPLES
    medoid_labels = np.load('200k_2000cluster/clara_labels.npy')
    dominant_labels = []
    count_dict = []
    cluster_length = []
    masked_index = np.zeros(len(yTrain))
    masked_index[all_active_ids] = 1
    print(medoid_labels.shape)
    print(masked_index.shape)
    chosen_label_wrt_medoids = np.zeros((2000, 3329)) #Num_label x num_medoids
    #test_sum = 0
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain[non_active_ids]).batch(192)

    preds = model.predict(val)

    pred_scores = np.zeros(len(yTrain))

    for i in range(len(preds)):
        pred_scores[non_active_ids[i]] = np.min(preds[i])

    for i in tqdm(range(2000)):
        indexes = np.where(medoid_labels == i)[0]
        curr_label = yTrain[indexes]
        unique, counts = np.unique(curr_label, return_counts=True)
        chosen_indexes = []
        for idx in indexes:
            if masked_index[idx] == 1:
                chosen_indexes.append(idx)
        #print(len(chosen_indexes))
        chosen_indexes = np.array(chosen_indexes)
        if len(chosen_indexes) > 0:
            active_label = yTrain[chosen_indexes]
            unique_active, counts_active = np.unique(active_label, return_counts=True)
            chosen_label_wrt_medoids[i][unique_active] = counts_active
            #test_sum += np.sum(counts_active)
        max_idx = np.argmax(counts)
        min_idx = np.argmin(counts)
        #dominant_labels.append(unique[max_idx])
        #least_dom_labels.append(unique[min_idx]/len(indexes)) #Normalize to 0-1
        count_dict.append(dict(zip(unique, counts)))
        cluster_length.append(len(indexes))
        #print(np.max(counts))

    #print(test_sum)
    #print(np.sum(masked_index))
    #print(np.sum(chosen_label_wrt_medoids))
    #print(np.max(chosen_label_wrt_medoids))
    '''
    sample_scores = []
    for idx in non_active_ids:
        label_idx = yTrain[idx][0]
        label_perc = chosen_label_wrt_medoids[medoid_labels[idx]][label_idx]
        sample_scores.append(label_perc)
    min_idxs = np.argpartition(sample_scores, num_sample)[:num_sample]
    '''
    #alpha = 0.5
    alpha_balance = 0.3
    alpha_medoids = 0.5
    alpha_uncertain = 0.2
    smooth = 2
    sample_indexes = []
    all_label_count = np.zeros(3329)
    sampling_rate = 10 #Take top 10 samples
    num_iter = int(num_sample/sampling_rate)
    #all_dom_labels = []
    
    for i in tqdm(range(num_iter)):
        #Get 10 most suitable samples in each iteration
        exist_labels = yTrain[all_active_ids]
        unique, counts = np.unique(exist_labels, return_counts=True) #Get the counts for sample in S
        all_label_count[unique] = counts
        
        range_x2 = (np.max(all_label_count) - np.min(all_label_count) ) / len(exist_labels)
        min_prob = np.min(all_label_count) / len(exist_labels)
        
        #pred_probs = 1- np.min(preds, axis = 1)
        sample_scores = []
        
        for idx in non_active_ids:
            #all_dom_labels.append(least_dom_labels[medoid_labels[idx]])
            label_idx = yTrain[idx][0]
            label_perc = 1 - chosen_label_wrt_medoids[medoid_labels[idx]][label_idx]
            uncertain_score = 1- pred_scores[idx] #Get reversed prob
            #pred_prob = 1- np.min(preds[idx]) #Get reversed prob
            #pred_probs.append(pred_prob)
            sample_prob = all_label_count[label_idx] / len(exist_labels)
            x2 = 1-(sample_prob - min_prob)/ range_x2
            #print(label_perc)
            sample_score = alpha_medoids * math.exp( - label_perc * smooth) + alpha_uncertain * math.exp( - uncertain_score * smooth) +alpha_balance * (1-math.exp(-x2 * smooth))
            sample_scores.append(sample_score)
        score = np.array(sample_scores)
        sample_index = np.argpartition(score, -sampling_rate)[-sampling_rate:]
        sample_index = non_active_ids[sample_index]
        #Add chosen sample to chosen matrix
        for idx in sample_index:
            chosen_label_wrt_medoids[medoid_labels[idx]][label_idx] += 1
        #print(len(sample_index))
        #print(sample_index)
        #print(non_active_ids.shape)
        #print(sample_index.shape)
        non_active_ids = np.setdiff1d(non_active_ids, sample_index) #np.delete(non_active_ids, sampled_ids)
        all_active_ids = np.hstack((all_active_ids, sample_index))
        #print('-------------')
        #print(len(non_active_ids))
        #print(len(all_active_ids))
        sample_indexes = np.hstack((sample_indexes, sample_index))
    print(np.sum(chosen_label_wrt_medoids))
    #min_idxs = np.argpartition(all_dom_labels, num_sample)[:num_sample]
    print(len(sample_indexes))
    return sample_indexes

def K_Medoids_Update(all_active_ids, sampled_ids, medoids_path):
    medoid_labels = np.load('200k_2000cluster/clara_labels.npy')
    #Get all possible cluster labels
    sampled_ids = medoid_labels[sampled_ids]
    unique_labels, counts = np.unique(sampled_ids, return_counts=True)

    update_indexes = []
    for label in unique_labels:
        label_idx = np.where(medoid_labels[all_active_ids] == label)[0]
        update_indexes = np.concatenate([update_indexes, label_idx])
    print(update_indexes.shape)
    update_indexes = update_indexes.astype(int)

    return all_active_ids[update_indexes]

def K_Medoids_Label_Update(all_active_ids, sampled_ids, yTrain, medoids_path):
    medoid_labels = np.load('200k_2000cluster/clara_labels.npy')
    sampled_medoids = medoid_labels[sampled_ids]
    #Get all possible cluster labels
    #unique_labels, counts = np.unique(sampled_ids, return_counts=True)

    '''
    update_indexes = []
    for label in unique_labels:
        label_idx = np.where(medoid_labels[all_active_ids] == label)[0]
        update_indexes = np.concatenate([update_indexes, label_idx])
    print(update_indexes.shape)
    update_indexes = update_indexes.astype(int)
    '''
    update_indexes = []
    for sampled_id in tqdm(sampled_ids):
        for act_idx in all_active_ids:
            if yTrain[sampled_id] == yTrain[act_idx] and medoid_labels[sampled_id] == medoid_labels[act_idx]:
                update_indexes.append(act_idx)
        all_active_ids = np.setdiff1d(all_active_ids, np.array(update_indexes))
        #print(len(update_indexes))
    print(len(update_indexes))
    update_indexes = np.concatenate((np.array(update_indexes), sampled_ids))
    print(len(update_indexes))
    return update_indexes

#(model, xTrain_original, yTrain_original,all_active_ids, non_active_ids, num_sample = args.num_sample)
def margin_sampling_by_balance(model, xTrain, yTrain, all_active_ids, non_active_ids, num_sample):
    xTrain_in = xTrain[all_active_ids]
    yTrain_in = yTrain[all_active_ids]
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_in).batch(192)

    preds = model.predict(val)
    print(preds[0])
    #test_arr = np.array([1,2,3,4,6,5])
    #max_idx = np.partition(test_arr, -2)[-2:]
    #print(np.abs(max_idx[0]-max_idx[1]))
    #exit()
    pred_index = np.arange(len(preds))
    margins = []
    for idx in pred_index:
        label_idx = yTrain_in[idx]
        max_idx = np.partition(preds[idx], -2)[-2:] #Get margin of 2 most likely class
        margin = np.abs(max_idx[0]-max_idx[1])
        margins.append(margin)
        #print(preds[idx][label_idx].shape) 
    pred_scores = np.zeros(len(yTrain))

    for i in range(len(preds)):
        pred_scores[non_active_ids[i]] = margins[i]

    alpha = 0.5
    smooth = 2
    sample_indexes = []
    all_label_count = np.zeros(3329)
    sampling_rate = 10 #Take top 10 samples
    num_iter = int(num_sample/sampling_rate)
    #all_dom_labels = []
    
    for i in tqdm(range(num_iter)):
        #Get 10 most suitable samples in each iteration
        exist_labels = yTrain[all_active_ids]
        unique, counts = np.unique(exist_labels, return_counts=True) #Get the counts for sample in S
        all_label_count[unique] = counts
        
        range_x2 = (np.max(all_label_count) - np.min(all_label_count) ) / len(exist_labels)
        min_prob = np.min(all_label_count) / len(exist_labels)
        
        #pred_probs = 1- np.min(preds, axis = 1)
        sample_scores = []
        
        for idx in non_active_ids:
            #all_dom_labels.append(least_dom_labels[medoid_labels[idx]])
            label_idx = yTrain[idx][0]
            uncertain_score = 1- pred_scores[idx] #Get reversed prob
            
            #pred_probs.append(pred_prob)
            sample_prob = all_label_count[label_idx] / len(exist_labels)
            x2 = 1-(sample_prob - min_prob)/ range_x2
            #print(label_perc)

            sample_score =  (1-alpha) * (1-math.exp(-x2 * smooth)) + alpha * math.exp( - uncertain_score * smooth)
            sample_scores.append(sample_score)
        score = np.array(sample_scores)
        sample_index = np.argpartition(score, -sampling_rate)[-sampling_rate:]
        sample_index = non_active_ids[sample_index]
        #print(len(sample_index))
        #print(sample_index)
        #print(non_active_ids.shape)
        #print(sample_index.shape)
        non_active_ids = np.setdiff1d(non_active_ids, sample_index) #np.delete(non_active_ids, sampled_ids)
        all_active_ids = np.hstack((all_active_ids, sample_index))
        #print('-------------')
        #print(len(non_active_ids))
        #print(len(all_active_ids))
        sample_indexes = np.hstack((sample_indexes, sample_index))

    #min_idxs = np.argpartition(all_dom_labels, num_sample)[:num_sample]
    print(len(sample_indexes))
    unique, counts = np.unique(sample_indexes, return_counts=True)
    print(len(unique))
    return sample_indexes

def margin_label_sampling_by_balance(model, xTrain, yTrain, all_active_ids, non_active_ids, num_sample):
    xTrain_in = xTrain[all_active_ids]
    yTrain_in = yTrain[all_active_ids]
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_in).batch(192)

    preds = model.predict(val)
    print(preds[0])
    #test_arr = np.array([1,2,3,4,6,5])
    #max_idx = np.partition(test_arr, -2)[-2:]
    #print(np.abs(max_idx[0]-max_idx[1]))
    #exit()
    pred_index = np.arange(len(preds))
    margins = []
    for idx in pred_index:
        label_idx = yTrain_in[idx]
        max_idx = np.partition(preds[idx], -2)[-2:] #Get margin of 2 most likely class
        if preds[idx][label_idx] not in max_idx:
            margin = np.abs(preds[idx][label_idx][0] - np.max(max_idx))
        else:
            margin = np.abs(max_idx[0]-max_idx[1])
        margins.append(margin)
        #print(preds[idx][label_idx].shape) 
    pred_scores = np.zeros(len(yTrain))

    for i in range(len(preds)):
        pred_scores[non_active_ids[i]] = margins[i]

    alpha = 0.5
    smooth = 2
    sample_indexes = []
    all_label_count = np.zeros(3329)
    sampling_rate = 10 #Take top 10 samples
    num_iter = int(num_sample/sampling_rate)
    #all_dom_labels = []
    
    for i in tqdm(range(num_iter)):
        #Get 10 most suitable samples in each iteration
        exist_labels = yTrain[all_active_ids]
        unique, counts = np.unique(exist_labels, return_counts=True) #Get the counts for sample in S
        all_label_count[unique] = counts
        
        range_x2 = (np.max(all_label_count) - np.min(all_label_count) ) / len(exist_labels)
        min_prob = np.min(all_label_count) / len(exist_labels)
        
        #pred_probs = 1- np.min(preds, axis = 1)
        sample_scores = []
        
        for idx in non_active_ids:
            #all_dom_labels.append(least_dom_labels[medoid_labels[idx]])
            label_idx = yTrain[idx][0]
            uncertain_score = 1- pred_scores[idx] #Get reversed prob
            
            #pred_probs.append(pred_prob)
            sample_prob = all_label_count[label_idx] / len(exist_labels)
            x2 = 1-(sample_prob - min_prob)/ range_x2
            #print(label_perc)

            sample_score =  (1-alpha) * (1-math.exp(-x2 * smooth)) + alpha * math.exp( - uncertain_score * smooth)
            sample_scores.append(sample_score)
        score = np.array(sample_scores)
        sample_index = np.argpartition(score, -sampling_rate)[-sampling_rate:]
        sample_index = non_active_ids[sample_index]
        #print(len(sample_index))
        #print(sample_index)
        #print(non_active_ids.shape)
        #print(sample_index.shape)
        non_active_ids = np.setdiff1d(non_active_ids, sample_index) #np.delete(non_active_ids, sampled_ids)
        all_active_ids = np.hstack((all_active_ids, sample_index))
        #print('-------------')
        #print(len(non_active_ids))
        #print(len(all_active_ids))
        sample_indexes = np.hstack((sample_indexes, sample_index))

    #min_idxs = np.argpartition(all_dom_labels, num_sample)[:num_sample]
    print(len(sample_indexes))
    unique, counts = np.unique(sample_indexes, return_counts=True)
    print(len(unique))
    print(sample_indexes[:10])
    return sample_indexes

from scipy.special import rel_entr
from math import log2
 
# calculate the kl divergence
def kl_divergence(p, q):
    return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))

# calculate the js divergence
def js_divergence(p, q):
    m = 0.5 * (p + q)
    return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)

def epsilon_onehot(label, epsilon):
    g_x = np.zeros(NumSKPVclasses)

    for i in range(len(g_x)):
        if i != label:
            g_x[i] = epsilon
        else:
            g_x[i] = 1 - (NumSKPVclasses - 1) * epsilon

    return g_x

def KL_Update(model, xTrain, yTrain, all_active_ids, sampled_ids, use_medoids=True, KL_STD = 1):
    epsilon = 1e-10
    #Get T/S, to get samples from a pool to reinforce these sample
    print(len(np.unique(all_active_ids)))
    if use_medoids:
        all_active_ids = np.setdiff1d(all_active_ids, sampled_ids)
    #print('XXXXXXXXXXXXX')
    #print(len(all_active_ids))
    #print(len(np.unique(all_active_ids)))

    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain[all_active_ids]).batch(192)
        #yTrain_onehot = tf.one_hot(yTrain, 3329)
    preds = model.predict(val)

    #update_set = sampled_ids
    all_kld = []
    i = 0
    for idx in tqdm(all_active_ids):
        l_x = preds[i]
        g_x = epsilon_onehot(yTrain[idx], epsilon)
        #print(l_x)
        #print(g_x)
        #print(g_x.shape)
        #print(np.sum(l_x))
        KLD = kl_divergence(l_x, g_x)
        all_kld.append(KLD)
        i +=1
        #KLD2 = js_divergence(l_x, g_x)
        #print(KLD)
        #print(l_x.shape)
        #print(yTrain[idx].shape)

    all_kld = np.array(all_kld)
    mean_kl = np.mean(all_kld)
    std_kl = np.std(all_kld)

    chosen_indexes = np.where(all_kld > mean_kl + KL_STD * std_kl)[0]
    update_set = all_active_ids[chosen_indexes]
    #update_set = np.concatenate((sampled_ids, chosen_active_idxs))
    #print(len(sampled_ids))
    #print(len(chosen_indexes))
    #print(len(update_set))
    unique = np.unique(update_set)
    #print(len(unique))
    #exit()
    return update_set


def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

def unceratainty_sampling(model, xTrain_in, yTrain_in, num_sample):
    preds = model.predict(xTrain_in)
    pred_probs = np.min(preds, axis = 1)
    idx = 0
    #Get samples that below the mean probability
    '''
    pred_probs = []
    for sample in preds:
        #label_idx = np.where(yTrain_in[idx]==1)
        pred_prob = np.min(preds[idx])
        #idx += 1
        pred_probs.append(pred_prob)
    pred_probs = np.squeeze(np.array(pred_probs))
    '''
    #max_idxs = np.argpartition(pred_probs, -num_sample)[-num_sample:]
    min_idxs = np.argpartition(pred_probs, num_sample)[:num_sample]

    return min_idxs

def margin_sampling(model, xTrain_in, yTrain_in, num_sample):
    preds = model.predict(xTrain_in)
    #pred_probs = np.min(preds, axis = 1)
    idx = 0
    #Get samples that below the mean probability
    pred_index = np.arange(len(preds))
    margins = []
    for idx in pred_index:
        label_idx = yTrain_in[idx]
        max_idx = np.partition(preds[idx], -2)[-2:] #Get margin of 2 most likely class
        if preds[idx][label_idx] not in max_idx:
            margin = np.abs(preds[idx][label_idx][0] - np.max(max_idx))
        else:
            margin = np.abs(max_idx[0]-max_idx[1])
        margins.append(margin)
        #print(preds[idx][label_idx].shape) 
    pred_probs = margins
    print(len(pred_probs))
    #max_idxs = np.argpartition(pred_probs, -num_sample)[-num_sample:]
    min_idxs = np.argpartition(pred_probs, num_sample)[:num_sample]

    return min_idxs

from sklearn.preprocessing import StandardScaler
from scipy import stats

def normalize(timeseries):
    return (timeseries-timeseries.min())/(timeseries.max()-timeseries.min())

def z_norm(timeseries):
    '''
    timeseries = np.expand_dims(timeseries, 1)
    scaler = StandardScaler()
    scaler.fit(timeseries)
    res = scaler.transform(timeseries)[:,0]
    '''

    return stats.zscore(timeseries)

def normalize_data_per_trace(data):
    print(data.shape)
    for i in range(len(data)):
        data[i] = normalize(data[i])

    return data

def z_normalize_data_per_trace(data):
    print(data.shape)
    for i in range(len(data)):
        data[i] = z_norm(data[i])

    return data

def cummulative_transform(data):
    for i in range(len(data)):
        data[i] = np.cumsum(data[i], dtype=float)
    return data


def create_training_data_optimize(args, data_path, sKeyNo, trainPortion, xType, yType, is_test, start_trace, end_trace):
    val_num = 300000
    val_ids = np.load('val_ids.npz')['ids']
    print(len(val_ids))
    print(val_ids[:10])
    end_val_trace = end_trace + val_num
    if is_test:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file_from_test(data_path, sKeyNo)
    else:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_files(data_path, start_trace, end_val_trace)

    #train_dataset = tf.data.Dataset.from_tensor_slices((trace_profiling[:150000], tf.one_hot(skpv_profiling[:150000], 3329) ))
    #exit()

    #Reshaped_trace_profiling= trace_profiling.reshape((trace_profiling.shape[0], trace_profiling.shape[1], 1))
    print(trace_profiling.shape)


    dataSize = end_trace - start_trace
    trainSize = math.floor(dataSize * trainPortion)
    #print(skpv_profiling[:10])
    y_train_skpv = to_categorical(skpv_profiling, num_classes=NumSKPVclasses)

    if args.normalize == 0:
        xTrain = np.expand_dims(trace_profiling[:dataSize], axis = 2)
    elif args.normalize == 1:
        if args.cummulative == 1:
            xTrain = np.expand_dims(normalize_data_per_trace(cummulative_transform(trace_profiling[:dataSize])), axis = 2)
        elif args.cummulative == 2:
            xTrain = np.expand_dims(cummulative_transform(normalize_data_per_trace(trace_profiling[:dataSize])), axis = 2)
        else:
            xTrain = np.expand_dims(normalize_data_per_trace(trace_profiling[:dataSize]), axis = 2)
    else:
        if args.cummulative == 1:
            xTrain = np.expand_dims(z_normalize_data_per_trace(cummulative_transform(trace_profiling[:dataSize])), axis = 2)
        elif args.cummulative == 2:
            xTrain = np.expand_dims(cummulative_transform(z_normalize_data_per_trace(trace_profiling[:dataSize])), axis = 2)
        else:
            xTrain = np.expand_dims(z_normalize_data_per_trace(trace_profiling[:dataSize]), axis = 2)
        #print(xTrain.shape)
    print(xTrain[0][:10])
    
    if is_test:
        xVal = np.expand_dims(trace_profiling[dataSize:end_val_trace], axis = 2)
        yVal = y_train_skpv[dataSize:end_val_trace]
        yVal_value = skpv_profiling[dataSize:end_val_trace]
    else:
        xVal = np.expand_dims(trace_profiling[val_ids], axis = 2)
        yVal = y_train_skpv[val_ids]
        yVal_value = skpv_profiling[val_ids]

    yTrain = y_train_skpv[:dataSize]
    yTrain_value = skpv_profiling[:dataSize]
    

    #del trace_profiling, bp_profiling, skpv_profiling
    return xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value

def train_model_multiEpochs(xType, database_folder_train, modelLogFolder, logTrainedModel_byFile_folder, logTrainedModel_byEp_folder, logFilename, MLmodel_detail, sKeyNo, class_weight, period, maxEpochs, train_batch_size, args):
    print(args.train_type)
    print(trained_model_path)
    if args.train_type == 'baseline':
        model = subModels_gen(xType, noConv1Dbranch, noBPbranch, noLayers, tracelen, NumBPinput, MLmodel_detail, modelLogFolder, logFilename, classes=noClasses)
    else:
        print('load model from {}'.format(trained_model_path))
        model = load_model(trained_model_path)
    gc.collect()
    model.summary()

    print('load success')
    trainPortion = 1.0
    
    xTrain_original, yTrain_original, xVal, yVal, yTrain_value, yVal_value = create_training_data_optimize(args,data_path,sKeyNo, trainPortion, xType, yType,False, args.start_trace, args.end_trace)
    #xTrain_Pool, yTrain_Pool, _, _, _, _ = create_training_data_form(data_path,sKeyNo, trainPortion, xType, yType,False, args.end_trace, 100000)
    yTrain_original = np.expand_dims(yTrain_value, axis = 1)
    yVal = yVal_value
    xTest, yTest, _xVal_, yVal_ , yTest_value, yVal_value = create_training_data_optimize(args,"KYBER51.H5", sKeyNo, testPortion, xType, yType,True, args.start_trace, args.end_trace)
    xTest = [[xTest]]
    print(xTest[0][0].shape)
    print(yTest_value)
    print(yTest.shape)
    xTest_multi, yTest_multi = load_multi_attack(args.eval_path)
    print(yTest_value)
    print(yTest_multi)
    xTest_multi = np.expand_dims(xTest_multi, axis = 3)
    print(xTest_multi.shape)

    #print(yTrain.shape)
    csv_logger = CSVLogger(filename=database_folder_train+'/log.csv', append=True, separator=';')
    #all_active_ids = np.load(args.all_ids) #Load existing samples
    #Start from scrath by selecting random sample
    all_sample_ids = np.arange(args.start_trace, args.end_trace)
    all_active_ids = np.random.choice(all_sample_ids, args.num_sample, replace=False)
    all_active_ids = all_active_ids.astype(int)

    unique = np.unique(all_active_ids)
    print(len(unique))

    non_active_ids = np.setdiff1d(all_sample_ids, all_active_ids)
    print(all_active_ids.shape)
    print(non_active_ids.shape)
    #print(len(all_sample_ids))
    #print(len(non_active_ids))
    #print(non_active_ids[:10])
    #print(len(xTrain[0][0]))
    #print(xTrain[0][0][non_active_ids[:10]])
    #xTrain_Pool_original, yTrain_Pool_original = copy.deepcopy(xTrain), copy.deepcopy(yTrain) #Need them for correct index
    
    #xTrain_Pool = [[xTrain_Pool]]
    xTrain_Pool, yTrain_Pool = xTrain_original[non_active_ids], yTrain_original[non_active_ids]
    #print(len(xTrain_Pool[0][0]))
    xTrain, yTrain = xTrain_original[all_active_ids], yTrain_original[all_active_ids]
    xTrainset, yTrainset = np.copy(xTrain), np.copy(yTrain)
    yTrainset = np.squeeze(yTrainset)
    print(len(xTrain))
    print('Len training:')
    print(xTrain.shape)
    print(yTrain.shape)
    #Set seed for fair comparison
    #np.random.seed(0)
    all_random_ids = np.random.choice(non_active_ids, args.num_sample * args.num_iteration, replace=False)

    start_time = time.time()
    start_it = 0
    if args.resume_it is not None:
        start_it = args.resume_it + 1
    for it in range(start_it, args.num_iteration):
        print('Iteration: ', it)
        database_folder_train_it = os.path.join(database_folder_train, 'it_'+str(it))
        Path(database_folder_train_it).mkdir(parents=True, exist_ok=True)
        save_model_name = (database_folder_train_it+'/model_best')
        save_ids_name = database_folder_train_it+'/all_ids.npy'
        attack_callback = CustomCallback(xTest, yTest_value, save_model_name, database_folder_train_it, args.eval_interval)
        attack_multi_callback = MultiKeyCallback(xTest_multi, yTest_multi, save_model_name, database_folder_train_it, args.eval_interval)

        callbacks=[csv_logger, attack_callback, attack_multi_callback]
        #callbacks=[csv_logger, attack_callback]

        max_ids = [] #Init all ids, the ids depend on the sampling method
        #Need to fix so they dont choose alrady chosen samples
        sampled_ids = []
        print(args.sampling)
        print(xTrain_Pool.shape)
        print(yTrain_Pool.shape)
        
        if args.sampling == 'random':
            max_ids = np.random.choice(non_active_ids, args.num_sample, replace=False)
            sampled_ids = max_ids.astype(int)
        elif args.sampling == 'uncertainty':
            max_ids = unceratainty_sampling(model, xTrain_Pool, yTrain_Pool, num_sample = args.num_sample)
            sampled_ids = non_active_ids[max_ids.astype(int)]
        elif args.sampling == 'margin':
            max_ids = margin_sampling(model, xTrain_Pool, yTrain_Pool, num_sample = args.num_sample)
            sampled_ids = non_active_ids[max_ids.astype(int)]
        elif args.sampling == 'uncertainty_label':
            max_ids = unceratainty_sampling_by_label(model, xTrain_Pool, yTrain_Pool, num_sample = args.num_sample)
            sampled_ids = non_active_ids[max_ids.astype(int)]
        elif args.sampling == 'uncertainty_balance':
            max_ids = unceratainty_sampling_by_balance(model, xTrain_Pool, yTrain_Pool, yTrain, args.num_sample, args.alpha)
            sampled_ids = non_active_ids[max_ids.astype(int)]
        elif args.sampling == 'uncertainty_balance_label':
            max_ids = unceratainty_sampling_by_balance_label(model, xTrain_Pool, yTrain_Pool, yTrain, num_sample = args.num_sample)
            sampled_ids = non_active_ids[max_ids.astype(int)]
        elif args.sampling == 'uncertainty_balance_medoids':
            max_ids = unceratainty_sampling_by_balance_medoids(model, xTrain_original, yTrain_original, all_active_ids, args.num_sample, args.medoids_path)
            sampled_ids = non_active_ids[max_ids.astype(int)]
        elif args.sampling == 'k_medoids':
            max_ids = K_Medoids_sampling(model, xTrain_original, yTrain_original, non_active_ids, args.num_sample, args.medoids_path)
            sampled_ids = non_active_ids[max_ids.astype(int)]
        elif args.sampling == 'k_medoids_balance':
            sampled_ids = K_Medoids_sampling_balance(model, xTrain_original, yTrain_original,all_active_ids, non_active_ids, args.num_sample, args.medoids_path)
            sampled_ids = sampled_ids.astype(int)
        elif args.sampling == 'k_medoids_uncertainty_balance':
            sampled_ids = K_Medoids_sampling_uncertainty_balance(model, xTrain_original, yTrain_original,all_active_ids, non_active_ids, args.num_sample, args.medoids_path)
            sampled_ids = sampled_ids.astype(int)
        elif args.sampling == 'k_medoids_active':
            sampled_ids = K_Medoids_Active(model, xTrain_original, yTrain_original,all_active_ids, non_active_ids, args.num_sample, args.medoids_path)
            sampled_ids = sampled_ids.astype(int)
        elif args.sampling == 'k_medoids_active_balance':
            sampled_ids = K_Medoids_Active_Balance(model, xTrain_original, yTrain_original,all_active_ids, non_active_ids, args.num_sample, args.medoids_path)
            sampled_ids = sampled_ids.astype(int)
        elif args.sampling == 'k_medoids_active_uncertainty_balance':
            sampled_ids = K_Medoids_Active_Uncertain_Balance(model, xTrain_original, yTrain_original,all_active_ids, non_active_ids, args.num_sample, args.medoids_path)
            sampled_ids = sampled_ids.astype(int)
        elif args.sampling == 'margin_balance':
            sampled_ids = margin_sampling_by_balance(model, xTrain_original, yTrain_original,all_active_ids, non_active_ids, num_sample = args.num_sample)
            sampled_ids = sampled_ids.astype(int)
        elif args.sampling == 'margin_label_balance':
            sampled_ids = margin_label_sampling_by_balance(model, xTrain_original, yTrain_original,all_active_ids, non_active_ids, num_sample = args.num_sample)
            sampled_ids = sampled_ids.astype(int)
        elif args.sampling == 'random_fair':
            sampled_ids = all_random_ids[it * args.num_sample : (it+1) * args.num_sample]
            sampled_ids = sampled_ids.astype(int)
        elif args.sampling == 'k_center':
            #sampled_ids = K_Centre_sampling(model, xTrain_original, yTrain_original, all_active_ids, non_active_ids, args.num_sample)
            #sampled_ids = sampled_ids.astype(int)
            kc_sampler = kCenterGreedy(xTrain_original, yTrain_original)
            sampled_ids = kc_sampler.select_batch_(model, all_active_ids, args.num_sample)
            

        '''
        if it == 0:
            all_ids = max_ids
        else:
            all_active_ids = np.concatenate((all_active_ids, max_ids))
        '''
        #TODO: Save labels
        #sub_xTrain, sub_yTrain = get_subset(max_ids.astype(int), xTrain_Pool, yTrain_Pool, xType)
        sub_xTrain, sub_yTrain = xTrain_original[sampled_ids], yTrain_original[sampled_ids]
        non_active_ids = np.setdiff1d(non_active_ids, sampled_ids) #np.delete(non_active_ids, sampled_ids)
        all_active_ids = np.hstack((all_active_ids, sampled_ids))
        print(non_active_ids[0])
        print(all_active_ids[0])
        print(len(non_active_ids))
        print(len(all_active_ids))
        #xTrain_Pool[0][0], yTrain_Pool = np.delete(xTrain_Pool[0][0], max_ids, 0), np.delete(yTrain_Pool, max_ids, 0) #---> If you want to remove samples in pool each iter
        #xTrain, yTrain = np.concatenate((xTrain, sub_xTrain)), np.concatenate((yTrain, sub_yTrain))

        #xTrain[0][0] = np.concatenate((xTrain[0][0], sub_xTrain[0][0]))
        #yTrain = np.concatenate((yTrain, sub_yTrain))

        xTrain_Pool, yTrain_Pool = xTrain_original[non_active_ids], yTrain_original[non_active_ids]
        #print(len(xTrain_Pool[0][0]))
        if args.update_sampling == 'k_medoids':
            update_ids = K_Medoids_Update(all_active_ids, sampled_ids, args.medoids_path)
            np.save(database_folder_train_it+'/medoids_ids.npy', update_ids)
            xTrain, yTrain = xTrain_original[update_ids], yTrain_original[update_ids]
        elif args.update_sampling == 'k_medoids_label':
            update_ids= K_Medoids_Label_Update(all_active_ids, sampled_ids, yTrain_original, args.medoids_path)
            np.save(database_folder_train_it+'/medoids_ids.npy', update_ids)
            xTrain, yTrain = xTrain_original[update_ids], yTrain_original[update_ids]
        elif args.update_sampling == 'KL':
            update_ids= KL_Update(model, xTrain_original, yTrain_original, all_active_ids, sampled_ids, use_medoids=False)
            np.save(database_folder_train_it+'/medoids_KL_ids.npy', update_ids)
            xTrain, yTrain = xTrain_original[update_ids], yTrain_original[update_ids]
        elif args.update_sampling == 'medoids_KL':
            medoids_ids=K_Medoids_Update(all_active_ids, sampled_ids, args.medoids_path)
            np.save(database_folder_train_it+'/medoids_ids.npy', medoids_ids)
            KL_ids= KL_Update(model, xTrain_original, yTrain_original, all_active_ids, medoids_ids, KL_STD=args.kl_std)
            np.save(database_folder_train_it+'/medoids_KL_ids.npy', KL_ids)
            update_ids = np.concatenate((medoids_ids, KL_ids))
            print(len(update_ids))
            xTrain, yTrain = xTrain_original[update_ids], yTrain_original[update_ids]
        elif args.update_sampling == 'medoids_KL_label':
            medoids_ids= K_Medoids_Label_Update(all_active_ids, sampled_ids, yTrain_original, args.medoids_path)
            np.save(database_folder_train_it+'/medoids_ids.npy', medoids_ids)
            KL_ids= KL_Update(model, xTrain_original, yTrain_original, all_active_ids, medoids_ids, KL_STD=args.kl_std)
            np.save(database_folder_train_it+'/medoids_KL_ids.npy', KL_ids)
            update_ids = np.concatenate((medoids_ids, KL_ids))
            xTrain, yTrain = xTrain_original[update_ids], yTrain_original[update_ids]
        elif args.update_sampling == 'medoids_KL_cluster':
            medoids_ids=K_Medoids_Update(all_active_ids, sampled_ids, args.medoids_path)
            np.save(database_folder_train_it+'/medoids_ids.npy', medoids_ids)
            #concat medoids ids and KL ids
            print('---------------')
            print(len(medoids_ids))
            KL_ids= KL_Update(model, xTrain_original, yTrain_original, all_active_ids, medoids_ids, KL_STD=args.kl_std)
            np.save(database_folder_train_it+'/medoids_KL_ids.npy', KL_ids)
            #Get C(KL): Cluster-related samples of cluster
            print(len(KL_ids))
            KL_cluster_ids=K_Medoids_Update(all_active_ids, KL_ids, args.medoids_path)
            print(len(KL_cluster_ids))
            np.save(database_folder_train_it+'/medoids_KL_cluster_ids.npy', KL_cluster_ids)
            update_ids = np.concatenate((medoids_ids, KL_cluster_ids))
            print(len(update_ids))
            xTrain, yTrain = xTrain_original[update_ids], yTrain_original[update_ids]
        else:
            xTrain, yTrain = xTrain_original[all_active_ids], yTrain_original[all_active_ids]
        yTrain = np.squeeze(yTrain)
        sub_yTrain = np.squeeze(sub_yTrain)
        all_counts = np.zeros(3329)
        unique, counts = np.unique(sub_yTrain, return_counts=True)
        for value, number in zip(unique,counts):
            all_counts[value] = number

        print(xTrain.shape)
        print(yTrain.shape)
        print(yVal.shape)
        print(sub_yTrain.shape)
        print(sub_yTrain[10])
        print('STD')
        print(np.std(all_counts))
            
        #print(yTrain.shape)
        print(yTrain[-10:])
        #exit()
        #exit()


        #epoch_step = int(args.num_epoch/args.train_interval)
        with tf.device("CPU"):
            #train = tf.data.Dataset.from_tensor_slices(({"input_1": xTrain[0][0], "input_2": xTrain[1][0], "input_3": xTrain[1][1]}, yTrain)).shuffle(4*64).batch(64)
            #val = tf.data.Dataset.from_tensor_slices(({"input_1": xVal[0][0], "input_2": xVal[1][0], "input_3": xVal[1][1]}, yVal)).shuffle(4*64).batch(64)
            if xType == 'wave':
                train = tf.data.Dataset.from_tensor_slices((xTrain, tf.one_hot(yTrain, NumSKPVclasses))).shuffle(4*64).batch(train_batch_size)
                val = tf.data.Dataset.from_tensor_slices((xVal, tf.one_hot(yVal, NumSKPVclasses))).shuffle(4*64).batch(train_batch_size)
            #if args.schedule_iteration is not None:
                #if it in args.schedule_iteration: #I made the train here to be eval to assess and avoid writing a bunch of code for verification                        #np.save(save_ids_name, sampled_ids)
                #else:
                    #np.save(save_ids_name, all_active_ids)
            #else:
        np.save(save_ids_name, all_active_ids)
            #elif xType == 'wavebp01':
            #    train = tf.data.Dataset.from_tensor_slices(({"input_layer": xTrain[0][0], "input_layer_1": xTrain[1][0], "input_layer_2": xTrain[1][1]}, yTrain)).shuffle(4*64).batch(train_batch_size)
            #    val = tf.data.Dataset.from_tensor_slices(({"input_layer": xVal[0][0], "input_layer_1": xVal[1][0], "input_layer_2": xVal[1][1]}, yVal)).shuffle(4*64).batch(train_batch_size)
        train_ep = args.subtrain_interval
        '''
        if it in args.schedule_iteration:
            print('XXXXXXXXXXXXX')
            model = subModels_gen(xType, noConv1Dbranch, noBPbranch, noLayers, tracelen, NumBPinput, MLmodel_detail, modelLogFolder, logFilename, classes=noClasses)
            model.summary()
        '''
        #Retrain model to fit with original paper, since using raw data -> train only on last layer
        if it == args.num_iteration-1:
            model = subModels_gen(xType, noConv1Dbranch, noBPbranch, noLayers, tracelen, NumBPinput, MLmodel_detail, modelLogFolder, logFilename, classes=noClasses)
            model.fit(train, batch_size=train_batch_size, verbose = 1, epochs=train_ep, callbacks=callbacks, class_weight=class_weight, validation_data=val)        
        #else:
        #print('ZZZZZZZZZZZZZ')
        #    sub_train = tf.data.Dataset.from_tensor_slices((sub_xTrain, tf.one_hot(sub_yTrain, NumSKPVclasses))).shuffle(4*64).batch(train_batch_size)
        #    model.fit(sub_train, batch_size=train_batch_size, verbose = 1, epochs=train_ep, callbacks=callbacks, class_weight=class_weight, validation_data=val)        
        #test_rank = eval_model(model, nruns_default, maxtrc_default, batches, xTest, yTest_value, noHypoKeys, noClasses)
            #print('Test Rank:, ', test_rank)
        #xTrain_Pool[0][0]
        

    print('Training took {}s'.format(time.time() - start_time))
        #model.save(database_folder_train+'/model_iteration_{}.keras'.format(it))

if work == 'train':
    print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
    print('Work =', work)
    classWeights = np.ones(noClasses).astype(int)
    class_weight = dict(enumerate(classWeights))

    train_model_multiEpochs(xType, database_folder_train, modelLogFolder, logTrainedModel_byFile_folder, logTrainedModel_byEp_folder, logFilename, MLmodel_detail, sKeyNo, class_weight, period, maxEpochs, train_batch_size, args)