from sklearn.neighbors import KNeighborsClassifier
import os.path
import sys
import h5py
import math
import gc
import numpy as np
#from numba import cuda
import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
import copy
import time
from tqdm import tqdm as tqdm
from scipy.spatial import distance
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import os
from pathlib import Path
import matplotlib.pyplot as plt
import os.path
import sys
import h5py
import math
import gc
import numpy as np
import matplotlib.pyplot as plt
import argparse
from pathlib import Path
from sklearn.cluster import KMeans
from math import log2
import itertools

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--train_folder', type=str)
    parser.add_argument('--data_path', type=str)
    parser.add_argument('--sampling_type', type=str, default = 'KL')
    parser.add_argument('--num_sample', type=int, help='iteration_num', default=5)
    parser.add_argument('--num_trace', type=int, help='iteration_num', default=5)
    parser.add_argument('--start_key', type=int, help='iteration_num', default=0)
    parser.add_argument('--end_key', type=int, help='iteration_num', default=300)
    parser.add_argument('--is_wavelet', type=int, default=0)
    parser.add_argument('--normalize', type=int, default=1)
    parser.add_argument('--cont_file', type=str, default=None)
    parser.add_argument('--save_interval', type=int, default=30)
    parser.add_argument('--add_num', type=int, default=0, help='additional number of extend pool')
    
    return parser

def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

def get_similar_samples(xTrain, target_samples, num_sample):
    
    sample_indexes = []
    sample_index_disjoint = []
    print(target_samples.shape)
    for key_sample in tqdm(target_samples):
        sample_key = []
        for target_sample in key_sample:
            #print(target_sample.shape)
            sim_score = []
            for i in range(len(xTrain)):
                dist = distance.euclidean(xTrain[i], target_sample)
                sim_score.append(dist)
            sim_score = np.array(sim_score)
            #print('---------------')
            #print(len(sim_score))
            sample_index = np.argpartition(sim_score, num_sample)[:num_sample]
            print(len(sample_index))
            print(sample_index[:10])
            sample_indexes = np.concatenate((sample_indexes, sample_index))
            sample_key.append(sample_index)
        sample_index_disjoint.append(sample_key)
    print(len(sample_indexes))
    sample_indexes = np.unique(sample_indexes)
    sample_index_disjoint = np.array(sample_index_disjoint)
    print(len(sample_indexes))
    print(len(sample_index_disjoint))
    print(sample_index_disjoint.shape)
    return sample_indexes, sample_index_disjoint

# calculate the kl divergence
def kl_divergence(p, q):
    return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))

def KL(a, b):
    #a = np.asarray(a, dtype=np.float)
    #b = np.asarray(b, dtype=np.float)

    return np.sum(np.where(a != 0, a * np.log(a / b), 0))

# calculate the js divergence
def js_divergence(p, q):
    m = 0.5 * (p + q)
    return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)

def epsilon_onehot(label, epsilon):
    g_x = np.zeros(3329)

    for i in range(len(g_x)):
        if i != label:
            g_x[i] = epsilon
        else:
            g_x[i] = 1 - (3329 - 1) * epsilon

    return g_x

#------------------

def cul_mul(arr):
    res = []
    for i in range(len(arr)):
        if i == 0:
            res = arr[0]
        else:
            res = res * arr[i]

    return res

def iterative_key(key_index, num_set):
    all_index = []
    for i in range(len(key_index) - num_set):
        all_index.append(key_index[i:i+num_set])

    return all_index

def norm_data(values):
    arr = np.array((values - np.min(values)) / (np.max(values) - np.min(values)))
    arr_sum = np.sum(arr)

    return arr/arr_sum

from scipy import stats
import torch
from gs_divergence import gs_div
'''
a = torch.Tensor([0.1, 0.2, 0.3, 0.4])
b = torch.Tensor([0.2, 0.2, 0.4, 0.2])

div = gs_div(a, b, alpha=-1, lmd=0.5)
'''

def get_similar_alpha_KL_GT(model, xTrain_original, yTrain_original, xTest_multi, yTest_multi, num_sample):
    
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    
    input_preds = model.predict(val)

    epsilon = 1e-10
    #D1, D2, GT
    #Calculate KL(D1,D2) then a_KL(D1, D2):
    print(xTest_multi.shape)
    sample_index_disjoint = []
    ep_yTrain = np.ones((len(yTrain_original), 3329)) * epsilon
    KL_GT = []
    for i in range(len(yTrain_original)):
        ep_yTrain[yTrain_original[i]] = 1 - (3329 - 1) * epsilon
        p_x = input_preds[i]
        q_x = ep_yTrain[yTrain_original[i]]
        #ep_yTrain[yTrain_original[i]] = 1
        #q_x = torch.Tensor(ep_yTrain[yTrain_original[i]])
        #KLD = gs_div(p_x, q_x, alpha=-1, lmd=0.5).detach().cpu().numpy()
        KLD = stats.entropy(p_x, q_x) #+ KL_GT[i]
        KL_GT.append(np.abs(KLD))
        #ep_yTrain.append(epsilon_onehot(yTrain, epsilon))

    for key_sample in tqdm(xTest_multi):
        sample = np.expand_dims(key_sample, axis = 2)
        print(sample.shape)
        preds = model.predict(sample)
        print(preds.shape)
        sample_key = []
        sample_indexes = []
        #For each target trace
        chosen_indexes = []
        curr_indexes = np.arange(len(input_preds))
        for pred in preds:
            l_x = pred
            KL_score = []
            curr_input = input_preds[curr_indexes]
            for i in range(len(curr_input)):
                g_x = curr_input[i]
                #q_x = ep_yTrain[i]
                #print(l_x)
                #print(g_x)
                #print(g_x.shape)
                #print(l_x.shape)
                #print(np.sum(l_x))
                #KLD = kl_divergence(l_x, g_x)
                #KLD = KL(l_x, g_x)
                KLD = stats.entropy(l_x, g_x) * KL_GT[i]
                #print(KLD)
                #print(kld)
                #exit()
                KL_score.append(KLD)
            KL_score = np.array(KL_score)
            #print('---------------')
            #print(len(sim_score))
            sample_index = np.argpartition(KL_score, num_sample)[:num_sample]
            chosen_indexes = np.concatenate((chosen_indexes, curr_indexes[sample_index]))
            sample_key.append(curr_indexes[sample_index])
            sample_indexes.append(curr_indexes[sample_index])
            curr_indexes = np.setdiff1d(curr_indexes, chosen_indexes)
            #print(len(curr_indexes))
            #print(len(chosen_indexes))
        print(len(np.unique(chosen_indexes)))
        sample_index_disjoint.append(chosen_indexes)

    print(len(sample_indexes))
    sample_indexes = np.unique(sample_indexes)
    sample_index_disjoint = np.array(sample_index_disjoint)
    print(len(sample_indexes))
    print(len(sample_index_disjoint))
    print(sample_index_disjoint.shape)
    return sample_indexes, sample_index_disjoint

def get_similar_alpha_KL_GT_geo(model, xTrain_original, yTrain_original, xTest_multi, yTest_multi, num_sample, cont_file):
    '''
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    
    input_preds = model.predict(val)
    '''
    device = 'cuda'
    train_data = [xTrain_original, yTrain_original]
    SCAdataset = SCADataset(train_data)
    train_loader = DataLoader(SCAdataset, batch_size=256, shuffle=True)
    #exit()
    '''
    XX = model.input 
    YY = model.layers[-2].output
    feature_extractor = Model(XX, YY)

    #Xaug = xTrain_original[:16]
    #Xresult = feature_extractor.predict(Xaug)
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    '''
    input_preds = []
    
    model = model.to(device)
    for batch_idx, (trace_data, target) in tqdm(enumerate(train_loader)):
        trace_data = trace_data.to(device)
        # Run a forward pass
        out = model(trace_data)
        # Get near-last layer output
        near_last_output = out.detach().cpu().numpy()
        input_preds.append(near_last_output)

    input_preds = np.concatenate(input_preds, axis=0)
    print(input_preds.shape)

    epsilon = 1e-10
    #D1, D2, GT
    #Calculate KL(D1,D2) then a_KL(D1, D2):
    print(xTest_multi.shape)
    sample_index_disjoint = []
    ep_yTrain = np.ones((len(yTrain_original), 3329)) * epsilon
    KL_GT = []
    curr_key = 0
    if cont_file is not None:
        print('ZZZ')
        dj_file = np.load(args.cont_file)
        curr_key = dj_file.shape[0]
        
    for i in tqdm(range(len(yTrain_original))):
        #ep_yTrain[yTrain_original[i]] = 1 - (3329 - 1) * epsilon
        p_x = torch.Tensor(input_preds[i])
        q_x = ep_yTrain[yTrain_original[i]]
        ep_yTrain[yTrain_original[i]] = 1
        q_x = torch.Tensor(ep_yTrain[yTrain_original[i]])
        KLD = gs_div(p_x, q_x, alpha=-1, lmd=0.5).detach().cpu().numpy()
        #KLD = stats.entropy(p_x, q_x) #+ KL_GT[i]
        KL_GT.append(np.abs(KLD))
        #ep_yTrain.append(epsilon_onehot(yTrain, epsilon))

    idx = 0
    for i in tqdm(range(curr_key, len(xTest_multi))):
        key_sample = xTest_multi[i]
        #sample = np.expand_dims(key_sample, axis = 2)
        sample = torch.from_numpy(key_sample)
        sample = sample.to(device).float()
        print(sample.shape)
        preds = model(sample)
        preds = preds.detach().cpu().numpy()
        print(preds.shape)
        sample_key = []
        sample_indexes = []
        chosen_indexes = []
        curr_indexes = np.arange(len(input_preds))
        for pred in tqdm(preds):
            l_x = pred
            KL_score = []
            curr_input = input_preds[curr_indexes]
            for i in range(len(curr_input)):
                g_x = torch.Tensor(curr_input[i])
                l_x = torch.Tensor(l_x)
                #q_x = ep_yTrain[i]
                #print(l_x)
                #print(g_x)
                #print(g_x.shape)
                #print(l_x.shape)
                #print(np.sum(l_x))
                #KLD = kl_divergence(l_x, g_x)
                KLD = np.abs(gs_div(l_x, g_x, alpha=-1, lmd=0.5).detach().cpu().numpy()) #* KL_GT[i] 
                #KLD = stats.entropy(l_x, g_x) * KL_GT[i]
                #print(KLD)
                #print(kld)
                #exit()
                KL_score.append(KLD)
            KL_score = np.array(KL_score)
            #print('---------------')
            #print(len(sim_score))
            sample_index = np.argpartition(KL_score, num_sample)[:num_sample]
            chosen_indexes = np.concatenate((chosen_indexes, curr_indexes[sample_index]))
            sample_key.append(curr_indexes[sample_index])
            sample_indexes.append(curr_indexes[sample_index])
            curr_indexes = np.setdiff1d(curr_indexes, chosen_indexes)
            #print(len(curr_indexes))
            #print(len(chosen_indexes))
        print(len(np.unique(chosen_indexes)))
        sample_index_disjoint.append(chosen_indexes)
        if idx % args.save_interval == 0:
            curr_disjoint = np.array(sample_index_disjoint)
            ref_fname = args.data_path[:-4]
            np.save(os.path.join(args.train_folder, 'disjoint_{}_{}_{}_{}_{}_{}'.format(args.start_key, idx, args.num_trace, args.num_sample, args.sampling_type, ref_fname)), curr_disjoint)
        idx +=1

    print(len(sample_indexes))
    sample_indexes = np.unique(sample_indexes)
    sample_index_disjoint = np.array(sample_index_disjoint)
    print(len(sample_indexes))
    print(len(sample_index_disjoint))
    print(sample_index_disjoint.shape)
    return sample_indexes, sample_index_disjoint


def get_similar_alpha_KL_GT_geo_set(model, xTrain_original, yTrain_original, xTest_multi, yTest_multi, num_sample):
    
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    
    input_preds = model.predict(val)

    epsilon = 1e-10
    #D1, D2, GT
    #Calculate KL(D1,D2) then a_KL(D1, D2):
    print(xTest_multi.shape)
    sample_index_disjoint = []
    ep_yTrain = np.ones((len(yTrain_original), 3329)) * epsilon
    KL_GT = []
    for i in tqdm(range(len(yTrain_original))):
        #ep_yTrain[yTrain_original[i]] = 1 - (3329 - 1) * epsilon
        p_x = torch.Tensor(input_preds[i])
        q_x = ep_yTrain[yTrain_original[i]]
        ep_yTrain[yTrain_original[i]] = 1
        q_x = torch.Tensor(ep_yTrain[yTrain_original[i]])
        KLD = gs_div(p_x, q_x, alpha=-1, lmd=0.5).detach().cpu().numpy()
        #KLD = stats.entropy(p_x, q_x) #+ KL_GT[i]
        KL_GT.append(np.abs(KLD))
        #ep_yTrain.append(epsilon_onehot(yTrain, epsilon))

    for key_sample in tqdm(xTest_multi):
        sample = np.expand_dims(key_sample, axis = 2)
        print(sample.shape)
        preds = model.predict(sample)
        print(preds.shape)
        sample_key = []
        sample_indexes = []
        chosen_indexes = []
        curr_indexes = np.arange(len(input_preds))
        for pred in preds:
            print('ZZZ')
            l_x = pred
            KL_score = []
            curr_input = input_preds[curr_indexes]
            for i in range(len(curr_input)):
                g_x = torch.Tensor(curr_input[i])
                l_x = torch.Tensor(l_x)
                #q_x = ep_yTrain[i]
                #print(l_x)
                #print(g_x)
                #print(g_x.shape)
                #print(l_x.shape)
                #print(np.sum(l_x))
                #KLD = kl_divergence(l_x, g_x)
                KLD = np.abs(gs_div(l_x, g_x, alpha=-1, lmd=0.5).detach().cpu().numpy()) #* KL_GT[i] 
                #KLD = stats.entropy(l_x, g_x) * KL_GT[i]
                #print(KLD)
                #print(kld)
                #exit()
                KL_score.append(KLD)
            KL_score = np.array(KL_score)
            #print('---------------')
            #print(len(sim_score))
            sample_index = np.argpartition(KL_score, num_sample)[:num_sample]
            chosen_indexes = np.concatenate((chosen_indexes, curr_indexes[sample_index]))
            sample_key.append(curr_indexes[sample_index])
            sample_indexes.append(curr_indexes[sample_index])
            curr_indexes = np.setdiff1d(curr_indexes, chosen_indexes)
            print(chosen_indexes[:10])
            exit()
            #print(len(curr_indexes))
            #print(len(chosen_indexes))
        print(len(np.unique(chosen_indexes)))
        sample_index_disjoint.append(chosen_indexes)

    print(len(sample_indexes))
    sample_indexes = np.unique(sample_indexes)
    sample_index_disjoint = np.array(sample_index_disjoint)
    print(len(sample_indexes))
    print(len(sample_index_disjoint))
    print(sample_index_disjoint.shape)
    return sample_indexes, sample_index_disjoint



def get_similar_samples_KNN_Feature(model, xTrain_original, yTrain_original, xTest_multi, yTest_multi, num_sample):

    device = 'cuda'
    print(xTrain_original.shape)
    model.eval()
    train_data = [xTrain_original, yTrain_original]
    SCAdataset = SCADataset(train_data)
    train_loader = DataLoader(SCAdataset, batch_size=256, shuffle=True)
    #exit()
    '''
    XX = model.input 
    YY = model.layers[-2].output
    feature_extractor = Model(XX, YY)

    #Xaug = xTrain_original[:16]
    #Xresult = feature_extractor.predict(Xaug)
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    '''
    input_preds = []
    
    model = model.to(device)
    for batch_idx, (trace_data, target) in tqdm(enumerate(train_loader)):
        activations = {}
        def hook_fn(module, input, output):
            activations['near_last'] = output
        model.fc9.register_forward_hook(hook_fn)
        trace_data = trace_data.to(device)
        # Run a forward pass
        out = model(trace_data)
        # Get near-last layer output
        near_last_output = activations['near_last'].detach().cpu().numpy()
        input_preds.append(near_last_output)

    input_preds = np.concatenate(input_preds, axis=0)
    print(input_preds.shape)
    #epsilon = 1e-10
    #D1, D2, GT
    #Calculate KL(D1,D2) then a_KL(D1, D2):
    print(xTest_multi.shape)
    sample_index_disjoint = []

    for key_sample in tqdm(xTest_multi):
        activations = {}
        def hook_fn(module, input, output):
            activations['near_last'] = output
        model.fc9.register_forward_hook(hook_fn)
        #sample = np.expand_dims(key_sample, axis = 2)
        sample = torch.from_numpy(key_sample)
        sample = sample.to(device).float()
        out = model(sample)
        preds = activations['near_last'].detach().cpu().numpy()
        print(preds.shape)
        sample_key = []
        sample_indexes = []
        chosen_indexes = []
        curr_indexes = np.arange(len(input_preds))
        for pred in preds:
            l_x = pred
            sim_score = []
            curr_input = input_preds[curr_indexes]
            for i in range(len(curr_input)):
                #g_x = torch.Tensor(input_preds[i])
                #l_x = torch.Tensor(l_x)
                sim_dist = distance.euclidean(curr_input[i], l_x)
                sim_score.append(sim_dist)
            sim_score = np.array(sim_score)
            #print('---------------')
            #print(len(sim_score))
            sample_index = np.argpartition(sim_score, num_sample)[:num_sample]
            chosen_indexes = np.concatenate((chosen_indexes, curr_indexes[sample_index]))
            sample_key.append(curr_indexes[sample_index])
            sample_indexes.append(curr_indexes[sample_index])
            curr_indexes = np.setdiff1d(curr_indexes, chosen_indexes)
            #print(len(curr_indexes))
            #print(len(chosen_indexes))
        print(len(np.unique(chosen_indexes)))
        sample_index_disjoint.append(chosen_indexes)
    sample_index_disjoint = np.array(sample_index_disjoint)
    print(sample_index_disjoint.shape)
    return sample_index_disjoint, sample_index_disjoint

def get_similar_alpha_KL_by_set(model, xTrain_original, yTrain_original, xTest_multi, yTest_multi, num_sample):
    
    num_set = 5
    with tf.device("CPU"):
        val = tf.data.Dataset.from_tensor_slices(xTrain_original).batch(192)
    
    input_preds = model.predict(val)
    print(input_preds.shape)
    epsilon = 1e-10
    #Set instances
    all_set = []
    all_labels = []
    all_indexes= []
    all_argmax = []
    all_GT = []
    num_trace = len(xTest_multi[0])
    print(num_trace)

    for i in tqdm(range(3329)):
        key_index = np.where(yTrain_original == i)[0]
        instance_idxs = iterative_key(key_index, num_set)
        for instance_idx in instance_idxs:
            all_indexes.append(instance_idx)
            p_x = cul_mul(input_preds[instance_idx,:])
            #print(np.sum(p_x))
            p_x = norm_data(p_x)
            q_x = np.zeros(3329)
            q_x[i] = 1
            #print(np.sum(p_x))
            #print(p_x)
            KLD = gs_div(torch.Tensor(p_x), torch.Tensor(q_x), alpha=-1, lmd=0.5).detach().cpu().numpy()
            all_GT.append(KLD)
            all_argmax.append(np.argmax(p_x))
            all_set.append(p_x)
            all_labels.append(i)

    all_indexes = np.array(all_indexes)
    all_labels = np.array(all_labels)
    print(len(all_labels))
    
    unique, counts = np.unique(all_argmax, return_counts=True)
    print(len(unique))
    print(np.std(counts))
    print(np.min(counts))
    print(np.max(counts))
    sample_index_disjoint = []
    idx = 0
    for key_sample in tqdm(xTest_multi):
        #print('------------------------')
        #print(yTest_multi[idx])
        curr_key = yTest_multi[idx]
        idx += 1
        sample = np.expand_dims(key_sample, axis = 2)
        #print(sample.shape)
        preds = model.predict(sample)
        q_x = norm_data(cul_mul(preds))
        #print(np.argmax(q_x))
        #print(sum(q_x))
        KL_score = []
        #exit()
        for i in range(len(all_set)):
            p_x = all_set[i]
            #KLD = stats.entropy(p_x, q_x) #+ KL_GT[i]
            #KLD = stats.entropy(torch.Tensor(p_x), torch.Tensor(q_x))
            KLD = gs_div(torch.Tensor(p_x), torch.Tensor(q_x), alpha=-1, lmd=0.5).detach().cpu().numpy()
            #print(KLD)
            #KLD = KLD * all_GT[i]
            #print(kld)
            KL_score.append(np.abs(KLD))
        KL_score = np.array(KL_score)
        #print(np.mean(KL_score))
        #print(np.max(KL_score))
        #print(np.min(KL_score))
        #print('---------------')
        #print(len(sim_score))
        sample_index = np.argpartition(KL_score, num_sample)[:num_sample]
        sample_index = np.array(sample_index).astype(int)#
        all_index = all_indexes[sample_index]
        print(all_index.shape)
        '''
        print(sample_index)
        print(sample_index.shape)
        all_index = all_indexes[sample_index].flatten()
        print(all_index.shape)
        all_lb = all_labels[sample_index]
        print('----')
        print(np.unique(all_lb))
        print(len(np.unique(all_lb)))
        print('----')
        #if curr_key in all_lb:
        #    print('yes')
        print(all_lb)
        print(all_set[sample_index[0]])
        print(q_x)
        print(sum(all_set[sample_index[0]]))
        print(sum(q_x))
        print(np.argmax(all_set[sample_index[0]]))
        #exit()
        '''
        sample_index_disjoint.append(all_index)
    sample_index_disjoint = np.array(sample_index_disjoint).astype(int)
    print(sample_index_disjoint.shape)

    return all_indexes, sample_index_disjoint

from sklearn.preprocessing import StandardScaler
from scipy import stats

def normalize(timeseries):
    return (timeseries-timeseries.min())/(timeseries.max()-timeseries.min())

def z_norm(timeseries):
    '''
    timeseries = np.expand_dims(timeseries, 1)
    scaler = StandardScaler()
    scaler.fit(timeseries)
    res = scaler.transform(timeseries)[:,0]
    '''

    return stats.zscore(timeseries)

def normalize_data_per_trace(data):
    print(data.shape)
    for j in range(len(data)):
        for i in range(len(data[j])):
            data[j][i] = normalize(data[j][i])

    return data

def normalize_data_per_trace_train(data):
    print(data.shape)
    for j in range(len(data)):
        data[j] = normalize(data[j])

    return data

def z_normalize_data_per_trace(data):
    print(data.shape)
    for j in range(len(data)):
        for i in range(len(data[j])):
            data[j][i] = z_norm(data[j][i])

    return data

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=512, kernel_size=3, padding='same')
        self.pool1 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn1 = nn.BatchNorm1d(512)
        
        self.conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3, padding='same')
        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn2 = nn.BatchNorm1d(256)
        
        self.conv3 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding='same')
        self.pool3 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn3 = nn.BatchNorm1d(128)
        
        self.conv4 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, padding='same')
        self.pool4 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn4 = nn.BatchNorm1d(64)
        
        # Fully connected layers
        self.fc1 = nn.Linear(448, 1024)
        self.bn5 = nn.BatchNorm1d(1024)
        self.dropout1 = nn.Dropout(0.2)
        
        self.fc2 = nn.Linear(1024, 512)
        self.bn6 = nn.BatchNorm1d(512)
        
        self.fc3 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(0.2)
        
        self.fc4 = nn.Linear(256, 128)
        self.bn8 = nn.BatchNorm1d(128)
        
        self.fc5 = nn.Linear(128, 1024)
        self.bn9 = nn.BatchNorm1d(1024)
        self.dropout3 = nn.Dropout(0.2)
        
        self.fc6 = nn.Linear(1024, 1024)
        self.bn10 = nn.BatchNorm1d(1024)
        self.dropout4 = nn.Dropout(0.2)
        
        self.fc7 = nn.Linear(1024, 512)
        self.bn11 = nn.BatchNorm1d(512)
        
        self.fc8 = nn.Linear(512, 256)
        self.bn12 = nn.BatchNorm1d(256)
        self.dropout5 = nn.Dropout(0.2)
        
        self.fc9 = nn.Linear(256, 128)
        self.bn13 = nn.BatchNorm1d(128)
        
        self.fc10 = nn.Linear(128, 3329)
    
    def forward(self, x):
        # Add channel dimension
        x = x.unsqueeze(1)
        
        # Convolutional layers
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.bn1(x)
        
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.bn2(x)
        
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        x = self.bn3(x)
        
        x = F.relu(self.conv4(x))
        x = self.pool4(x)
        x = self.bn4(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.bn5(x)
        x = self.dropout1(x)
        
        x = F.relu(self.fc2(x))
        x = self.bn6(x)
        
        x = F.relu(self.fc3(x))
        x = self.bn7(x)
        x = self.dropout2(x)
        
        x = F.relu(self.fc4(x))
        x = self.bn8(x)
        
        x = F.relu(self.fc5(x))
        x = self.bn9(x)
        x = self.dropout3(x)
        
        x = F.relu(self.fc6(x))
        x = self.bn10(x)
        x = self.dropout4(x)
        
        x = F.relu(self.fc7(x))
        x = self.bn11(x)
        
        x = F.relu(self.fc8(x))
        x = self.bn12(x)
        x = self.dropout5(x)
        
        x = F.relu(self.fc9(x))
        x = self.bn13(x)
        
        x = self.fc10(x)
        
        return x

class SCADataset():
  
    def __init__(self, data):
        
        self.x = data[0].astype(np.float32)
        self.y = data[1]
        self.n_samples = data[0].shape[0] 
      
    # support indexing such that dataset[i] can 
    # be used to get i-th sample
    def __getitem__(self, index):
        return self.x[index], self.y[index]
        
    # we can call len(dataset) to return the size
    def __len__(self):
        return self.n_samples

parser = parse_arguments()
args = parser.parse_args()
data = np.load('data.npz')

NUM_SAMPLE = 200000

data_ids = np.load('attack_300key_ids.npz')['ids']
sorted_ids = np.sort(np.hstack(data_ids))
non_overlap_ids = np.setdiff1d(np.arange(300000), sorted_ids) + NUM_SAMPLE
print(sorted_ids.shape)
print(non_overlap_ids.shape)
print(np.max(non_overlap_ids))
addtional_ids = non_overlap_ids[:args.add_num]
init_ids = np.arange(NUM_SAMPLE)
all_ids = np.concatenate((np.arange(NUM_SAMPLE), addtional_ids))
print(len(all_ids))
print(all_ids[:10])

xTrain_original = data['data'][all_ids]
labels = data['label'][all_ids]
#xTest_multi, yTest_multi = load_multi_attack('attack_multi_data_300key.npz')
xTest_multi, yTest_multi = load_multi_attack(args.data_path)

print(xTrain_original.shape)
print(xTest_multi.shape)

if args.normalize == 1:
    print('Normalize')
    xTest_multi = normalize_data_per_trace(xTest_multi)
    xTrain_original = normalize_data_per_trace_train(xTrain_original)
ref_fname = args.data_path[:-4]

xTrain_original = np.expand_dims(xTrain_original, axis = 2)
print(xTrain_original.shape)
print(xTest_multi.shape)


if args.is_wavelet == 1:
    num_sample_per_key = 100
    data = np.load('data_wavelet.npz')
    xTrain_original = np.expand_dims(data['cA1'][:NUM_SAMPLE], axis = 2)
    test_data = data['cA1'][200000:]
    test_labels = data['label'][200000:]
    print(test_data.shape)
    xTest_multi = []
    for lb in yTest_multi:
        lb_indexes = np.where(test_labels == lb)[0]
        xTest_multi.append(test_data[lb_indexes[:num_sample_per_key]])

    xTest_multi = np.array(xTest_multi)
    print(xTrain_original.shape)
    print(xTest_multi.shape)


input_data = xTrain_original[:,:,0]
ref_data = xTest_multi[args.start_key:args.end_key,:args.num_trace,:]
model = CNNModel()
model.load_state_dict(torch.load(os.path.join(args.train_folder, 'model_800.pt'), weights_only=True))
model.eval()#model.summary()


print(args.sampling_type)

if args.sampling_type == 'euclid':
    sampled_ids, disjoint_ids = get_similar_samples(input_data, ref_data, args.num_sample)
elif args.sampling_type == 'KL':
    sampled_ids, disjoint_ids = get_similar_alpha_KL_GT(model, xTrain_original, labels, ref_data, yTest_multi, args.num_sample)
elif args.sampling_type == 'KL_geo':
    sampled_ids, disjoint_ids = get_similar_alpha_KL_GT_geo(model, input_data, labels, ref_data, yTest_multi, args.num_sample, args.cont_file)
elif args.sampling_type == 'KL_set':
    sampled_ids, disjoint_ids = get_similar_alpha_KL_by_set(model, xTrain_original, labels, ref_data, yTest_multi, args.num_sample, args.cont_file)
elif args.sampling_type == 'KNN_feature':
    sampled_ids, disjoint_ids = get_similar_samples_KNN_Feature(model, input_data, labels, ref_data, yTest_multi, args.num_sample)
    

#sampled_ids, disjoint_ids = get_similar_alpha_KL_GT_geo(model, xTrain_original, labels, ref_data, yTest_multi, args.num_sample)
fname='ids_{}_{}_{}_{}'.format(args.start_key, args.end_key, args.num_trace, args.num_sample)
np.save(os.path.join(args.train_folder, fname), sampled_ids)
np.save(os.path.join(args.train_folder, 'disjoint_{}_{}_{}_{}_{}_{}_{}'.format(args.start_key, args.end_key, args.num_trace, args.num_sample, args.sampling_type, ref_fname, args.add_num)), disjoint_ids)

'''
for i in range(len(disjoint_ids)):
    fname = 'disjoint_{}'.format(i)
    np.save(os.path.join(args.train_folder, fname), sampled_ids)
'''
'''
neigh = KNeighborsClassifier(n_neighbors=3)
neigh.fit(X, y)
print(neigh.predict([[1.1]]))
print(neigh.predict_proba([[0.9]]))
'''