from sklearn.neighbors import KNeighborsClassifier
import os.path
import sys
import h5py
import math
import gc
import numpy as np
#from numba import cuda
import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
import copy
import time
from tqdm import tqdm as tqdm
from scipy.spatial import distance
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Flatten, Dense, Input, Conv1D, MaxPooling1D, ReLU, Dropout, Concatenate, Activation, Multiply, BatchNormalization #, AveragePooling1D, Add, GlobalAveragePooling1D, GlobalMaxPooling1D
from tensorflow.keras.utils import plot_model   #, get_source_inputs
from tensorflow.keras.utils import get_file
from tensorflow.keras import backend as K
#from keras_applications.imagenet_utils import decode_predictions
#from keras_applications.imagenet_utils import preprocess_input
#from keras_applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model
from math import log2
import itertools

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--train_folder', type=str)
    parser.add_argument('--data_path', type=str)
    parser.add_argument('--sampling_type', type=str, default = 'KL')
    parser.add_argument('--num_sample', type=int, help='iteration_num', default=5)
    parser.add_argument('--num_trace', type=int, help='iteration_num', default=5)
    parser.add_argument('--start_key', type=int, help='iteration_num', default=0)
    parser.add_argument('--end_key', type=int, help='iteration_num', default=300)
    parser.add_argument('--is_wavelet', type=int, default=0)
    parser.add_argument('--normalize', type=int, default=1)
    parser.add_argument('--cont_file', type=str, default=None)
    parser.add_argument('--save_interval', type=int, default=30)
    parser.add_argument('--add_num', type=int, default=0, help='additional number of extend pool')
    
    return parser

def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

def get_similar_samples(xTrain, target_samples, num_sample):
    
    sample_indexes = []
    sample_index_disjoint = []
    print(target_samples.shape)
    for key_sample in tqdm(target_samples):
        sample_key = []
        for target_sample in key_sample:
            #print(target_sample.shape)
            sim_score = []
            for i in range(len(xTrain)):
                dist = distance.euclidean(xTrain[i], target_sample)
                sim_score.append(dist)
            sim_score = np.array(sim_score)
            #print('---------------')
            #print(len(sim_score))
            sample_index = np.argpartition(sim_score, num_sample)[:num_sample]
            print(len(sample_index))
            print(sample_index[:10])
            sample_indexes = np.concatenate((sample_indexes, sample_index))
            sample_key.append(sample_index)
        sample_index_disjoint.append(sample_key)
    print(len(sample_indexes))
    sample_indexes = np.unique(sample_indexes)
    sample_index_disjoint = np.array(sample_index_disjoint)
    print(len(sample_indexes))
    print(len(sample_index_disjoint))
    print(sample_index_disjoint.shape)
    return sample_indexes, sample_index_disjoint

# calculate the kl divergence
def kl_divergence(p, q):
    return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))

def KL(a, b):
    #a = np.asarray(a, dtype=np.float)
    #b = np.asarray(b, dtype=np.float)

    return np.sum(np.where(a != 0, a * np.log(a / b), 0))

# calculate the js divergence
def js_divergence(p, q):
    m = 0.5 * (p + q)
    return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)

def epsilon_onehot(label, epsilon):
    g_x = np.zeros(3329)

    for i in range(len(g_x)):
        if i != label:
            g_x[i] = epsilon
        else:
            g_x[i] = 1 - (3329 - 1) * epsilon

    return g_x

#------------------

def cul_mul(arr):
    res = []
    for i in range(len(arr)):
        if i == 0:
            res = arr[0]
        else:
            res = res * arr[i]

    return res

def iterative_key(key_index, num_set):
    all_index = []
    for i in range(len(key_index) - num_set):
        all_index.append(key_index[i:i+num_set])

    return all_index

def norm_data(values):
    arr = np.array((values - np.min(values)) / (np.max(values) - np.min(values)))
    arr_sum = np.sum(arr)

    return arr/arr_sum

from scipy import stats
import torch
from gs_divergence import gs_div
'''
a = torch.Tensor([0.1, 0.2, 0.3, 0.4])
b = torch.Tensor([0.2, 0.2, 0.4, 0.2])

div = gs_div(a, b, alpha=-1, lmd=0.5)
'''

def get_similar_alpha_KL_GT(model, xTrain_original, yTrain_original, xTest_multi, yTest_multi, num_sample):
    
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    
    input_preds = model.predict(val)

    epsilon = 1e-10
    #D1, D2, GT
    #Calculate KL(D1,D2) then a_KL(D1, D2):
    print(xTest_multi.shape)
    sample_index_disjoint = []
    ep_yTrain = np.ones((len(yTrain_original), 3329)) * epsilon
    KL_GT = []
    for i in range(len(yTrain_original)):
        ep_yTrain[yTrain_original[i]] = 1 - (3329 - 1) * epsilon
        p_x = input_preds[i]
        q_x = ep_yTrain[yTrain_original[i]]
        #ep_yTrain[yTrain_original[i]] = 1
        #q_x = torch.Tensor(ep_yTrain[yTrain_original[i]])
        #KLD = gs_div(p_x, q_x, alpha=-1, lmd=0.5).detach().cpu().numpy()
        KLD = stats.entropy(p_x, q_x) #+ KL_GT[i]
        KL_GT.append(np.abs(KLD))
        #ep_yTrain.append(epsilon_onehot(yTrain, epsilon))

    for key_sample in tqdm(xTest_multi):
        sample = np.expand_dims(key_sample, axis = 2)
        print(sample.shape)
        preds = model.predict(sample)
        print(preds.shape)
        sample_key = []
        sample_indexes = []
        #For each target trace
        chosen_indexes = []
        curr_indexes = np.arange(len(input_preds))
        for pred in preds:
            l_x = pred
            KL_score = []
            curr_input = input_preds[curr_indexes]
            for i in range(len(curr_input)):
                g_x = curr_input[i]
                #q_x = ep_yTrain[i]
                #print(l_x)
                #print(g_x)
                #print(g_x.shape)
                #print(l_x.shape)
                #print(np.sum(l_x))
                #KLD = kl_divergence(l_x, g_x)
                #KLD = KL(l_x, g_x)
                KLD = stats.entropy(l_x, g_x) * KL_GT[i]
                #print(KLD)
                #print(kld)
                #exit()
                KL_score.append(KLD)
            KL_score = np.array(KL_score)
            #print('---------------')
            #print(len(sim_score))
            sample_index = np.argpartition(KL_score, num_sample)[:num_sample]
            chosen_indexes = np.concatenate((chosen_indexes, curr_indexes[sample_index]))
            sample_key.append(curr_indexes[sample_index])
            sample_indexes.append(curr_indexes[sample_index])
            curr_indexes = np.setdiff1d(curr_indexes, chosen_indexes)
            #print(len(curr_indexes))
            #print(len(chosen_indexes))
        print(len(np.unique(chosen_indexes)))
        sample_index_disjoint.append(chosen_indexes)

    print(len(sample_indexes))
    sample_indexes = np.unique(sample_indexes)
    sample_index_disjoint = np.array(sample_index_disjoint)
    print(len(sample_indexes))
    print(len(sample_index_disjoint))
    print(sample_index_disjoint.shape)
    return sample_indexes, sample_index_disjoint

def get_similar_alpha_KL_GT_geo(model, xTrain_original, yTrain_original, xTest_multi, yTest_multi, num_sample, cont_file):
    
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    
    input_preds = model.predict(val)

    epsilon = 1e-10
    #D1, D2, GT
    #Calculate KL(D1,D2) then a_KL(D1, D2):
    print(xTest_multi.shape)
    sample_index_disjoint = []
    ep_yTrain = np.ones((len(yTrain_original), 3329)) * epsilon
    KL_GT = []
    curr_key = 0
    if cont_file is not None:
        print('ZZZ')
        dj_file = np.load(args.cont_file)
        curr_key = dj_file.shape[0]
        
    for i in tqdm(range(len(yTrain_original))):
        #ep_yTrain[yTrain_original[i]] = 1 - (3329 - 1) * epsilon
        p_x = torch.Tensor(input_preds[i])
        q_x = ep_yTrain[yTrain_original[i]]
        ep_yTrain[yTrain_original[i]] = 1
        q_x = torch.Tensor(ep_yTrain[yTrain_original[i]])
        KLD = gs_div(p_x, q_x, alpha=-1, lmd=0.5).detach().cpu().numpy()
        #KLD = stats.entropy(p_x, q_x) #+ KL_GT[i]
        KL_GT.append(np.abs(KLD))
        #ep_yTrain.append(epsilon_onehot(yTrain, epsilon))

    idx = 0
    for i in tqdm(range(curr_key, len(xTest_multi))):
        key_sample = xTest_multi[i]
        sample = np.expand_dims(key_sample, axis = 2)
        print(sample.shape)
        preds = model.predict(sample)
        print(preds.shape)
        sample_key = []
        sample_indexes = []
        chosen_indexes = []
        curr_indexes = np.arange(len(input_preds))
        for pred in tqdm(preds):
            l_x = pred
            KL_score = []
            curr_input = input_preds[curr_indexes]
            for i in range(len(curr_input)):
                g_x = torch.Tensor(curr_input[i])
                l_x = torch.Tensor(l_x)
                #q_x = ep_yTrain[i]
                #print(l_x)
                #print(g_x)
                #print(g_x.shape)
                #print(l_x.shape)
                #print(np.sum(l_x))
                #KLD = kl_divergence(l_x, g_x)
                KLD = np.abs(gs_div(l_x, g_x, alpha=-1, lmd=0.5).detach().cpu().numpy()) #* KL_GT[i] 
                #KLD = stats.entropy(l_x, g_x) * KL_GT[i]
                #print(KLD)
                #print(kld)
                #exit()
                KL_score.append(KLD)
            KL_score = np.array(KL_score)
            #print('---------------')
            #print(len(sim_score))
            sample_index = np.argpartition(KL_score, num_sample)[:num_sample]
            chosen_indexes = np.concatenate((chosen_indexes, curr_indexes[sample_index]))
            sample_key.append(curr_indexes[sample_index])
            sample_indexes.append(curr_indexes[sample_index])
            curr_indexes = np.setdiff1d(curr_indexes, chosen_indexes)
            #print(len(curr_indexes))
            #print(len(chosen_indexes))
        print(len(np.unique(chosen_indexes)))
        sample_index_disjoint.append(chosen_indexes)
        if idx % args.save_interval == 0:
            curr_disjoint = np.array(sample_index_disjoint)
            ref_fname = args.data_path[:-4]
            np.save(os.path.join(args.train_folder, 'disjoint_{}_{}_{}_{}_{}_{}'.format(args.start_key, idx, args.num_trace, args.num_sample, args.sampling_type, ref_fname)), curr_disjoint)
        idx +=1

    print(len(sample_indexes))
    sample_indexes = np.unique(sample_indexes)
    sample_index_disjoint = np.array(sample_index_disjoint)
    print(len(sample_indexes))
    print(len(sample_index_disjoint))
    print(sample_index_disjoint.shape)
    return sample_indexes, sample_index_disjoint

from keras.models import Model

def get_similar_alpha_KL_GT_geo_set(model, xTrain_original, yTrain_original, xTest_multi, yTest_multi, num_sample):
    
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    
    input_preds = model.predict(val)

    epsilon = 1e-10
    #D1, D2, GT
    #Calculate KL(D1,D2) then a_KL(D1, D2):
    print(xTest_multi.shape)
    sample_index_disjoint = []
    ep_yTrain = np.ones((len(yTrain_original), 3329)) * epsilon
    KL_GT = []
    for i in tqdm(range(len(yTrain_original))):
        #ep_yTrain[yTrain_original[i]] = 1 - (3329 - 1) * epsilon
        p_x = torch.Tensor(input_preds[i])
        q_x = ep_yTrain[yTrain_original[i]]
        ep_yTrain[yTrain_original[i]] = 1
        q_x = torch.Tensor(ep_yTrain[yTrain_original[i]])
        KLD = gs_div(p_x, q_x, alpha=-1, lmd=0.5).detach().cpu().numpy()
        #KLD = stats.entropy(p_x, q_x) #+ KL_GT[i]
        KL_GT.append(np.abs(KLD))
        #ep_yTrain.append(epsilon_onehot(yTrain, epsilon))

    for key_sample in tqdm(xTest_multi):
        sample = np.expand_dims(key_sample, axis = 2)
        print(sample.shape)
        preds = model.predict(sample)
        print(preds.shape)
        sample_key = []
        sample_indexes = []
        chosen_indexes = []
        curr_indexes = np.arange(len(input_preds))
        for pred in preds:
            print('ZZZ')
            l_x = pred
            KL_score = []
            curr_input = input_preds[curr_indexes]
            for i in range(len(curr_input)):
                g_x = torch.Tensor(curr_input[i])
                l_x = torch.Tensor(l_x)
                #q_x = ep_yTrain[i]
                #print(l_x)
                #print(g_x)
                #print(g_x.shape)
                #print(l_x.shape)
                #print(np.sum(l_x))
                #KLD = kl_divergence(l_x, g_x)
                KLD = np.abs(gs_div(l_x, g_x, alpha=-1, lmd=0.5).detach().cpu().numpy()) #* KL_GT[i] 
                #KLD = stats.entropy(l_x, g_x) * KL_GT[i]
                #print(KLD)
                #print(kld)
                #exit()
                KL_score.append(KLD)
            KL_score = np.array(KL_score)
            #print('---------------')
            #print(len(sim_score))
            sample_index = np.argpartition(KL_score, num_sample)[:num_sample]
            chosen_indexes = np.concatenate((chosen_indexes, curr_indexes[sample_index]))
            sample_key.append(curr_indexes[sample_index])
            sample_indexes.append(curr_indexes[sample_index])
            curr_indexes = np.setdiff1d(curr_indexes, chosen_indexes)
            print(chosen_indexes[:10])
            exit()
            #print(len(curr_indexes))
            #print(len(chosen_indexes))
        print(len(np.unique(chosen_indexes)))
        sample_index_disjoint.append(chosen_indexes)

    print(len(sample_indexes))
    sample_indexes = np.unique(sample_indexes)
    sample_index_disjoint = np.array(sample_index_disjoint)
    print(len(sample_indexes))
    print(len(sample_index_disjoint))
    print(sample_index_disjoint.shape)
    return sample_indexes, sample_index_disjoint

def get_similar_samples_KNN_Feature(model, xTrain_original, yTrain_original, xTest_multi, yTest_multi, num_sample):

    print(model.layers)
    #exit()
    XX = model.input 
    YY = model.layers[-2].output
    feature_extractor = Model(XX, YY)

    #Xaug = xTrain_original[:16]
    #Xresult = feature_extractor.predict(Xaug)
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    input_preds = feature_extractor.predict(val)

    #epsilon = 1e-10
    #D1, D2, GT
    #Calculate KL(D1,D2) then a_KL(D1, D2):
    print(xTest_multi.shape)
    sample_index_disjoint = []

    for key_sample in tqdm(xTest_multi):
        sample = np.expand_dims(key_sample, axis = 2)
        preds = feature_extractor.predict(sample)
        print(preds.shape)
        sample_key = []
        sample_indexes = []
        chosen_indexes = []
        curr_indexes = np.arange(len(input_preds))
        for pred in preds:
            l_x = pred
            sim_score = []
            curr_input = input_preds[curr_indexes]
            for i in range(len(curr_input)):
                #g_x = torch.Tensor(input_preds[i])
                #l_x = torch.Tensor(l_x)
                sim_dist = distance.euclidean(curr_input[i], l_x)
                sim_score.append(sim_dist)
            sim_score = np.array(sim_score)
            #print('---------------')
            #print(len(sim_score))
            sample_index = np.argpartition(sim_score, num_sample)[:num_sample]
            chosen_indexes = np.concatenate((chosen_indexes, curr_indexes[sample_index]))
            sample_key.append(curr_indexes[sample_index])
            sample_indexes.append(curr_indexes[sample_index])
            curr_indexes = np.setdiff1d(curr_indexes, chosen_indexes)
            #print(len(curr_indexes))
            #print(len(chosen_indexes))
        print(len(np.unique(chosen_indexes)))
        sample_index_disjoint.append(chosen_indexes)
    sample_index_disjoint = np.array(sample_index_disjoint)
    print(sample_index_disjoint.shape)
    return sample_index_disjoint, sample_index_disjoint

def get_similar_alpha_KL_by_set(model, xTrain_original, yTrain_original, xTest_multi, yTest_multi, num_sample):
    
    num_set = 5
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    
    input_preds = model.predict(val)
    print(input_preds.shape)
    epsilon = 1e-10
    #Set instances
    all_set = []
    all_labels = []
    all_indexes= []
    all_argmax = []
    all_GT = []
    num_trace = len(xTest_multi[0])
    print(num_trace)

    for i in tqdm(range(3329)):
        key_index = np.where(yTrain_original == i)[0]
        instance_idxs = iterative_key(key_index, num_set)
        for instance_idx in instance_idxs:
            all_indexes.append(instance_idx)
            p_x = cul_mul(input_preds[instance_idx,:])
            #print(np.sum(p_x))
            p_x = norm_data(p_x)
            q_x = np.zeros(3329)
            q_x[i] = 1
            #print(np.sum(p_x))
            #print(p_x)
            KLD = gs_div(torch.Tensor(p_x), torch.Tensor(q_x), alpha=-1, lmd=0.5).detach().cpu().numpy()
            all_GT.append(KLD)
            all_argmax.append(np.argmax(p_x))
            all_set.append(p_x)
            all_labels.append(i)

    all_indexes = np.array(all_indexes)
    all_labels = np.array(all_labels)
    print(len(all_labels))
    
    unique, counts = np.unique(all_argmax, return_counts=True)
    print(len(unique))
    print(np.std(counts))
    print(np.min(counts))
    print(np.max(counts))
    sample_index_disjoint = []
    idx = 0
    for key_sample in tqdm(xTest_multi):
        #print('------------------------')
        #print(yTest_multi[idx])
        curr_key = yTest_multi[idx]
        idx += 1
        sample = np.expand_dims(key_sample, axis = 2)
        #print(sample.shape)
        preds = model.predict(sample)
        q_x = norm_data(cul_mul(preds))
        #print(np.argmax(q_x))
        #print(sum(q_x))
        KL_score = []
        #exit()
        for i in range(len(all_set)):
            p_x = all_set[i]
            #KLD = stats.entropy(p_x, q_x) #+ KL_GT[i]
            #KLD = stats.entropy(torch.Tensor(p_x), torch.Tensor(q_x))
            KLD = gs_div(torch.Tensor(p_x), torch.Tensor(q_x), alpha=-1, lmd=0.5).detach().cpu().numpy()
            #print(KLD)
            #KLD = KLD * all_GT[i]
            #print(kld)
            KL_score.append(np.abs(KLD))
        KL_score = np.array(KL_score)
        #print(np.mean(KL_score))
        #print(np.max(KL_score))
        #print(np.min(KL_score))
        #print('---------------')
        #print(len(sim_score))
        sample_index = np.argpartition(KL_score, num_sample)[:num_sample]
        sample_index = np.array(sample_index).astype(int)#
        all_index = all_indexes[sample_index]
        print(all_index.shape)
        '''
        print(sample_index)
        print(sample_index.shape)
        all_index = all_indexes[sample_index].flatten()
        print(all_index.shape)
        all_lb = all_labels[sample_index]
        print('----')
        print(np.unique(all_lb))
        print(len(np.unique(all_lb)))
        print('----')
        #if curr_key in all_lb:
        #    print('yes')
        print(all_lb)
        print(all_set[sample_index[0]])
        print(q_x)
        print(sum(all_set[sample_index[0]]))
        print(sum(q_x))
        print(np.argmax(all_set[sample_index[0]]))
        #exit()
        '''
        sample_index_disjoint.append(all_index)
    sample_index_disjoint = np.array(sample_index_disjoint).astype(int)
    print(sample_index_disjoint.shape)

    return all_indexes, sample_index_disjoint

from sklearn.preprocessing import StandardScaler
from scipy import stats

def normalize(timeseries):
    return (timeseries-timeseries.min())/(timeseries.max()-timeseries.min())

def z_norm(timeseries):
    '''
    timeseries = np.expand_dims(timeseries, 1)
    scaler = StandardScaler()
    scaler.fit(timeseries)
    res = scaler.transform(timeseries)[:,0]
    '''

    return stats.zscore(timeseries)

def normalize_data_per_trace(data):
    print(data.shape)
    for j in range(len(data)):
        for i in range(len(data[j])):
            data[j][i] = normalize(data[j][i])

    return data

def normalize_data_per_trace_train(data):
    print(data.shape)
    for j in range(len(data)):
        data[j] = normalize(data[j])

    return data

def z_normalize_data_per_trace(data):
    print(data.shape)
    for j in range(len(data)):
        for i in range(len(data[j])):
            data[j][i] = z_norm(data[j][i])

    return data

parser = parse_arguments()
args = parser.parse_args()
data = np.load('data.npz')

NUM_SAMPLE = 200000

data_ids = np.load('attack_300key_ids.npz')['ids']
sorted_ids = np.sort(np.hstack(data_ids))
non_overlap_ids = np.setdiff1d(np.arange(300000), sorted_ids) + NUM_SAMPLE
print(sorted_ids.shape)
print(non_overlap_ids.shape)
print(np.max(non_overlap_ids))
addtional_ids = non_overlap_ids[:args.add_num]
init_ids = np.arange(NUM_SAMPLE)
all_ids = np.concatenate((np.arange(NUM_SAMPLE), addtional_ids))
print(len(all_ids))
print(all_ids[:10])

xTrain_original = data['data'][all_ids]
labels = data['label'][all_ids]
#xTest_multi, yTest_multi = load_multi_attack('attack_multi_data_300key.npz')
xTest_multi, yTest_multi = load_multi_attack(args.data_path)

print(xTrain_original.shape)
print(xTest_multi.shape)

if args.normalize == 1:
    print('Normalize')
    xTest_multi = normalize_data_per_trace(xTest_multi)
    xTrain_original = normalize_data_per_trace_train(xTrain_original)
ref_fname = args.data_path[:-4]

xTrain_original = np.expand_dims(xTrain_original, axis = 2)
print(xTrain_original.shape)
print(xTest_multi.shape)


if args.is_wavelet == 1:
    num_sample_per_key = 100
    data = np.load('data_wavelet.npz')
    xTrain_original = np.expand_dims(data['cA1'][:NUM_SAMPLE], axis = 2)
    test_data = data['cA1'][200000:]
    test_labels = data['label'][200000:]
    print(test_data.shape)
    xTest_multi = []
    for lb in yTest_multi:
        lb_indexes = np.where(test_labels == lb)[0]
        xTest_multi.append(test_data[lb_indexes[:num_sample_per_key]])

    xTest_multi = np.array(xTest_multi)
    print(xTrain_original.shape)
    print(xTest_multi.shape)


input_data = xTrain_original[:,:,0]
ref_data = xTest_multi[args.start_key:args.end_key,:args.num_trace,:]
model = load_model(os.path.join(args.train_folder, 'model_best_end.keras'))
model.summary()
'''
print('-----------------------')
print(input_data.shape)
print(ref_data.shape)
exit()
'''

print(args.sampling_type)

if args.sampling_type == 'euclid':
    sampled_ids, disjoint_ids = get_similar_samples(input_data, ref_data, args.num_sample)
elif args.sampling_type == 'KL':
    sampled_ids, disjoint_ids = get_similar_alpha_KL_GT(model, xTrain_original, labels, ref_data, yTest_multi, args.num_sample)
elif args.sampling_type == 'KL_geo':
    sampled_ids, disjoint_ids = get_similar_alpha_KL_GT_geo(model, xTrain_original, labels, ref_data, yTest_multi, args.num_sample, args.cont_file)
elif args.sampling_type == 'KL_set':
    sampled_ids, disjoint_ids = get_similar_alpha_KL_by_set(model, xTrain_original, labels, ref_data, yTest_multi, args.num_sample, args.cont_file)
elif args.sampling_type == 'KNN_feature':
    sampled_ids, disjoint_ids = get_similar_samples_KNN_Feature(model, xTrain_original, labels, ref_data, yTest_multi, args.num_sample)
    

#sampled_ids, disjoint_ids = get_similar_alpha_KL_GT_geo(model, xTrain_original, labels, ref_data, yTest_multi, args.num_sample)
fname='ids_{}_{}_{}_{}'.format(args.start_key, args.end_key, args.num_trace, args.num_sample)
np.save(os.path.join(args.train_folder, fname), sampled_ids)
np.save(os.path.join(args.train_folder, 'disjoint_{}_{}_{}_{}_{}_{}_{}'.format(args.start_key, args.end_key, args.num_trace, args.num_sample, args.sampling_type, ref_fname, args.add_num)), disjoint_ids)

'''
for i in range(len(disjoint_ids)):
    fname = 'disjoint_{}'.format(i)
    np.save(os.path.join(args.train_folder, fname), sampled_ids)
'''
'''
neigh = KNeighborsClassifier(n_neighbors=3)
neigh.fit(X, y)
print(neigh.predict([[1.1]]))
print(neigh.predict_proba([[0.9]]))
'''