import os.path
import sys
import h5py
import math
import gc
import numpy as np
#from numba import cuda
import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
import copy
import time
from tqdm import tqdm as tqdm
from scipy.spatial import distance

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--dm_path', type=str)
    parser.add_argument('--medoids_path', type=str)
    parser.add_argument('--save_folder', type=str)
    parser.add_argument('--num_sample', type=int)
    parser.add_argument('--min_dist', type=int)
    parser.add_argument('--alpha', type=float)
    return parser

parser = parse_arguments()
args = parser.parse_args()
smooth = 5
alpha = args.alpha

def fast_min_max_sampling(data, labels, args, num_sample):
    
    data = np.squeeze(data)
    original_data = copy.deepcopy(data)
    #kmeans = CLARA(n_clusters=num_cluster, random_state=0).fit(data)
    medoids_indexes = np.load(args.medoids_path)
    dist_mm = np.load(args.dm_path)
    print(np.max(dist_mm))
    print(np.min(dist_mm))
    
    #if min_dist == 1:
    #    dist_mm = np.min(dist_mm, axis=1)
    #    np.save('min_dist.npy', dist_mm)
    print(dist_mm.shape)

    #Select sample using min-max
    curr_matrix = copy.deepcopy(dist_mm)
    print(np.max(curr_matrix))
    print(np.min(curr_matrix))
    all_idxs = np.arange(len(original_data))
    print('Min max sampling')
    
    #while len(medoids_indexes) < num_sample:
    #    print(len(medoids_indexes))
    init_num = len(medoids_indexes)
    #score = np.zeros(len(original_data))
    #exist_prob = np.zeros(len(original_data))
    all_label_count = np.zeros(3330) #3330 is the number of key

    #Dummy holder to test algorithm
    #curr_matrix[10] = 2
    start_time = time.time()
    for i in tqdm(range(num_sample - init_num)):
        curr_idxs = np.setdiff1d(all_idxs, medoids_indexes)
        #print(curr_matrix[medoids_indexes])
        #curr_matrix[medoids_indexes] = 0
        max_idx = np.argmax(curr_matrix)
        #max_ind = np.argpartition(curr_matrix, -len(original_data))[-len(original_data):]
        #generated_labels = labels[medoids_indexes] #get current labels in S
        #unique, counts = np.unique(generated_labels, return_counts=True)
        #all_label_count[unique] = counts
        #for count_sample in counts: 
        #    exist_prob.append(count_sample/len(generated_labels))
        #for i in range(len(unique)):#Get proportion for each label 
        #    exist_prob[unique[i]] = (counts[i]/len(original_data))
        '''
        for i in range(len(original_data)):
            count_curr = len(np.where(labels == labels[i])[0])
            exist_prob[i] = count_curr / len(original_data)
        print('Done')
        for i in range(len(original_data)):
            range_x2 = np.max(exist_prob) - np.min(exist_prob)
            range_x1 = np.max(curr_matrix)-np.min(curr_matrix) + 1
            if range_x1 == 0:
                print('invalid')
                print(np.max(curr_matrix))
                print(np.min(curr_matrix))
            x2 = 1-(exist_prob[i]- np.min(exist_prob)) / range_x2
            x1 = 1-((curr_matrix[i] - np.min(curr_matrix)) / range_x1)
            score[i] = alpha * math.exp( -x1 * smooth) + (1-alpha) * math.exp(-x2 * smooth)
        print(np.argmax(score))
        print(max_idx)
        '''
        #Calculate the score for each samples in S
        score = np.zeros(len(original_data))

        unique, counts = np.unique(labels[medoids_indexes], return_counts=True) #Get the counts for sample in S
        all_label_count[unique] = counts # Distribution for sample count -> Can be sped up
        range_x2 = (np.max(all_label_count) - np.min(all_label_count) ) / len(medoids_indexes)
        range_x1 = np.max(curr_matrix)-np.min(curr_matrix)
        print(range_x1)
        #print(range_x1)
        #print(range_x2)
        min_prob = np.min(all_label_count)
        min_dist = np.min(curr_matrix)

        
        for idx in curr_idxs: #For each sample in D/S
            sample_prob = all_label_count[labels[idx]] / len(medoids_indexes)
            
            x2 = 1-(sample_prob- min_prob)/ range_x2

            #x2 = 1-(exist_prob[labels[i]]- np.min(exist_prob)) / range_x2
            x1 = ((curr_matrix[idx] - min_dist) / range_x1)
            #print('-----------')
            #print(x1)
            #print(x2)
            #all_res.append( alpha * math.exp(x1 * a) + (1-alpha) * math.exp( x2 * a))
            #sample_score = alpha * math.exp( -x1 * smooth) + (1-alpha) * (1-math.exp(-x2 * smooth))
            sample_score = alpha * x1 + (1-alpha) * x2
            score[idx] = sample_score

        print(np.max(score))
        print(np.min(score))
        max_idx = np.argmax(score)

        print('curr_idxs[:10]: ',curr_idxs[:10])
        print('len(curr_idxs): ',len(curr_idxs))
        print('len(medoids_indexes): ',len(medoids_indexes))

        #curr_matrix[medoids_indexes] = -1
        #max_idx = np.argmax(curr_matrix)
        print('max_idx: ',max_idx)
        print('max distance', np.max(curr_matrix))
        print('max distance by index', curr_matrix[max_idx])
        #exit()
        
        medoids_indexes = np.append(medoids_indexes, max_idx) # Pick sample then append to S
        for idx in curr_idxs:  #for each samples in D/S, calculate the distance
            #start_time = time.time()
            x = np.array(original_data[idx])
            m = np.array(original_data[max_idx])
            #curr_matrix[idx,len(medoids_indexes)] = pairwise_distances(x, m)[0][0]
            dist_xm = distance.euclidean(x, m)
            if dist_xm < curr_matrix[idx]:
                curr_matrix[idx] = dist_xm


    end_time = time.time()
    print('Took {} seconds'.format(end_time - start_time))
    return medoids_indexes, curr_matrix

'''
alpha = 1
a = 5
#x2 = - (exist_prob[labels[i]] - 1) / exist_prob[labels[i]]
range_x2 = np.max(exist_prob) - np.min(exist_prob)
range_x1 = np.max(ranking)-np.min(ranking)
#x2 = 1 - ((exist_prob[labels[i]] - np.min(exist_prob)))
x2 = 1-(exist_prob[labels[i]]- np.min(exist_prob)) / range_x2
x1 = 1-((ranking[i] - np.min(ranking)) / range_x1)
print('-----------')
#print(x1)
print(x2)
#all_res.append( alpha * math.exp(x1 * a) + (1-alpha) * math.exp( x2 * a))
all_res.append( alpha * math.exp( -x1 * a) + (1-alpha) * math.exp(-x2 * a))
'''

data = np.load('data.npz')
datas = data['data'][:200000]
labels = data['label'][:200000]
data_indexes = fast_min_max_sampling(datas, labels, args, num_sample=args.num_sample)
np.save(os.path.join(args.save_folder,'minmax_balance_uniform_{}_index.npy'.format(alpha)), data_indexes)
np.save(os.path.join(args.save_folder,'dist_matrix_alpha_{}_{}_uniform_index.npy'.format(alpha, args.num_sample)), dist_matrix)

exit()

#Run with pseudo distance for testing

generated_labels = np.array([0,0,1,2,2,2,3,3])
unique, counts = np.unique(generated_labels, return_counts=True)
#exist_prob = []
#for count_sample in counts:
#    exist_prob.append(count_sample/len(generated_labels))

all_label = np.zeros(5)
all_label[unique] = counts #Label distribution

print(all_label)

labels = np.array([0,0,1,1,2])
#distances = np.array([8.9, 7.6, 5.4, 3.4, 2.3])
distances = np.arange(5) * 2
print('labels: ',labels)
print('distances: ',distances)
smooth = 5
alpha = 0
all_res = []
for i in range(len(distances)):
    range_x2 = (np.max(all_label) - np.min(all_label) ) / len(generated_labels)
    range_x1 = np.max(distances)-np.min(distances)
    #print(range_x1)
    #print(range_x2)
    #x2 = 1 - ((exist_prob[labels[i]] - np.min(exist_prob)))
    sample_prob = all_label[labels[i]] / len(generated_labels)
    min_prob = np.min(all_label)
    x2 = 1-(sample_prob- min_prob)/ range_x2

    #x2 = 1-(exist_prob[labels[i]]- np.min(exist_prob)) / range_x2
    x1 = ((distances[i] - np.min(distances)) / range_x1)
    #print('-----------')
    #print(x1)
    #print(x2)
    #all_res.append( alpha * math.exp(x1 * a) + (1-alpha) * math.exp( x2 * a))
    #all_res.append( alpha * math.exp( -x1 * smooth) + (1-alpha) * (1-math.exp(-x2 * smooth)))
    #Uniform
    all_res.append( alpha * x1 + (1-alpha) * x2)


plt.plot(all_res)
plt.savefig('plot_alpha_uniform_{}.png'.format(alpha))