from sklearn.neighbors import KNeighborsClassifier
import os.path
import sys
import h5py
import math
import gc
import numpy as np
#from numba import cuda
import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
import copy
import time
from tqdm import tqdm as tqdm
from scipy.spatial import distance
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Flatten, Dense, Input, Conv1D, MaxPooling1D, ReLU, Dropout, Concatenate, Activation, Multiply, BatchNormalization #, AveragePooling1D, Add, GlobalAveragePooling1D, GlobalMaxPooling1D
from tensorflow.keras.utils import plot_model   #, get_source_inputs
from tensorflow.keras.utils import get_file
from tensorflow.keras import backend as K
#from keras_applications.imagenet_utils import decode_predictions
#from keras_applications.imagenet_utils import preprocess_input
#from keras_applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model
from math import log2
import itertools

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--train_model', type=str)
    parser.add_argument('--sampling_type', type=str, default = 'KL')
    parser.add_argument('--num_sample', type=int, help='iteration_num', default=5)
    parser.add_argument('--num_trace', type=int, help='iteration_num', default=5)
    parser.add_argument('--is_wavelet', type=int, default=0)
    parser.add_argument('--normalize', type=int, default=1)
    parser.add_argument('--cont_file', type=str, default=None)
    parser.add_argument('--save_interval', type=int, default=30)
    parser.add_argument('--add_num', type=int, default=0, help='additional number of extend pool')
    
    return parser

def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

def get_similar_samples(xTrain, target_samples, num_sample):
    
    sample_indexes = []
    sample_index_disjoint = []
    print(target_samples.shape)
    for key_sample in tqdm(target_samples):
        sample_key = []
        for target_sample in key_sample:
            #print(target_sample.shape)
            sim_score = []
            for i in range(len(xTrain)):
                dist = distance.euclidean(xTrain[i], target_sample)
                sim_score.append(dist)
            sim_score = np.array(sim_score)
            #print('---------------')
            #print(len(sim_score))
            sample_index = np.argpartition(sim_score, num_sample)[:num_sample]
            print(len(sample_index))
            print(sample_index[:10])
            sample_indexes = np.concatenate((sample_indexes, sample_index))
            sample_key.append(sample_index)
        sample_index_disjoint.append(sample_key)
    print(len(sample_indexes))
    sample_indexes = np.unique(sample_indexes)
    sample_index_disjoint = np.array(sample_index_disjoint)
    print(len(sample_indexes))
    print(len(sample_index_disjoint))
    print(sample_index_disjoint.shape)
    return sample_indexes, sample_index_disjoint

# calculate the kl divergence
def kl_divergence(p, q):
    return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))

def KL(a, b):
    #a = np.asarray(a, dtype=np.float)
    #b = np.asarray(b, dtype=np.float)

    return np.sum(np.where(a != 0, a * np.log(a / b), 0))

# calculate the js divergence
def js_divergence(p, q):
    m = 0.5 * (p + q)
    return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)

def epsilon_onehot(label, epsilon):
    g_x = np.zeros(3329)

    for i in range(len(g_x)):
        if i != label:
            g_x[i] = epsilon
        else:
            g_x[i] = 1 - (3329 - 1) * epsilon

    return g_x

#------------------

def cul_mul(arr):
    res = []
    for i in range(len(arr)):
        if i == 0:
            res = arr[0]
        else:
            res = res * arr[i]

    return res

def iterative_key(key_index, num_set):
    all_index = []
    for i in range(len(key_index) - num_set):
        all_index.append(key_index[i:i+num_set])

    return all_index

def norm_data(values):
    arr = np.array((values - np.min(values)) / (np.max(values) - np.min(values)))
    arr_sum = np.sum(arr)

    return arr/arr_sum

from scipy import stats
import torch
from gs_divergence import gs_div
'''
a = torch.Tensor([0.1, 0.2, 0.3, 0.4])
b = torch.Tensor([0.2, 0.2, 0.4, 0.2])

div = gs_div(a, b, alpha=-1, lmd=0.5)
'''

def get_similar_alpha_KL_GT(model, xTrain_original, yTrain_original, xTest_multi, yTest_multi, num_sample):
    
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    
    input_preds = model.predict(val)

    epsilon = 1e-10
    #D1, D2, GT
    #Calculate KL(D1,D2) then a_KL(D1, D2):
    print(xTest_multi.shape)
    sample_index_disjoint = []
    ep_yTrain = np.ones((len(yTrain_original), 3329)) * epsilon
    KL_GT = []
    for i in range(len(yTrain_original)):
        ep_yTrain[yTrain_original[i]] = 1 - (3329 - 1) * epsilon
        p_x = input_preds[i]
        q_x = ep_yTrain[yTrain_original[i]]
        #ep_yTrain[yTrain_original[i]] = 1
        #q_x = torch.Tensor(ep_yTrain[yTrain_original[i]])
        #KLD = gs_div(p_x, q_x, alpha=-1, lmd=0.5).detach().cpu().numpy()
        KLD = stats.entropy(p_x, q_x) #+ KL_GT[i]
        KL_GT.append(np.abs(KLD))
        #ep_yTrain.append(epsilon_onehot(yTrain, epsilon))

    for key_sample in tqdm(xTest_multi):
        sample = np.expand_dims(key_sample, axis = 2)
        print(sample.shape)
        preds = model.predict(sample)
        print(preds.shape)
        sample_key = []
        sample_indexes = []
        #For each target trace
        chosen_indexes = []
        curr_indexes = np.arange(len(input_preds))
        for pred in preds:
            l_x = pred
            KL_score = []
            curr_input = input_preds[curr_indexes]
            for i in range(len(curr_input)):
                g_x = curr_input[i]
                #q_x = ep_yTrain[i]
                #print(l_x)
                #print(g_x)
                #print(g_x.shape)
                #print(l_x.shape)
                #print(np.sum(l_x))
                #KLD = kl_divergence(l_x, g_x)
                #KLD = KL(l_x, g_x)
                KLD = stats.entropy(l_x, g_x) * KL_GT[i]
                #print(KLD)
                #print(kld)
                #exit()
                KL_score.append(KLD)
            KL_score = np.array(KL_score)
            #print('---------------')
            #print(len(sim_score))
            sample_index = np.argpartition(KL_score, num_sample)[:num_sample]
            chosen_indexes = np.concatenate((chosen_indexes, curr_indexes[sample_index]))
            sample_key.append(curr_indexes[sample_index])
            sample_indexes.append(curr_indexes[sample_index])
            curr_indexes = np.setdiff1d(curr_indexes, chosen_indexes)
            #print(len(curr_indexes))
            #print(len(chosen_indexes))
        print(len(np.unique(chosen_indexes)))
        sample_index_disjoint.append(chosen_indexes)

    print(len(sample_indexes))
    sample_indexes = np.unique(sample_indexes)
    sample_index_disjoint = np.array(sample_index_disjoint)
    print(len(sample_indexes))
    print(len(sample_index_disjoint))
    print(sample_index_disjoint.shape)
    return sample_indexes, sample_index_disjoint

def get_similar_alpha_KL_GT_geo(model, xTrain_original, yTrain_original, xTest_multi, yTest_multi, num_sample, cont_file):
    
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    
    input_preds = model.predict(val)

    epsilon = 1e-10
    #D1, D2, GT
    #Calculate KL(D1,D2) then a_KL(D1, D2):
    print(xTest_multi.shape)
    sample_index_disjoint = []
    ep_yTrain = np.ones((len(yTrain_original), 3329)) * epsilon

    curr_key = 0



    idx = 0

    preds = model.predict(xTest_multi)

    sample_key = []
    sample_indexes = []
    chosen_indexes = []
    curr_indexes = np.arange(len(input_preds))
    for pred in tqdm(preds):
        l_x = pred
        KL_score = []
        curr_input = input_preds[curr_indexes]
        for i in range(len(curr_input)):
            g_x = torch.Tensor(curr_input[i])
            l_x = torch.Tensor(l_x)
            #q_x = ep_yTrain[i]
            #print(l_x)
            #print(g_x)
            #print(g_x.shape)
            #print(l_x.shape)
            #print(np.sum(l_x))
            #KLD = kl_divergence(l_x, g_x)
            KLD = np.abs(gs_div(l_x, g_x, alpha=-1, lmd=0.5).detach().cpu().numpy()) #* KL_GT[i] 
            #KLD = stats.entropy(l_x, g_x) * KL_GT[i]
            #print(KLD)
            #print(kld)
            #exit()
            KL_score.append(KLD)
        KL_score = np.array(KL_score)
        #print('---------------')
        #print(len(sim_score))
        sample_index = np.argpartition(KL_score, num_sample)[:num_sample]
        chosen_indexes = np.concatenate((chosen_indexes, curr_indexes[sample_index]))
        sample_key.append(curr_indexes[sample_index])
        sample_indexes.append(curr_indexes[sample_index])
        curr_indexes = np.setdiff1d(curr_indexes, chosen_indexes)
        #print(len(curr_indexes))
        #print(len(chosen_indexes))
    print(len(np.unique(chosen_indexes)))
    sample_index_disjoint.append(chosen_indexes)
    

    print(len(sample_indexes))
    sample_indexes = np.unique(sample_indexes)
    sample_index_disjoint = np.array(sample_index_disjoint)
    print(len(sample_indexes))
    print(len(sample_index_disjoint))
    print(sample_index_disjoint.shape)
    return sample_indexes, sample_index_disjoint

from keras.models import Model

def get_similar_alpha_KL_GT_geo_set(model, xTrain_original, yTrain_original, xTest_multi, yTest_multi, num_sample):
    
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    
    input_preds = model.predict(val)

    epsilon = 1e-10
    #D1, D2, GT
    #Calculate KL(D1,D2) then a_KL(D1, D2):
    print(xTest_multi.shape)
    sample_index_disjoint = []
    ep_yTrain = np.ones((len(yTrain_original), 3329)) * epsilon
    KL_GT = []
    for i in tqdm(range(len(yTrain_original))):
        #ep_yTrain[yTrain_original[i]] = 1 - (3329 - 1) * epsilon
        p_x = torch.Tensor(input_preds[i])
        q_x = ep_yTrain[yTrain_original[i]]
        ep_yTrain[yTrain_original[i]] = 1
        q_x = torch.Tensor(ep_yTrain[yTrain_original[i]])
        KLD = gs_div(p_x, q_x, alpha=-1, lmd=0.5).detach().cpu().numpy()
        #KLD = stats.entropy(p_x, q_x) #+ KL_GT[i]
        KL_GT.append(np.abs(KLD))
        #ep_yTrain.append(epsilon_onehot(yTrain, epsilon))

    for key_sample in tqdm(xTest_multi):
        sample = np.expand_dims(key_sample, axis = 2)
        print(sample.shape)
        preds = model.predict(sample)
        print(preds.shape)
        sample_key = []
        sample_indexes = []
        chosen_indexes = []
        curr_indexes = np.arange(len(input_preds))
        for pred in preds:
            print('ZZZ')
            l_x = pred
            KL_score = []
            curr_input = input_preds[curr_indexes]
            for i in range(len(curr_input)):
                g_x = torch.Tensor(curr_input[i])
                l_x = torch.Tensor(l_x)
                #q_x = ep_yTrain[i]
                #print(l_x)
                #print(g_x)
                #print(g_x.shape)
                #print(l_x.shape)
                #print(np.sum(l_x))
                #KLD = kl_divergence(l_x, g_x)
                KLD = np.abs(gs_div(l_x, g_x, alpha=-1, lmd=0.5).detach().cpu().numpy()) #* KL_GT[i] 
                #KLD = stats.entropy(l_x, g_x) * KL_GT[i]
                #print(KLD)
                #print(kld)
                #exit()
                KL_score.append(KLD)
            KL_score = np.array(KL_score)
            #print('---------------')
            #print(len(sim_score))
            sample_index = np.argpartition(KL_score, num_sample)[:num_sample]
            chosen_indexes = np.concatenate((chosen_indexes, curr_indexes[sample_index]))
            sample_key.append(curr_indexes[sample_index])
            sample_indexes.append(curr_indexes[sample_index])
            curr_indexes = np.setdiff1d(curr_indexes, chosen_indexes)
            print(chosen_indexes[:10])
            exit()
            #print(len(curr_indexes))
            #print(len(chosen_indexes))
        print(len(np.unique(chosen_indexes)))
        sample_index_disjoint.append(chosen_indexes)

    print(len(sample_indexes))
    sample_indexes = np.unique(sample_indexes)
    sample_index_disjoint = np.array(sample_index_disjoint)
    print(len(sample_indexes))
    print(len(sample_index_disjoint))
    print(sample_index_disjoint.shape)
    return sample_indexes, sample_index_disjoint

def get_similar_samples_KNN_Feature(model, xTrain_original, yTrain_original, xTest_multi, yTest_multi, num_sample):

    print(model.layers)
    print(model.layers[-2])
    # Make an auxiliary model that exposes the output from the intermediate layer
    # of interest, which is the first Dense layer in this case.
    feature_extractor = tf.keras.Model(inputs=model.inputs,
                               outputs=model.outputs + [model.layers[-2].output])

    # Access both the final and intermediate output of the original model
    # by calling `aux_model.predict()`.
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    #input_preds = feature_extractor.predict(val)
    final_output, input_preds = feature_extractor.predict(val)
    
    #epsilon = 1e-10
    #D1, D2, GT
    #Calculate KL(D1,D2) then a_KL(D1, D2):
    print(xTest_multi.shape)
    sample_index_disjoint = []

    out, preds = feature_extractor.predict(xTest_multi)

    sample_key = []
    sample_indexes = []
    chosen_indexes = []
    curr_indexes = np.arange(len(input_preds))
    for pred in tqdm(preds):
        l_x = pred
        sim_score = []
        curr_input = input_preds[curr_indexes]
        for i in range(len(curr_input)):
            #g_x = torch.Tensor(input_preds[i])
            #l_x = torch.Tensor(l_x)
            sim_dist = distance.euclidean(curr_input[i], l_x)
            sim_score.append(sim_dist)
        sim_score = np.array(sim_score)
        #print('---------------')
        #print(len(sim_score))
        sample_index = np.argpartition(sim_score, num_sample)[:num_sample]
        chosen_indexes = np.concatenate((chosen_indexes, curr_indexes[sample_index]))
        sample_key.append(curr_indexes[sample_index])
        sample_indexes.append(curr_indexes[sample_index])
        curr_indexes = np.setdiff1d(curr_indexes, chosen_indexes)
        #print(len(curr_indexes))
        #print(len(chosen_indexes))
    print(len(np.unique(chosen_indexes)))
    sample_index_disjoint.append(chosen_indexes)

    return sample_index_disjoint, sample_index_disjoint

def get_similar_alpha_KL_by_set(model, xTrain_original, yTrain_original, xTest_multi, yTest_multi, num_sample):
    
    num_set = 5
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    
    input_preds = model.predict(val)
    print(input_preds.shape)
    epsilon = 1e-10
    #Set instances
    all_set = []
    all_labels = []
    all_indexes= []
    all_argmax = []
    all_GT = []
    num_trace = len(xTest_multi[0])
    print(num_trace)

    for i in tqdm(range(3329)):
        key_index = np.where(yTrain_original == i)[0]
        instance_idxs = iterative_key(key_index, num_set)
        for instance_idx in instance_idxs:
            all_indexes.append(instance_idx)
            p_x = cul_mul(input_preds[instance_idx,:])
            #print(np.sum(p_x))
            p_x = norm_data(p_x)
            q_x = np.zeros(3329)
            q_x[i] = 1
            #print(np.sum(p_x))
            #print(p_x)
            KLD = gs_div(torch.Tensor(p_x), torch.Tensor(q_x), alpha=-1, lmd=0.5).detach().cpu().numpy()
            all_GT.append(KLD)
            all_argmax.append(np.argmax(p_x))
            all_set.append(p_x)
            all_labels.append(i)

    all_indexes = np.array(all_indexes)
    all_labels = np.array(all_labels)
    print(len(all_labels))
    
    unique, counts = np.unique(all_argmax, return_counts=True)
    print(len(unique))
    print(np.std(counts))
    print(np.min(counts))
    print(np.max(counts))
    sample_index_disjoint = []
    idx = 0
    for key_sample in tqdm(xTest_multi):
        #print('------------------------')
        #print(yTest_multi[idx])
        curr_key = yTest_multi[idx]
        idx += 1
        sample = np.expand_dims(key_sample, axis = 2)
        #print(sample.shape)
        preds = model.predict(sample)
        q_x = norm_data(cul_mul(preds))
        #print(np.argmax(q_x))
        #print(sum(q_x))
        KL_score = []
        #exit()
        for i in range(len(all_set)):
            p_x = all_set[i]
            #KLD = stats.entropy(p_x, q_x) #+ KL_GT[i]
            #KLD = stats.entropy(torch.Tensor(p_x), torch.Tensor(q_x))
            KLD = gs_div(torch.Tensor(p_x), torch.Tensor(q_x), alpha=-1, lmd=0.5).detach().cpu().numpy()
            #print(KLD)
            #KLD = KLD * all_GT[i]
            #print(kld)
            KL_score.append(np.abs(KLD))
        KL_score = np.array(KL_score)
        #print(np.mean(KL_score))
        #print(np.max(KL_score))
        #print(np.min(KL_score))
        #print('---------------')
        #print(len(sim_score))
        sample_index = np.argpartition(KL_score, num_sample)[:num_sample]
        sample_index = np.array(sample_index).astype(int)#
        all_index = all_indexes[sample_index]
        print(all_index.shape)
        '''
        print(sample_index)
        print(sample_index.shape)
        all_index = all_indexes[sample_index].flatten()
        print(all_index.shape)
        all_lb = all_labels[sample_index]
        print('----')
        print(np.unique(all_lb))
        print(len(np.unique(all_lb)))
        print('----')
        #if curr_key in all_lb:
        #    print('yes')
        print(all_lb)
        print(all_set[sample_index[0]])
        print(q_x)
        print(sum(all_set[sample_index[0]]))
        print(sum(q_x))
        print(np.argmax(all_set[sample_index[0]]))
        #exit()
        '''
        sample_index_disjoint.append(all_index)
    sample_index_disjoint = np.array(sample_index_disjoint).astype(int)
    print(sample_index_disjoint.shape)

    return all_indexes, sample_index_disjoint

from sklearn.preprocessing import StandardScaler
from scipy import stats

def normalize(timeseries):
    return (timeseries-timeseries.min())/(timeseries.max()-timeseries.min())

def z_norm(timeseries):
    '''
    timeseries = np.expand_dims(timeseries, 1)
    scaler = StandardScaler()
    scaler.fit(timeseries)
    res = scaler.transform(timeseries)[:,0]
    '''

    return stats.zscore(timeseries)

def normalize_data_per_trace(data):
    print(data.shape)
    for j in range(len(data)):
        for i in range(len(data[j])):
            data[j][i] = normalize(data[j][i])

    return data

def normalize_data_per_trace_train(data):
    print(data.shape)
    for j in range(len(data)):
        data[j] = normalize(data[j])

    return data

def z_normalize_data_per_trace(data):
    print(data.shape)
    for j in range(len(data)):
        for i in range(len(data[j])):
            data[j][i] = z_norm(data[j][i])

    return data

def check_file_exists(file_path):
    file_path = os.path.normpath(file_path)
    if os.path.exists(file_path) == False:
        print("Error: provided file path '%s' does not exist!" % file_path)
        sys.exit(-1)
    return

def load_sca_model(model_file):
    check_file_exists(model_file)
    try:
        model = load_model(model_file)
    except:
        print("Error: can't load Keras model file '%s'" % model_file)
        sys.exit(-1)
    return model

def load_ascad(ascad_database_file, load_metadata=False):
    check_file_exists(ascad_database_file)
    # Open the ASCAD database HDF5 for reading
    try:
        in_file  = h5py.File(ascad_database_file, "r")
    except:
        print("Error: can't open HDF5 file '%s' for reading (it might be malformed) ..." % ascad_database_file)
        sys.exit(-1)
    # Load profiling traces
    X_profiling = np.array(in_file['Profiling_traces/traces'], dtype=np.int8)
    # Load profiling labels
    Y_profiling = np.array(in_file['Profiling_traces/labels'])
    # Load attacking traces
    X_attack = np.array(in_file['Attack_traces/traces'], dtype=np.int8)
    # Load attacking labels
    Y_attack = np.array(in_file['Attack_traces/labels'])
    if load_metadata == False:
        return (X_profiling, Y_profiling), (X_attack, Y_attack)
    else:
        return (X_profiling, Y_profiling), (X_attack, Y_attack), (in_file['Profiling_traces/metadata'], in_file['Attack_traces/metadata'])


parser = parse_arguments()
args = parser.parse_args()
data = np.load('data.npz')
train_folder = os.path.dirname(args.train_model)

NUM_SAMPLE = 200000
num_traces=2000
target_byte=2
multilabel=0
simulated_key=0
save_file=""
fpath = 'ASCAD_variable.h5'
(X_profiling, Y_profiling), (X_attack, Y_attack), (Metadata_profiling, Metadata_attack) = load_ascad(fpath, load_metadata=True)
input_data = X_profiling
ref_data = X_attack[:args.num_trace]
model = load_model(os.path.join(args.train_model))
labels = Y_profiling
yTest_multi = Y_attack
model.summary()
'''
print('-----------------------')
print(input_data.shape)
print(ref_data.shape)
exit()
'''

print(args.sampling_type)

if args.sampling_type == 'euclid':
    sampled_ids, disjoint_ids = get_similar_samples(input_data, ref_data, args.num_sample)
elif args.sampling_type == 'KL':
    sampled_ids, disjoint_ids = get_similar_alpha_KL_GT(model, input_data, labels, ref_data, yTest_multi, args.num_sample)
elif args.sampling_type == 'KL_geo':
    sampled_ids, disjoint_ids = get_similar_alpha_KL_GT_geo(model, input_data, labels, ref_data, yTest_multi, args.num_sample, args.cont_file)
elif args.sampling_type == 'KL_set':
    sampled_ids, disjoint_ids = get_similar_alpha_KL_by_set(model, input_data, labels, ref_data, yTest_multi, args.num_sample, args.cont_file)
elif args.sampling_type == 'KNN_feature':
    sampled_ids, disjoint_ids = get_similar_samples_KNN_Feature(model, input_data, labels, ref_data, yTest_multi, args.num_sample)
    

#sampled_ids, disjoint_ids = get_similar_alpha_KL_GT_geo(model, xTrain_original, labels, ref_data, yTest_multi, args.num_sample)
fname='ids_{}_{}'.format(args.num_trace, args.num_sample)
np.save(os.path.join(train_folder, fname), sampled_ids)
np.save(os.path.join(train_folder, 'disjoint_{}_{}_{}_{}.npy'.format(args.num_trace, args.num_sample, args.sampling_type, os.path.basename(args.train_model)[:-6])), disjoint_ids)

'''
for i in range(len(disjoint_ids)):
    fname = 'disjoint_{}'.format(i)
    np.save(os.path.join(args.train_folder, fname), sampled_ids)
'''
'''
neigh = KNeighborsClassifier(n_neighbors=3)
neigh.fit(X, y)
print(neigh.predict([[1.1]]))
print(neigh.predict_proba([[0.9]]))
'''