from sklearn.neighbors import KNeighborsClassifier
import os.path
import sys
import h5py
import math
import gc
import numpy as np
#from numba import cuda
import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
import copy
import time
from tqdm import tqdm as tqdm
from scipy.spatial import distance
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
import pandas as pd

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--target_folder', type=str)
    parser.add_argument('--target_file', type=str)
    parser.add_argument('--num_sample', type=int, default=200)

    return parser

def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

parser = parse_arguments()
args = parser.parse_args()
data = np.load('data.npz')
xTrain_original = data['data'][:200000]
labels = data['label'][:200000]
xTest_multi, yTest_multi = load_multi_attack('attack_multi_data_300key.npz')
xTrain_original = np.expand_dims(xTrain_original, axis = 2)
print(xTrain_original.shape)
print(xTest_multi.shape)
print(yTest_multi.shape)
print(yTest_multi[:10])

target_samples = xTest_multi[:,:5,:]

dj_set = np.load(os.path.join(args.target_folder,args.target_file))
print(dj_set.shape)
dj_set = dj_set[:,:args.num_sample]
#Get std/percentage/cluster information
all_std = []
all_min = []
all_max = []
all_mean = []
all_max_perc = []
all_min_perc = []
all_mean_perc = []
all_std_perc = []
is_max_key = []
all_unique = []
same_label_count = []
same_label_percentage = []
same_label_rank = []
num_samples_per_key = []
#num_samples_per_cluster = []
#Get percentage of all labels

i = 0
full_data = []
for key in dj_set:
    curr_key = yTest_multi[i]
    all_labels = np.zeros(3329)
    all_perc = np.zeros(3329)
    #if dj_set.shape[1] < dj_set.shape[2]:
    unique_indexes = np.unique(key).astype(int)
    #print(key[:10])
    num_samples_per_key.append(len(unique_indexes))
    #print(len(unique_indexes))
    key_labels = labels[unique_indexes]
    unique, counts = np.unique(key_labels, return_counts=True)
    all_labels[unique] = counts
    all_perc = all_labels/len(unique_indexes)
    same_label_rank.append(np.sum(all_perc > all_perc[curr_key])) 
    all_max_perc.append(np.max(all_perc[np.nonzero(all_perc)]))
    all_mean_perc.append(np.mean(all_perc[np.nonzero(all_perc)]))
    all_std_perc.append(np.std(all_perc[np.nonzero(all_perc)]))
    all_min_perc.append(np.min(all_perc[np.nonzero(all_perc)]))
    all_unique.append(len(unique))
    #all_labels[-1] = curr_key
    full_data.append(all_labels)
    all_mean.append(np.mean(all_labels[np.nonzero(all_labels)]))
    all_std.append(np.std(all_labels[np.nonzero(all_labels)]))
    all_min.append(np.min(all_labels[np.nonzero(all_labels)]))
    all_max.append(np.max(all_labels[np.nonzero(all_labels)]))
    same_label_count.append(all_labels[curr_key])
    same_label_percentage.append(all_labels[curr_key]/len(unique_indexes))
    i+= 1

'''
full_data = np.array(full_data)
all_header = [ 'key_' + str(i) for i in range(3329)] + ['D2_Key']
df = pd.DataFrame(full_data)
df.to_csv(os.path.join(args.target_folder,"full_labels.csv"), header=all_header,index=False)
'''
df = pd.DataFrame({
    'MEAN': all_mean,
    'STD': all_std,
    'MIN': all_min,
    'MAX': all_max,
    'Unique_Label': all_unique,
    'Num_sample': num_samples_per_key,
    'Same_Label': same_label_count,
    'Same_Label_Perc': same_label_percentage,
    'MEAN PERC': all_mean_perc,
    'MAX PERC': all_max_perc,
    'MIN PERC': all_min_perc,
    'STD PERC': all_std_perc,
    'Rank' : same_label_rank,
    'Rank_Perc': [same_label_rank[i]/all_unique[i] for i in range(len(same_label_rank))]
    })

#Number of sample influence
'''
step_std = []
step_perc = []
step = 20 #Every 20 traces
for key in dj_set:
    #print(key.shape)
    curr_key = yTest_multi[i]
    key_std = []
    key_same_perc = []
    for i in range(int(key.shape[1]/step)):   
        all_labels = np.zeros(3329)
        #print(np.concatenate(key).shape)
        unique_indexes = np.unique(np.concatenate(key[:,i*step:(i+1)*step]))
        key_labels = labels[unique_indexes]
        unique, counts = np.unique(key_labels, return_counts=True)
        all_labels[unique] = counts
        key_std.append(np.std(all_labels))
        key_same_perc.append(all_labels[curr_key]/len(unique_indexes))
    step_std.append(key_std)
    step_perc.append(key_same_perc)

step_std = np.transpose(np.array(step_std))
step_perc = np.transpose(np.array(step_perc))
print(step_std.shape)
print(step_perc.shape)

for i in range(len(step_std)):
    colname = 'num_std_{}'.format((i+1)*step)
    df[colname] = step_std[i]

for i in range(len(step_perc)):
    colname = 'num_perc_{}'.format((i+1)*step)
    df[colname] = step_perc[i]
'''
df.to_csv(os.path.join(args.target_folder, 'Stats_{}_{}.csv'.format(args.target_file[:-4], args.num_sample)), index=False)
'''
for i in range(len(disjoint_ids)):
    fname = 'disjoint_{}'.format(i)
    np.save(os.path.join(args.train_folder, fname), sampled_ids)
'''
'''
neigh = KNeighborsClassifier(n_neighbors=3)
neigh.fit(X, y)
print(neigh.predict([[1.1]]))
print(neigh.predict_proba([[0.9]]))
'''