from sklearn.neighbors import KNeighborsClassifier
import os.path
import sys
import h5py
import math
import gc
import numpy as np
#from numba import cuda
import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
import copy
import time
from tqdm import tqdm as tqdm
from scipy.spatial import distance
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
import pandas as pd
import os

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--target_folder', type=str)
    parser.add_argument('--KL_file', type=str)
    parser.add_argument('--KNN_file', type=str)
    parser.add_argument('--num_sample', type=int, default=200)
    parser.add_argument('--add_num', type=int, default=0)

    return parser

def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

parser = parse_arguments()
args = parser.parse_args()
data = np.load('data.npz')
NUM_SAMPLE = 200000
data_ids = np.load('attack_300key_ids.npz')['ids']
sorted_ids = np.sort(np.hstack(data_ids))
non_overlap_ids = np.setdiff1d(np.arange(300000), sorted_ids) + NUM_SAMPLE
print(sorted_ids.shape)
print(non_overlap_ids.shape)
print(np.max(non_overlap_ids))
addtional_ids = non_overlap_ids[:args.add_num]
init_ids = np.arange(NUM_SAMPLE)
all_ids = np.concatenate((np.arange(NUM_SAMPLE), addtional_ids))
print(len(all_ids))
print(all_ids[:10])

xTrain_original = data['data'][all_ids]
labels = data['label'][all_ids]

xTest_multi, yTest_multi = load_multi_attack('attack_multi_data_300key.npz')
xTrain_original = np.expand_dims(xTrain_original, axis = 2)
print(xTrain_original.shape)
print(xTest_multi.shape)
print(yTest_multi.shape)
print(yTest_multi[:10])

target_samples = xTest_multi[:,:5,:]

dj_set_KL = np.load(os.path.join(args.target_folder,args.KL_file))
dj_set_KNN = np.load(os.path.join(args.target_folder,args.KNN_file))
print(dj_set_KNN.shape)
print(dj_set_KL.shape)
dj_set = np.concatenate((dj_set_KNN, dj_set_KL), axis = 1)
print(dj_set.shape)
fname = os.path.commonprefix([args.KL_file, args.KNN_file]) + 'KNN_KL.npy'
np.save(os.path.join(args.target_folder, fname), dj_set)

dj_set_KNN = dj_set_KNN[:,:args.num_sample]
dj_set_KL = dj_set_KL[:,:args.num_sample]
#Get std/percentage/cluster information
all_std = []
all_min = []
all_max = []
all_mean = []
all_max_perc = []
all_min_perc = []
all_mean_perc = []
all_std_perc = []
is_max_key = []
all_unique = []
same_label_count = []
same_label_percentage = []
same_label_rank = []
num_samples_per_key = []
overlap = []
#num_samples_per_cluster = []
#Get percentage of all labels

i = 0
full_data = []
for i in range(len(dj_set_KL)):
    key_KL = dj_set_KL[i]
    key_KNN = dj_set_KNN[i]
    overlap_num = len(np.intersect1d(key_KL, key_KNN))
    overlap.append(overlap_num / (len(key_KL) * 2))
    #key = np.concatenate((key_KL, key_KNN))
    key = key_KNN
    #key = np.intersect1d(key_KL, key_KNN)
    #print(key.shape)
    curr_key = yTest_multi[i]
    all_labels = np.zeros(3329)
    all_perc = np.zeros(3329)
    #if dj_set.shape[1] < dj_set.shape[2]:
    unique_indexes = np.unique(key).astype(int)
    #print(len(unique_indexes))
    #exit()
    #print(key[:10])
    num_samples_per_key.append(len(unique_indexes))
    #print(len(unique_indexes))
    key_labels = labels[unique_indexes]
    unique, counts = np.unique(key_labels, return_counts=True)
    all_labels[unique] = counts
    all_perc = all_labels/len(unique_indexes)
    same_label_rank.append(np.sum(all_perc > all_perc[curr_key])) 
    all_max_perc.append(np.max(all_perc[np.nonzero(all_perc)]))
    all_mean_perc.append(np.mean(all_perc[np.nonzero(all_perc)]))
    all_std_perc.append(np.std(all_perc[np.nonzero(all_perc)]))
    all_min_perc.append(np.min(all_perc[np.nonzero(all_perc)]))
    all_unique.append(len(unique))
    #all_labels[-1] = curr_key
    full_data.append(all_labels)
    all_mean.append(np.mean(all_labels[np.nonzero(all_labels)]))
    all_std.append(np.std(all_labels[np.nonzero(all_labels)]))
    all_min.append(np.min(all_labels[np.nonzero(all_labels)]))
    all_max.append(np.max(all_labels[np.nonzero(all_labels)]))
    same_label_count.append(all_labels[curr_key])
    same_label_percentage.append(all_labels[curr_key]/len(unique_indexes))
    i+= 1

df = pd.DataFrame({
    'MEAN': all_mean,
    'STD': all_std,
    'MIN': all_min,
    'MAX': all_max,
    'Unique_Label': all_unique,
    'Num_sample': num_samples_per_key,
    'Same_Label': same_label_count,
    'Same_Label_Perc': same_label_percentage,
    'MEAN PERC': all_mean_perc,
    'MAX PERC': all_max_perc,
    'MIN PERC': all_min_perc,
    'STD PERC': all_std_perc,
    'Rank' : same_label_rank,
    'Rank_Perc': [same_label_rank[i]/all_unique[i] for i in range(len(same_label_rank))],
    'Overlap' : overlap
    })

df.to_csv(os.path.join(args.target_folder, 'Stats_{}_{}.csv'.format('KL_KNN_geo', args.num_sample)), index=False)