import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
import os.path
import sys
import h5py
import math
import gc
import time
import numpy as np
#from numba import cuda

import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
from natsort import natsorted

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--train_folder', type=str, help='eval folder')

    return parser

parser = parse_arguments()
args = parser.parse_args()
basepath = 'trained_models/'

data = np.load('data.npz')
labels = data['label']

'''
end_model_path = os.path.join('multi_attack_trained_models/test_1_baseline_none_wave_0_200000_30000_balance_alpha_0.5_50000_index.npy/model_best.keras')
end_model = load_model(end_model_path)
test_end = get_meanrank(xTest_multi, yTest_multi, end_model, maxtrc)
print(test_end)
'''
baseline_mr_best = 76.37
baseline_240k_best = 23.1
baseline_160k_best = 36.78

test_key = 1733

folder_path = args.train_folder


def get_info(folder_path, infer = False):
    #all_path = sorted(os.listdir(folder_path))
    all_path = natsorted(next(os.walk(folder_path))[1])
    print(all_path)
    all_ids = []
    all_cluster = []
    all_kl = []
    for fpath in all_path:
        all_counts = np.zeros(3329)
        fname = os.path.join(folder_path, fpath)
        num_ids = np.load(os.path.join(fname, 'all_ids.npy'))
        all_ids.append(len(num_ids))
        medoids_path = os.path.join(fname, 'medoids_ids.npy')
        KL_path = os.path.join(fname, 'medoids_KL_ids.npy')
        if os.path.isfile(medoids_path):
            medoids_ids = np.load(medoids_path)
            all_cluster.append(len(medoids_ids))
        else:
            all_cluster.append(0)
        if os.path.isfile(KL_path):
            KL_ids = np.load(KL_path)
            all_kl.append(len(KL_ids))
        else:
            all_kl.append(0)

    #Plot

    return all_ids, all_cluster, all_kl, all_path

all_ids, all_cluster, all_kl, all_path = get_info(folder_path, infer=True)
df = pd.DataFrame({
    'Name': all_path,
    'Ids':all_ids,
    'Cluster_ids':all_cluster,
    'KL_ids' : all_kl
    })
df.to_csv(os.path.join(folder_path, 'num_sample.csv'))
#np.savez('res_random.npz', rd_multi)
#np.savez('res_uncertain.npz', all_multi)
#np.savez('res_bal.npz', bal_multi)
#np.savez('res_bal_update.npz', bal_multi_notrain)
#np.savez('res_bal_best.npz', bal_multi)
#np.savez('res_bal_update_best.npz', bal_multi_notrain)