import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import os
from pathlib import Path
import matplotlib.pyplot as plt
import os.path
import sys
import h5py
import math
import gc
import numpy as np
import matplotlib.pyplot as plt
import argparse
from pathlib import Path
from sklearn.cluster import KMeans
import pandas as pd

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--train_type', type=str, help='baseline or active')
    parser.add_argument('--sampling', type=str, help='random, minmax or uncertainty')
    parser.add_argument('--name', type=str, help='experiment name', default='test')
    parser.add_argument('--xType', type=str, help='number of ciphertext')
    parser.add_argument('--start_trace', type=int, help='start trace')
    parser.add_argument('--end_trace', type=int)
    parser.add_argument('--batch_size', type=int, help='batch_size', default=256)
    parser.add_argument('--num_epoch', type=int, help='batch_size', default=256)
    parser.add_argument('--trained_model_path', type=str)
    parser.add_argument('--num_iteration', type=int, help='iteration_num', default=5)
    parser.add_argument('--all_ids', type=str)
    parser.add_argument('--eval_path', type=str)
    parser.add_argument('--num_sample', type=int, help='iteration_num', default=5)
    parser.add_argument('--eval_interval', type=int, help='iteration_num', default=1)
    parser.add_argument('--seed', type=int, default = 0)
    parser.add_argument('--normalize', type=int, default = 0)
    parser.add_argument('--transformer', type=int, default = 0)
    parser.add_argument('--mlp', type=int, default = 0)
    parser.add_argument('--cummulative', type=int, default = 0)
    parser.add_argument('--use_bp', type=int, default = 0)
    parser.add_argument('--lr', type=float, default = 0.00001)

    parser.add_argument('--use_dyn', type=int, default = 0)

    parser.add_argument('--n_neighbor_post', type=int, default = 20)
    parser.add_argument('--jtt_multiplier', type=int, default=1, help='multiplier hyperparams for JTT' )
    parser.add_argument('--jtt_threshold', type=int, default=1, help='threshold hyperparams for JTT' )
    parser.add_argument('--gamma', type=float, default = -1)
    parser.add_argument('--coreset_ratio', type=float, default = 0.1)
    parser.add_argument('--coreset_key', type=str, default = 'entropy')
    parser.add_argument('--mis_ratio', type=float, default = 0.1)
    parser.add_argument('--mis_key', type=str, default = 'entropy')
    parser.add_argument('--label_balanced', type=bool, default = True)

    parser.add_argument('--node', type=int, default=256)
    parser.add_argument('--n_layers', type=int, default=6)
    parser.add_argument('--batch_norm', type=int, default=1)
    parser.add_argument('--dropout', type=int, default=1)
    parser.add_argument('--dropout_rate', type=float, default=0.2)

    return parser

parser = parse_arguments()
args = parser.parse_args()

def set_seeds(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    #random.seed(seed)
    np.random.seed(seed)

set_seeds(args.seed)

bp_range = [0, 3329]
skpv_range = [0, 3328]
fqmul_range = [-1828, 1664]
tracelen = 600
NumFQMULclasses = fqmul_range[1] - fqmul_range[0] + 1;  # number of classes for fqmul(skpv, bp)
NumSKPVclasses = skpv_range[1] - skpv_range[0] + 1;     # number of classes for skpv
NumBPinput = bp_range[1] - bp_range[0] + 1;             # number of input for bp (ciphertext)
noClasses = NumSKPVclasses
noHypoKeys = NumSKPVclasses
sKeyNo = 0  # Note: sKeyNo is in range 0 to 3 and which subkeys are they are decided by code in m4 (NOT by code in PC)
work = 'train' #'train'  #'attack'
training_file_list = ['Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data000000to099999_600samples.h5',\
'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data100000to199999_600samples.h5']
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data200000to299999_600samples.h5']#,\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data300000to399999_600samples.h5',\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data400000to499999_600samples.h5']

trained_model_path = args.trained_model_path
data_path = 'data.npz'
nruns_default = 10
maxtrc_default = 200
testPortion = 1
attack_byModel_epNo = 232


# training parameters
train_batch_size = args.batch_size#100#150#200#250#500#640 #80 for mars45 #170 for mars56
period = 8 #8
maxEpochs = args.num_epoch#3072#2048#1536#1280#1024#512#256 #1536
attack_byModel_fileNo = int(attack_byModel_epNo/period)
N_TRACE = 20000
Threshold_Save = 200

#model hyper-parameters
noConv1Dbranch = 1
noLayers = 6    # if newly train
noClassificationLayer = 1
GPU_clear = True    # False

# training data type
xType = args.xType  #'wave' #'wavebp0' #'wavebp1' #'wavebp01' #'wavebp01next0' #'wavebp01next01'
yType = 'skpv'    #'fqmul0' #'fqmul1' #'skpv' 
trainPortion = 0.8

# Database and logs for model and training progress (epochs)
attackModel = 'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1'
device = 'm4_CWLite'
attackModel_dev = attackModel + '_' + device
attackModel_dev_folder = '../' + attackModel_dev + '/'

MLmodelStruct = '4C4FC_2BP4FC4FC_J4FCSM'
#MLmodel_detail = '3C[512_128_64]_2BP4FC[1024_512_256_128]4FC[1024_512_256_128]_J4FC[1024_512_256_128]SM'
MLmodel_detail = '4C/512_256_128_64/_2BP4FC/1024_512_256_128/4FC/1024_512_256_128/_J4FC/1024_512_256_128/SM'

hyper_ver = 'hy0001010101_skpv0'    #hyper-parameter contains 5 groups: Conv1D, FC for Conv1D, BP0, BP1, FC for joined BPs
#dataFile_train = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'  #'20kDatax25'
dataFile_train_folder = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'    #'20kDatax25'
dataFile_attack = '100kDatax1_test'#'skvp0_0_700points100kDatax1attack' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'  #'20kDatax25'
model_input_type = '_in[[][]]_tf2' #'in[]_tf2' #'[[][]]_tf2'
#data_type = dataFile_train + model_input_type
data_type = '100kDataxN' + str(len(training_file_list)) + model_input_type
#database_folder_train = attackModel_dev_folder + attackModel + '_' + dataFile_train_folder + '_h5/'
save_path = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_lr_{}_mlp_{}'.format(args.name ,args.train_type, args.sampling ,xType, args.start_trace, args.end_trace, args.num_sample, args.seed, args.normalize, args.cummulative, args.lr, args.mlp)
print(save_path)
#database_folder_train = os.path.join('trained_models', save_path)
database_folder_train = os.path.join('multi_attack_trained_models', save_path)
Path(database_folder_train).mkdir(parents=True, exist_ok=True)
database_folder_attack = attackModel_dev_folder + attackModel + '_' + dataFile_attack + '_h5/'
logFilename = MLmodelStruct + '_' + hyper_ver
DLmodel_name = logFilename
#DLmodel_folder = attackModel_dev_folder + logFilename + '_' + data_type + '/'
DLmodel_folder = 'models/'
modelLogFolder = DLmodel_folder + 'log' + DLmodel_name + '/'
logTrainedModel_byFile_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byDataFile/'
#logTrainedModel_byEp_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byEpoch/'
logTrainedModel_byEp_folder = logTrainedModel_byFile_folder
attackLogFolder = DLmodel_folder + 'log' + DLmodel_name + '_attack/'
if os.path.isdir(DLmodel_folder) == False:
    os.mkdir(DLmodel_folder)
if os.path.isdir(modelLogFolder) == False:
    os.mkdir(modelLogFolder)
if os.path.isdir(logTrainedModel_byFile_folder) == False:
    os.mkdir(logTrainedModel_byFile_folder)
if os.path.isdir(logTrainedModel_byEp_folder) == False:
    os.mkdir(logTrainedModel_byEp_folder)
print('DLmodel_folder =', DLmodel_folder)
print('modelLogFolder =', modelLogFolder)
print('logTrainedModel_byFile_folder =', logTrainedModel_byFile_folder)
print('logTrainedModel_byEp_folder =', logTrainedModel_byEp_folder)


################################################################################################
####################################### MODELS STRUCTURE #######################################
################################################################################################
# Input BatchNormalization for each PoI size
#                           subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_inputBNorms =   [   1,      0,      0,      0,      0,      0]
###################### MULTI CONVOLUTIONAL-SIZE CONVOLUTION ######################
# Convolutional nodes
# matrix showing number of nodes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_NoConvNodes =   [   512,    256,    128,    64,     0,      0]
# Convolutional filter sizes
# matrix showing filter sizes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convKernelSizes = [   3,      3,      3,      3,      0,      0] # subModel0

###############################################
# Pooling size in convolutional layers
# matrix showing MaxPooling sizes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convPoolSizes = [   2,      2,      2,      2,      0,      0] # subModel0

# Pooling stride in convolutional layers
# matrix showing MaxPooling strike in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convPoolStrides = [   3,      3,      3,      3,      0,      0] # subModel0

# BatchNormalization in convolutional layers
# matrix showing BatchNormalization condition in each convolutional layer in each PoI length
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_convBNorms = [   1,      1,      1,      1,      0,      0] # subModel0
# Dropout in convolutional layers
# matrix showing Dropout value in each convolutional layer in each PoI length
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_convDrops = [   0,      0,      0,      0,      0,      0] # subModel0

###################### MULTI CONVOLUTIONAL-SIZE FULLY-CONNECTED ######################
# Flatten Convolutional feature map before Fully connected
#                           subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_convFeatFlat = [    1,      0,      0,      0,      0,      0]
# Fully-connected for convolutional value before adding Plaintext
# matrix showing fully-connected condition before adding Plaintext
#                   layer0  layer1  layer2  layer3  layer4  layer5
subMods_FCs = [   1024,   512,    256,    128,    0,      0] # subModel0
# BatchNormalization for fully-connected of convolutional value before adding Plaintext
# matrix showing BatchNormalization for fully-connected condition before adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_FC_BNorms = [   1,      1,      1,      1,      0,      0] # subModel0

# Dropout for fully-connected of convolutional value before adding Plaintext
# matrix showing Dropout for fully-connected condition before adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_FC_Drops = [   0.2,    0,      0.2,    0,      0,      0] # subModel0

###################### MULTI_CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENSION ######################
# Plaintext adding here
if xType == 'wave':
    noBPbranch = 0
    #                       sub0    sub1    sub2    sub3    sub4    sub5
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp0':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0] # conv1D branch 0
elif xType == 'wavebp1':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01':
    noBPbranch = 2
    subMods_Pext =  [  1,      1,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01next0':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01next01':
    noBPbranch = 4
    subMods_Pext =  [  1,      1,      1,      1,      0,      0]  # conv1D branch 0


###################### (MULTI CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENDED) FULLY-CONNECTED ######################
# Fully-connected for convolutional value after adding Plaintext
# matrix showing fully-connected condition after adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FCs = [   [   1024,   1024,   512,    256,    128,    0], # subModel0
                        [   1024,   1024,   512,    256,    128,    0], # subModel1
                        [   1024,   1024,   512,    256,    128,    0], # subModel2
                        [   1024,   1024,   512,    256,    128,    0], # subModel3
                        [   0,      0,      0,      0,      0,      0], # subModel4
                        [   0,      0,      0,      0,      0,      0]]    # subModel5
# BatchNormalization for fully-connected of convolutional value after adding Plaintext
# matrix showing BatchNormalization for fully-connected condition after adding Plaintext
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FC_BNorms = [ [   1,      1,      1,      1,      1,      0], # subModel0
                            [   1,      1,      1,      1,      1,      0], # subModel1
                            [   1,      1,      1,      1,      1,      0], # subModel2
                            [   1,      1,      1,      1,      1,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]]    # subModel5
# Dropout for fully-connected of convolutional value after adding Plaintext
# matrix showing Dropout for fully-connected condition after adding Plaintext
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FC_Drops = [  [   0.2,    0.2,    0,      0.1,    0,      0], # subModel0
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel1
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel2
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]]    # subModel5

# Softmax for each sub-model if available
#                               subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_classification =    [  0,      0,      0,      0,      0,      0]

if xType == 'wave':
    subMods_join =  [   0]
else:
    subMods_join =  [   1]  

subMods_join_FCs =  [  1024,   1024,   512,    256,    128,    0]
# BatchNormalization for fully-connected of convolutional value after joining PoIs
# matrix showing BatchNormalization for fully-connected condition after joining PoIs
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_join_FC_BNorms =    [  1,      1,      1,      1,      1,      0]
# Dropout for fully-connected of convolutional value after joining PoIs
# matrix showing Dropout for fully-connected condition after joining PoIs
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_join_FC_Drops = [  0.2,        0.2,        0,      0.1,        0,      0]

# Softmax for joined-model if available
subMods_join_classification =   [   1]

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=512, kernel_size=3, padding='same')
        self.pool1 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn1 = nn.BatchNorm1d(512)
        
        self.conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3, padding='same')
        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn2 = nn.BatchNorm1d(256)
        
        self.conv3 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding='same')
        self.pool3 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn3 = nn.BatchNorm1d(128)
        
        self.conv4 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, padding='same')
        self.pool4 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn4 = nn.BatchNorm1d(64)
        
        # Fully connected layers
        self.fc1 = nn.Linear(448, 1024)
        self.bn5 = nn.BatchNorm1d(1024)
        self.dropout1 = nn.Dropout(0.2)
        
        self.fc2 = nn.Linear(1024, 512)
        self.bn6 = nn.BatchNorm1d(512)
        
        self.fc3 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(0.2)
        
        self.fc4 = nn.Linear(256, 128)
        self.bn8 = nn.BatchNorm1d(128)
        
        self.fc5 = nn.Linear(128, 1024)
        self.bn9 = nn.BatchNorm1d(1024)
        self.dropout3 = nn.Dropout(0.2)
        
        self.fc6 = nn.Linear(1024, 1024)
        self.bn10 = nn.BatchNorm1d(1024)
        self.dropout4 = nn.Dropout(0.2)
        
        self.fc7 = nn.Linear(1024, 512)
        self.bn11 = nn.BatchNorm1d(512)
        
        self.fc8 = nn.Linear(512, 256)
        self.bn12 = nn.BatchNorm1d(256)
        self.dropout5 = nn.Dropout(0.2)
        
        self.fc9 = nn.Linear(256, 128)
        self.bn13 = nn.BatchNorm1d(128)
        
        self.fc10 = nn.Linear(128, 3329)
    
    def forward(self, x):
        # Add channel dimension
        x = x.unsqueeze(1)
        
        # Convolutional layers
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.bn1(x)
        
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.bn2(x)
        
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        x = self.bn3(x)
        
        x = F.relu(self.conv4(x))
        x = self.pool4(x)
        x = self.bn4(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.bn5(x)
        x = self.dropout1(x)
        
        x = F.relu(self.fc2(x))
        x = self.bn6(x)
        
        x = F.relu(self.fc3(x))
        x = self.bn7(x)
        x = self.dropout2(x)
        
        x = F.relu(self.fc4(x))
        x = self.bn8(x)
        
        x = F.relu(self.fc5(x))
        x = self.bn9(x)
        x = self.dropout3(x)
        
        x = F.relu(self.fc6(x))
        x = self.bn10(x)
        x = self.dropout4(x)
        
        x = F.relu(self.fc7(x))
        x = self.bn11(x)
        
        x = F.relu(self.fc8(x))
        x = self.bn12(x)
        x = self.dropout5(x)
        
        x = F.relu(self.fc9(x))
        x = self.bn13(x)
        
        x = self.fc10(x)
        
        return x

import time

# Training loop (example)
def train(args, save_folder, model, train_loader, test_loader, optimizer, criterion, epochs=10):
    start_time = time.time()

    model.train()
    losses = []
    for epoch in range(epochs):
        train_loss = []
        val_loss = []
        for batch_idx, (trace_data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            target = target.long().to(device)
            trace_data = trace_data.to(device)
            output = model(trace_data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Train Loss: {loss.item()}')
            train_loss.append(loss.item())

        for batch_idx, (trace_data, target) in enumerate(test_loader):
            target = target.long().to(device)
            trace_data = trace_data.to(device)
            output = model(trace_data)
            loss = criterion(output, target)
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Val Loss: {loss.item()}')
            val_loss.append(loss.item())

        if epoch % args.eval_interval == 0:
            save_path = os.path.join(save_folder, 'model_{}.pt'.format(epoch))
            torch.save(model.state_dict(), save_path)

        losses.append({"Epoch": epoch + 1, "Train Loss": np.mean(train_loss), "Validation Loss": np.mean(val_loss)})

    df = pd.DataFrame(losses)
    df.to_csv(os.path.join(save_folder, "losses.csv"), index=False)
    print("---Training done in %s seconds ---" % (time.time() - start_time))

    return model

def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

def load_meta_trace_files(data_path, start_trace, end_trace):
    data = np.load(data_path)
    trace_profiling = data['data']
    bp_profiling = data['bp']
    skpv_profiling = data['label']

    return (trace_profiling[start_trace:end_trace], bp_profiling[:,start_trace:end_trace], skpv_profiling[start_trace:end_trace])#, fqmul_profiling)

def random_sampling(data, num_sample):
    #print(len(data))
    #print(data.shape)
    np.random.seed(args.seed)
    if args.xType == 'wavebp01':
        data = data[0]
    rand_ids = np.random.choice(len(data), num_sample, replace=False)
    print(len(rand_ids))
    print('---')
    return rand_ids

from sklearn.preprocessing import StandardScaler
from scipy import stats

def normalize(timeseries):
    return (timeseries-timeseries.min())/(timeseries.max()-timeseries.min())

def normalize_data_per_trace(data):
    print(data.shape)
    for i in range(len(data)):
        data[i] = normalize(data[i])

    return data

def create_training_data_optimize(args, data_path, sKeyNo, trainPortion, xType, yType, is_test, start_trace, end_trace):
    val_num = 300000
    val_ids = np.load('val_ids.npz')['ids']
    print(len(val_ids))
    print(val_ids[:10])
    end_val_trace = end_trace + val_num
    if is_test:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file_from_test(data_path, sKeyNo)
    else:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_files(data_path, start_trace, end_val_trace)

    #train_dataset = tf.data.Dataset.from_tensor_slices((trace_profiling[:150000], tf.one_hot(skpv_profiling[:150000], 3329) ))
    #exit()

    #Reshaped_trace_profiling= trace_profiling.reshape((trace_profiling.shape[0], trace_profiling.shape[1], 1))
    print(trace_profiling.shape)


    dataSize = end_trace - start_trace
    trainSize = math.floor(dataSize * trainPortion)
    print(skpv_profiling[:10])
    
    #y_train_skpv = to_categorical(skpv_profiling, num_classes=NumSKPVclasses)
    y_train_skpv = skpv_profiling
    if args.normalize == 1:
        trace_profiling = normalize_data_per_trace(trace_profiling)
    xTrain = trace_profiling[:dataSize]
    print(xTrain[0][:10])

    if is_test:
        xVal = trace_profiling[dataSize:end_val_trace]
        yVal = y_train_skpv[dataSize:end_val_trace]
        yVal_value = skpv_profiling[dataSize:end_val_trace]
    else:
        xVal = trace_profiling[val_ids]
        yVal = y_train_skpv[val_ids]
        yVal_value = skpv_profiling[val_ids]

    yTrain = y_train_skpv[:dataSize]
    yTrain_value = skpv_profiling[:dataSize]
    

    #del trace_profiling, bp_profiling, skpv_profiling
    return xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value

"""
Adapted from D2Pruning
"""

import copy
from sklearn.neighbors import kneighbors_graph
from sklearn.metrics import pairwise_distances
import abc
import numpy as np
import time
from collections import defaultdict
from tqdm import tqdm
import sys

class SamplingMethod(object):
  # From D2Pruning
  __metaclass__ = abc.ABCMeta

  @abc.abstractmethod
  def __init__(self, X, y, seed, **kwargs):
    self.X = X
    self.y = y
    self.seed = seed

  def flatten_X(self):
    shape = self.X.shape
    flat_X = self.X
    if len(shape) > 2:
      flat_X = np.reshape(self.X, (shape[0],np.product(shape[1:])))
    return flat_X


  @abc.abstractmethod
  def select_batch_(self):
    return

  def select_batch(self, **kwargs):
    return self.select_batch_(**kwargs)

  def to_dict(self):
    return None


class InfoGraphDensitySampler(SamplingMethod):
  """Adapted from GraphDensitySampler@(D2Pruning)
  """
  def __init__(self, X, y, seed, gamma=None, importance_scores=None, args=None):
    self.name = 'infomax_graph_density' 
    max_solver_it = 5
    self.args = args
    self.X = X.squeeze(0)
    if self.X is not None:
      self.flat_X = self.flatten_X()
    if gamma is not None:
      self.gamma = gamma
    else:
      self.gamma = 1. / self.X.shape[1]
    #self.graph_mode = args.graph_mode
    #self.graph_sampling_mode = args.graph_sampling_mode  
    self.Build_GCCG_Graph(args.n_neighbor_post, importance_scores, max_solver_it)
    
  def Build_GCCG_Graph(self, n_neighbor=10, importance_scores=None, max_solver_it=1):
    self.distances = pairwise_distances(self.flat_X, self.flat_X)
    self.graph_density = np.zeros(self.X.shape[0])
    print(self.X.shape)
    print(self.flat_X.shape)
    print(importance_scores.shape)
    print(n_neighbor)
    #if importance_scores is not None and self.graph_mode in ['sum', 'product']:
    if importance_scores is not None:
      epsilon = 0.0000001 
      n_samples = self.flat_X.shape[0] 
      if n_neighbor > self.flat_X.shape[0]:
            n_neighbor = self.flat_X.shape[0] -10
      #print(self.flat_X.shape[0])
      #print(n_neighbor)
      connect = kneighbors_graph(self.flat_X, n_neighbor,p=2)
      connect = connect.todense() 
      neighbors = connect.nonzero()
      inds = zip(neighbors[0], neighbors[1]) 
      for iteration in range(max_solver_it):
        for entry in inds:
            i = entry[0]
            j = entry[1] 
            distance = self.distances[i, j] 
            if iteration == 0:
                weight_j = (1.0-np.exp(-distance + epsilon)) * max(importance_scores[j].item(), epsilon)
                weight_i = (1.0-np.exp(-distance + epsilon)) * max(importance_scores[i].item(), epsilon)
            else:
                weight_j = (1.0-np.exp(-distance + epsilon)) * (self.graph_density[j].item()) 
                weight_i = (1.0-np.exp(-distance + epsilon)) * (self.graph_density[i].item()) 
            connect[i, j] = weight_j 
            connect[j, i] = weight_i 
        self.connect_post = connect
    self.connect = connect
    self.starting_density = copy.deepcopy(self.graph_density)

  def select_batch_(self, N, **kwargs):
    self.connect_post = self.connect
    # Here, we suggest to use the greedy selection strategy as a post-processing step 
    # (we find that it is more stable than the just pick out those samples with largest final scores, 
    # especially when the selection ratio is small)
    post_processing = True
    if post_processing: 
      batch = set()
      while len(batch) < N:
        selected = np.argmax(self.graph_density) 
        if type(self.connect_post) == dict: 
          pass 
        else: 
          neighbors = (self.connect_post[selected,:] > 0).nonzero()[1] 
        self.graph_density[neighbors] = self.graph_density[neighbors] - np.exp(-self.distances[selected, neighbors]*self.gamma)*self.graph_density[selected] 
        batch.add(selected) 
        self.graph_density[list(batch)] = min(self.graph_density) - 1
    else:
      batch = set()
      while len(batch) < N: 
        selected = np.argmax(self.graph_density)
        batch.add(selected) 
        self.graph_density[list(batch)] = min(self.graph_density) - 1e10
    return list(batch)



def get_median(features, targets): 
    num_classes = len(np.unique(targets, axis=0))
    prot = np.zeros((num_classes, features.shape[-1]), dtype=features.dtype)

    for i in range(num_classes):
        prot[i] = np.median(features[(targets == i).nonzero(), :].squeeze(), axis=0, keepdims=False)
    return prot


def get_distance(features, labels):

    prots = get_median(features, labels)
    prots_for_each_example = np.zeros(shape=(features.shape[0], prots.shape[-1]))

    num_classes = len(np.unique(labels))
    for i in range(num_classes):
        prots_for_each_example[(labels == i).nonzero()[0], :] = prots[i]
    distance = np.linalg.norm(features - prots_for_each_example, axis=1)

    return distance


def bin_allocate(num, bins, mode='uniform', initial_budget=None):
    sorted_index = torch.argsort(bins)
    sort_bins = bins[sorted_index]

    num_bin = bins.shape[0]

    rest_exp_num = num
    budgets = []
    for i in range(num_bin):
        if sort_bins[i] == 0:
            budgets.append(0)
            continue 
        rest_bins = torch.count_nonzero(sort_bins[i:])
        if mode == 'uniform':
            avg = rest_exp_num // rest_bins
            cur_num = min(sort_bins[i].item(), avg)
            rest_exp_num -= cur_num
        else:
            avg = initial_budget[sorted_index[i]]
            cur_num = min(sort_bins[i].item(), avg)
            delta = int((avg - cur_num)/max(1, (rest_bins - 1))) 
            for j in range(i+1, num_bin):
                initial_budget[sorted_index[j]] += delta
        budgets.append(cur_num)

    budgets = torch.tensor(budgets)
    if torch.sum(budgets) < num: # TODO: check again
        delta = num - torch.sum(budgets)
        i = 1
        while delta and i <= num_bin:
            if budgets[-i] < sort_bins[-i]:
                budgets[-i] += 1
                delta -= 1
            i += 1

    rst = torch.zeros((num_bin,)).type(torch.int)
    rst[sorted_index] = torch.tensor(budgets).type(torch.int)

    assert all([b<= r for r, b in zip(bins, rst)]), ([(r.item(),b.item()) for r, b in zip(bins, rst)], bins, [x.item() for x in torch.tensor(budgets)[sorted_index]])
    return rst


class CoresetSelection(object):

    @staticmethod
    def moderate_selection(data_score, ratio, features):

        def get_prune_idx(rate, distance):
            rate = 1-rate
            low = 0.5 - rate / 2
            high = 0.5 + rate / 2

            sorted_idx = distance.argsort()
            low_idx = round(distance.shape[0] * low)
            high_idx = round(distance.shape[0] * high)

            ids = np.concatenate((sorted_idx[:low_idx], sorted_idx[high_idx:]))

            return ids

        targets_list = data_score['targets']
        distance = get_distance(features, targets_list)
        ids = get_prune_idx(ratio, distance)

        return ids


    @staticmethod
    def score_monotonic_selection(data_score, key, ratio, descending, class_balanced):
        score = data_score[key]
        score_sorted_index = score.argsort(descending=descending)
        total_num = ratio * data_score[key].shape[0]
        print("Selecting from %s samples" % total_num)
        if class_balanced:
            print('Class balance mode.')
            all_index = torch.arange(data_score['targets'].shape[0])
            #Permutation
            selected_index = []
            targets_list = data_score['targets'][score_sorted_index]
            targets_unique = torch.unique(targets_list)
            for target in targets_unique:
                target_index_mask = (targets_list == target)
                target_index = all_index[target_index_mask]
                targets_num = target_index_mask.sum()
                target_coreset_num = targets_num * ratio
                selected_index = selected_index + list(target_index[:int(target_coreset_num)])
                print("Selected %s samples for %s label" % (len(selected_index), target))
            selected_index = torch.tensor(selected_index)
            print(f'High priority {key}: {score[score_sorted_index[selected_index][:15]]}')
            print(f'Low priority {key}: {score[score_sorted_index[selected_index][-15:]]}')
            return score_sorted_index[selected_index]
        else:
            print(f'High priority {key}: {score[score_sorted_index[:15]]}')
            print(f'Low priority {key}: {score[score_sorted_index[-15:]]}')
            return score_sorted_index[:int(total_num)]

    @staticmethod
    def mislabel_mask(data_score, mis_key, mis_num, mis_descending, coreset_key):
        print(mis_key)
        print(data_score['entropy'])
        mis_score = data_score[mis_key]
        mis_score_sorted_index = mis_score.argsort(descending=mis_descending)
        hard_index = mis_score_sorted_index[:mis_num]
        print(f'Bad data -> High priority {mis_key}: {data_score[mis_key][hard_index][:15]}')
        print(f'Prune {hard_index.shape[0]} samples.')

        easy_index = mis_score_sorted_index[mis_num:]
        data_score[coreset_key] = data_score[coreset_key][easy_index]

        return data_score, easy_index

    @staticmethod 
    def stratified_sampling(data_score, coreset_num, args, data_embeds=None):

        if args.sampling_mode == 'graph' and args.coreset_key in ['accumulated_margin']: # TODO: check again
            score = data_score[args.coreset_key]
            min_score = torch.min(score)
            max_score = torch.max(score)
            score = score - min_score
            data_score[args.coreset_key] = -score

        print('Using stratified sampling...')
        score = data_score[args.coreset_key]
        if args.graph_score:
            graph = InfoGraphDensitySampler(X=data_embeds, y=None, gamma=args.gamma,
                                                  seed=0, importance_scores=score, args=args) 
            score = torch.tensor(graph.graph_density)

        total_num = len(score)
        min_score = torch.min(score)
        max_score = torch.max(score) * 1.0001
        print("Min score: %s, max score: %s" % (min_score.item(), max_score.item()))
        step = (max_score - min_score) / args.stratas

        def bin_range(k):
            return min_score + k * step, min_score + (k + 1) * step

        strata_num = []
        ##### calculate number of samples in each strata #####
        for i in range(args.stratas):
            start, end = bin_range(i)
            num = torch.logical_and(score >= start, score < end).sum()
            strata_num.append(num)
        strata_num = torch.tensor(strata_num)

        if args.budget_mode == 'uniform':
            budgets = bin_allocate(coreset_num, strata_num)
        elif args.budget_mode == 'confidence':
            confs = data_score['confidence']
            mean_confs = []
            for i in range(args.stratas):
                start, end = bin_range(i)
                sample_idxs = torch.logical_and(score >= start, (score < end)).nonzero().squeeze()
                if sample_idxs.size()[0] != 0:
                    mean_confs.append(1-torch.mean(confs[sample_idxs]).item())
                else:
                    mean_confs.append(0)
            total_conf = np.sum(mean_confs)
            budgets = [int(n*coreset_num/total_conf) for n in mean_confs]
            print("Initial budget", budgets)
            budgets = bin_allocate(coreset_num, strata_num, mode='confidence', initial_budget=budgets)
        elif args.budget_mode == 'aucpr':
            budgets = bin_allocate(coreset_num, strata_num)
            sample_index = torch.arange(data_score[args.coreset_key].shape[0])
            aucpr_values = []
            min_budgets = {}
            for i in tqdm(range(args.stratas), desc='Getting k-centers for aucpr-based budgeting'):
                if budgets[i] == 0:
                    aucpr_values.append(0)
                    continue
                start, end = bin_range(i)
                mask = torch.logical_and(score >= start, score < end)
                pool = sample_index[mask]

                if args.sampling_mode == 'random':
                    rand_index = torch.randperm(pool.shape[0])
                    selected_idxs = [idx.item() for idx in rand_index[:budgets[i]]] 
                elif args.sampling_mode == 'graph':
                    if pool.shape[0] <= args.n_neighbor:
                        rand_index = torch.randperm(pool.shape[0])
                        selected_idxs = rand_index[:budgets[i]].numpy().tolist()
                    else:
                        sampling_method = InfoGraphDensitySampler(X=None if data_embeds is None else data_embeds[pool], y=None, gamma=args.gamma,
                                                              seed=0, importance_scores=score[pool], args=args) 
                        selected_idxs = sampling_method.select_batch_(budgets[i])
                else:
                    raise ValueError

                kcenters = pool[selected_idxs]
                non_coreset = list(set(pool.tolist()).difference(set(kcenters.tolist())))
                aucpr = get_aucpr(data_embeds[kcenters], data_embeds[non_coreset])
                aucpr_values.append(round(aucpr, 3))
                if aucpr == 0:
                    min_budgets[i] = budgets[i]

            print("Initial AUCpr values: ", aucpr_values)
            print("Initial mean AUCpr: ", np.mean(aucpr_values))
            total_aucpr = np.sum(aucpr_values)
            print("Uniform budget", budgets)
            if total_aucpr == 0:
                pass
            else:
                budgets = [int(n*(coreset_num-sum(min_budgets.values()))/total_aucpr) if i not in min_budgets
                           else min_budgets[i] for i, n in enumerate(aucpr_values)]
                print("Initial budget", budgets)
                budgets = bin_allocate(coreset_num, strata_num, mode='aucpr', initial_budget=budgets)
        else:
            raise ValueError
        # assert budgets.sum().item() == coreset_num, (budgets.sum(), coreset_num)
        print(budgets, budgets.sum())

        ##### sampling in each strata #####
        selected_index = []
        sample_index = torch.arange(data_score[args.coreset_key].shape[0])

        pools, kcenters = [], []
        for i in tqdm(range(args.stratas), desc='sampling from each strata'):
            start, end = bin_range(i)
            mask = torch.logical_and(score >= start, score < end)
            pool = sample_index[mask]
            pools.append(pool)

            if len(pool.numpy().tolist()) == 0 or budgets[i] == 0:
                continue
            if args.sampling_mode == 'random':
                rand_index = torch.randperm(pool.shape[0])
                selected_idxs = [idx.item() for idx in rand_index[:budgets[i]]] 
                selected_idxs = sampling_method.select_batch_(None, budgets[i])
            elif args.sampling_mode == 'graph':
                if pool.shape[0] <= args.n_neighbor: # if num of samples are less than size of graph, select all
                    rand_index = torch.randperm(pool.shape[0])
                    selected_idxs = rand_index[:budgets[i]].numpy().tolist()
                else:
                    sampling_method = InfoGraphDensitySampler(X=None if data_embeds is None else data_embeds[pool], y=None, gamma=args.gamma, seed=0,
                                                          importance_scores=score[pool], args=args) 
                    selected_idxs = sampling_method.select_batch_(budgets[i])
            else:
                raise ValueError
            kcenters.append(pool[selected_idxs])

        if args.aucpr:
            final_aucpr_values = []
            for pool, samples in zip(pools, kcenters):
                if len(pool.numpy().tolist()) == 0 or budgets[i] == 0:
                    final_aucpr_values.append(0.0)
                non_coreset = list(set(pool.tolist()).difference(set(samples.tolist())))
                if len(non_coreset) == 0:
                    aucpr = 0
                else:
                    aucpr = get_aucpr(data_embeds[kcenters], data_embeds[non_coreset])
                final_aucpr_values.append(round(aucpr, 3))
            print("Final AUCpr values: ", final_aucpr_values)
            print("Final mean AUCpr: ", np.mean(final_aucpr_values))

        for samples in kcenters:
            selected_index += samples

        return selected_index, (pools, budgets)

    @staticmethod
    def density_sampling(data_score, bins, coreset_num, args, data_embeds=None):

        if args.sampling_mode == 'graph' and args.coreset_key in ['accumulated_margin']: # TODO: check again
            score = data_score[args.coreset_key]
            min_score = torch.min(score)
            max_score = torch.max(score)
            score = score - min_score
            data_score[args.coreset_key] = score

        hist = np.histogram(bins, bins=np.arange(0, np.amax(bins) + 2, 1))[0]
        n_bins = np.amax(bins) + 1
        bin_pop_density = Counter(hist.tolist())
        print("Frequency of bin counts", bin_pop_density.most_common(20))

        non_empty_bins = np.where(hist != 0)[0]
        print("Skipping %s empty bins in total %s bins" % ((n_bins - non_empty_bins.shape[0]), n_bins))

        strata_num = []
        bin2size = {bin_idx: hist[bin_idx] for bin_idx in non_empty_bins} 
        for i in non_empty_bins:
            strata_num.append(bin2size[i])
        strata_num = torch.tensor(strata_num)

        if args.budget_mode == 'density':
            total_num = sum(list(bin2size.values()))
            bin2budget = {bin_idx: math.ceil(bin2size[bin_idx]*coreset_num/total_num) for bin_idx in non_empty_bins}
        elif args.budget_mode == 'uniform':
            budgets = bin_allocate(coreset_num, strata_num)
            bin2budget = {bin_i: budgets[i] for i, bin_i in enumerate(non_empty_bins)}
        elif args.budget_mode == 'confidence':
            confs = data_score['confidence']
            mean_confs = []
            strata_num = []
            for i in non_empty_bins:
                sample_idxs = np.where(bins == i)[0]
                mean_confs.append(1 - torch.mean(confs[sample_idxs]).item())
                strata_num.append(bin2size[i])
            strata_num = torch.tensor(strata_num)
            total_conf = np.sum(mean_confs)
            budgets = [int(n * coreset_num / total_conf) for n in mean_confs]
            print("Initial budget", budgets)
            budgets = bin_allocate(coreset_num, strata_num, mode='confidence', initial_budget=budgets)
            print("Final budget", budgets)
            bin2budget = {bin_idx: budgets[i] for i, bin_idx in enumerate(non_empty_bins)}
        elif args.budget_mode == 'aucpr':
            budgets = bin_allocate(coreset_num, strata_num)
            aucpr_values = []
            min_budgets = {}
            for i, bin_idx in tqdm(enumerate(non_empty_bins), desc='Getting k-centers for aucpr-based budgeting'):
                if budgets[i] == 0:
                    aucpr_values.append(0)
                    continue
                sample_idxs = np.where(bins == bin_idx)[0]

                if args.sampling_mode == 'random':
                    rand_index = np.random.permutation(sample_idxs.shape[0])
                    selected_idx = rand_index[:budgets[i]] 
                elif args.sampling_mode == 'graph':
                    if sample_idxs.shape[0] <= 10:
                        selected_idx = np.random.permutation(sample_idxs.shape[0])
                    else:
                        sampling_method = InfoGraphDensitySampler(X=None if data_embeds is None else data_embeds[sample_idxs], y=None, gamma=args.gamma, seed=0,
                                                              importance_scores=data_score['forgetting'][sample_idxs], args=args) 
                        selected_idx = sampling_method.select_batch_(budgets[i])
                else:
                    raise ValueError

                kcenters = sample_idxs[selected_idx]
                non_coreset = list(set(sample_idxs.tolist()).difference(set(kcenters.tolist())))
                aucpr = get_aucpr(data_embeds[kcenters], data_embeds[non_coreset])
                aucpr_values.append(aucpr)
                if aucpr == 0:
                    min_budgets[bin_idx] = budgets[i]

            print("Initial AUCpr values: ", aucpr_values)
            print("Initial mean AUCpr: ", np.mean(aucpr_values))
            total_aucpr = np.sum(aucpr_values)
            print("Uniform budget", budgets)
            if total_aucpr == 0:
                pass
            else:
                budgets = [int(n * (coreset_num - sum(min_budgets.values())) / total_aucpr) if i not in min_budgets
                       else min_budgets[i] for i, n in enumerate(aucpr_values)]
                print("Initial budget", budgets)
                budgets = bin_allocate(coreset_num, strata_num, mode='aucpr', initial_budget=budgets)
            bin2budget = {bin_idx: budgets[i] for i, bin_idx in enumerate(non_empty_bins)}
        else:
            raise ValueError

        print('Using density sampling...')
        pools = []
        final_idxs = []
        # def get_coreset_for_bin(bins_to_accomplish, final_idxs):
        def get_coreset_for_bin(bin_idx):
            # while True:
            #     try:
            # bin_idx = bins_to_accomplish.get_nowait()
            sample_idxs = np.where(bins == bin_idx)[0]
            pools.append(sample_idxs)
            print("Starting process for label  ", bin_idx, "with %s samples" % len(sample_idxs))
            if len(sample_idxs) > 0:
                if bin2budget[bin_idx] > len(sample_idxs):
                    kcenters = np.random.permutation(sample_idxs.shape[0])
                else:
                    if args.sampling_mode == 'random':
                        rand_index = np.random.permutation(sample_idxs.shape[0])
                        kcenters = rand_index[:bin2budget[bin_idx]] 
                    elif args.sampling_mode == 'graph':
                        if sample_idxs.shape[0] <= args.n_neighbor:
                            kcenters = np.random.permutation(sample_idxs.shape[0])
                        else:
                            sampling_method = InfoGraphDensitySampler(X=None if data_embeds is None else data_embeds[sample_idxs], y=None,
                                                                  gamma=args.gamma, seed=0,
                                                                  importance_scores=data_score[args.coreset_key][
                                                                      sample_idxs], args=args) 
                            kcenters = sampling_method.select_batch_(bin2budget[bin_idx])
                    else:
                        raise ValueError
                kcenters = sample_idxs[kcenters]
            else:
                kcenters = [] 
            return kcenters 

        for bin_idx in non_empty_bins:
            kcenters = get_coreset_for_bin(bin_idx)
            final_idxs.append(kcenters.tolist())

        if args.aucpr:
            final_aucpr_values = []
            for pool, selected in zip(pools, final_idxs):
                non_coreset = list(set(pool.tolist()).difference(set(selected.tolist())))
                if len(non_coreset) == 0 or len(selected) == 0:
                    aucpr = 0
                else:
                    aucpr = get_aucpr(data_embeds[selected], data_embeds[non_coreset])
                final_aucpr_values.append(round(aucpr, 3))
            print("Final AUCpr values: ", final_aucpr_values)
            print("Final mean AUCpr: ", np.mean(final_aucpr_values))

        selected_idxs = []
        for idxs in final_idxs:
            selected_idxs.extend(idxs)
        random.shuffle(selected_idxs)
        if len(selected_idxs) < coreset_num:
            extra_sample_set = list(set(range(len(bins))).difference(set(selected_idxs)))
            selected_idxs = selected_idxs + random.sample(extra_sample_set, k=min(len(extra_sample_set), coreset_num-len(selected_idxs)))

        return selected_idxs[:coreset_num], None

    @staticmethod
    def random_selection(total_num, num):
        print('Random selection.')
        score_random_index = torch.randperm(total_num)

        return score_random_index[:int(num)]

"""Calculate loss and entropy"""
def post_training_metrics(args, model, dataloader, device):
    model.eval()
    data_importance = {}
    all_feat = torch.zeros((len(dataloader.dataset), 128))
    data_importance['entropy'] = torch.zeros(len(dataloader.dataset))
    data_importance['confidence'] = torch.zeros(len(dataloader.dataset))

    activation = {}
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook

    #print(model)
    model.fc9.register_forward_hook(get_activation('fc9'))
    
    idx = 0
    for batch_idx, (inputs, target) in enumerate(train_loader):
        target = target.long().to(device)
        inputs = inputs.to(device)
        
        logits = model(inputs)
        feat_map = activation['fc9']
        #print(feat_map.shape)
        all_feat[idx: idx + inputs.size(0)] = feat_map

        prob = nn.Softmax(dim=1)(logits)

        entropy = -1 * prob * torch.log(prob + 1e-10)
        entropy = torch.sum(entropy, dim=1).detach().cpu()

        confidence = prob[torch.arange(0, logits.shape[0]).to(device), target].detach().cpu()
        # store using global indices
        data_importance['entropy'][idx: idx + inputs.size(0)] = entropy
        data_importance['confidence'][idx: idx + inputs.size(0)] = confidence

        # move the offset forward by the current batch size
        idx += inputs.size(0)

    return data_importance, all_feat

"""Extract feature map"""
def feature_maps(model, dataloader, device, layer=4):
    model.eval()
    features = []
    for batch_idx, (idx, (inputs, targets)) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)

        feats = model.feature_map(inputs, layer=layer)
        features.append(feats.detach().cpu().numpy())

    features = np.concatenate(features)
    return features


def infomax(args, model, trainset, trainloader, threshold, device):
    total_num = len(trainset) 
    coreset_num = args.coreset_ratio * total_num


    data_score, features = post_training_metrics(args, model, trainloader, device)
    
    #with open(args.data_score_path, 'rb') as f:
    #    data_score = pickle.load(f)
    if args.coreset_key in ['entropy', 'forgetting', 'el2n', 'ssl']:
        print("Using descending order")
        args.data_score_descending = 1 
    if args.label_balanced:
        labels = []
        for _, target in trainloader:
            labels.extend(target.tolist())
        labels = torch.tensor(labels)
        uniq_labels = set(labels.tolist())
        targets = labels.tolist()
        uniq_labels = list(uniq_labels)
        coreset_sizes_per_label = int(coreset_num / len(uniq_labels))
        coreset_sizes_per_label = [coreset_sizes_per_label] * len(uniq_labels)
    else:
        uniq_labels = None
        targets = None
    score_index = None

    
    mis_num = int(args.mis_ratio * total_num)
    _, score_index = CoresetSelection.mislabel_mask(data_score, mis_key=args.mis_key,
                                                             mis_num=mis_num,
                                                             mis_descending=args.mis_key in ['entropy',
                                                                                             'forgetting', 'el2n',
                                                                                             'ssl'],
                                                             coreset_key=args.coreset_key)
    features = features[score_index]
    #np.save('./temp/imagenet/score_index_aum_%s.npy' % args.coreset_ratio, score_index)

    coreset_num = int(args.coreset_ratio * total_num)
    # load data scores from training 100% data
    print(data_score[args.coreset_key].shape)

    sampling_method = InfoGraphDensitySampler(X=features, y=None,
                                          gamma=args.gamma,
                                          seed=0, importance_scores=data_score[args.coreset_key], args=args) 
    coreset_index = sampling_method.select_batch_(coreset_num)
    coreset_index_no_mis = np.array(coreset_index.copy())
    coreset_index = score_index[coreset_index]
    graph_scores = sampling_method.starting_density

    if len(coreset_index) < coreset_num:
        if score_index is not None:
            extra_sample_set = list(set(score_index.tolist()).difference(set(coreset_index.tolist())))
            coreset_index = np.hstack((coreset_index, np.array(random.sample(extra_sample_set,
                                                                             k=int(min(len(extra_sample_set),
                                                                                   coreset_num - len(coreset_index)))))))
            print("Added extra %s samples" %  int(min(len(extra_sample_set), coreset_num-len(coreset_index))))
            print(coreset_index.shape)
    print(coreset_index.shape)

    return coreset_index

def dynamic_uncertainty(model, xTrain_in, yTrain_in, num_sample):
    for i in range(100,800,100):
        print(i)
    exit()
    pred_prob_history_files
    uncertainty_his = []

    for (idx, pred_prob_file) in enumerate(pred_prob_history_files):
        pred_prob = torch.from_numpy(np.loadtxt(f'{args.input}/pred_prob_history/{pred_prob_file}'))
        indices = pred_prob.nonzero().squeeze().cpu()

        pred_prob_his[indices] = torch.cat((pred_prob_his[indices, 1:], torch.unsqueeze(pred_prob[indices], 1).cpu()),
                                           dim=1)
        if idx >= args.window - 1 and idx < len(pred_prob_history_files) - 1:
            uncertainty_his.append((torch.std(pred_prob_his, dim=1) * 10).detach().numpy())

    dynamic_uncertainty = np.mean(np.array(uncertainty_his), axis=0)

    dyn_unc_rank = np.argsort(dynamic_uncertainty)

    keep = set(dyn_unc_rank[-int(nums * args.fraction):])

    return keep

# class to represent dataset
class SCADataset():
  
    def __init__(self, data):
        
        self.x = data[0].astype(np.float32)
        self.y = data[1]
        self.n_samples = data[0].shape[0] 
      
    # support indexing such that dataset[i] can 
    # be used to get i-th sample
    def __getitem__(self, index):
        return self.x[index], self.y[index]
        
    # we can call len(dataset) to return the size
    def __len__(self):
        return self.n_samples

#MAIN

xTrain_original, yTrain_original, xVal, yVal, yTrain_value, yVal_value = create_training_data_optimize(args,data_path,sKeyNo, trainPortion, xType, yType,False, args.start_trace, args.end_trace)
print(xTrain_original.shape)
print(yTrain_original.shape)
print(xVal.shape)
print(yVal.shape)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Example usage
xType = 'wavebp'
noConv1Dbranch = 1
noBPbranch = 2
noLayers = 6
tracelen = 600
NumBPinput = 3330
classes = 3329
#database_folder_train = os.path.join('multi_attack_trained_models', save_path)
#Path(database_folder_train).mkdir(parents=True, exist_ok=True)
# Example usage:
model = CNNModel()
#Load model
model.load_state_dict(torch.load(args.trained_model_path, weights_only=True))
model = model.to(device)



# Define optimizer and loss function
optimizer = optim.RMSprop(model.parameters(), lr=0.00001)
criterion = nn.CrossEntropyLoss()
# Example usage
trace_profiling, bp_profiling, skpv_profiling = load_meta_trace_files(data_path, args.start_trace, args.end_trace)
num_sample = args.num_sample

num_batch = 10000
all_indexes = []
for i in range(0,num_sample,num_batch):
    all_ids = np.arange(i,i+num_batch)
    train_data = [xTrain_original[all_ids], yTrain_original[all_ids]]
    test_data = [xVal, yVal]
    SCAdataset = SCADataset(train_data)
    SCAdataset_val = SCADataset(test_data)
    train_loader = DataLoader(SCAdataset, batch_size=args.batch_size, shuffle=False) #Changed to False here to get error index
    val_loader = DataLoader(SCAdataset_val, batch_size=args.batch_size, shuffle=False)
    print('--------------')
    #Identify coreset
    error_indexes = infomax(args, model, SCAdataset, train_loader, threshold=args.jtt_threshold, device=device)
    error_indexes = error_indexes + i
    print(len(error_indexes))
    print(torch.max(error_indexes))
    all_indexes.append(error_indexes)

all_indexes = np.array(all_indexes)
coreset_indexes = all_indexes.reshape(-1)    


#train_data = [xTrain_original, yTrain_original]
#Append new samples to train loader
train_data = [xTrain_original[coreset_indexes], yTrain_original[coreset_indexes]]
print('Len data')
print(len(train_data[0]))
SCAdataset = SCADataset(train_data)
train_loader = DataLoader(SCAdataset, batch_size=args.batch_size, shuffle=True) #Changed to False here to get error index


# Example usage:

#New model Initialization
model = CNNModel()
model = model.to(device)
# Define optimizer and loss function
optimizer = optim.RMSprop(model.parameters(), lr=0.00001)
criterion = nn.CrossEntropyLoss()

model = train(args, database_folder_train, model, train_loader, val_loader, optimizer, criterion, epochs=args.num_epoch)