import os.path
import sys
import h5py
import math
import gc
import numpy as np
#from numba import cuda
import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
import copy
import time
from tqdm import tqdm as tqdm
from scipy.spatial import distance

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--dm_path', type=str, default=None)
    parser.add_argument('--medoids_path', type=str)
    parser.add_argument('--save_folder', type=str)
    parser.add_argument('--num_sample', type=int)
    parser.add_argument('--min_dist', type=int)
    return parser

parser = parse_arguments()
args = parser.parse_args()

def fast_min_max_sampling(data, args, num_sample):
    
    data = np.squeeze(data)
    original_data = copy.deepcopy(data)
    #kmeans = CLARA(n_clusters=num_cluster, random_state=0).fit(data)
    medoids_indexes = np.load(args.medoids_path)
    print(len(medoids_indexes))
    dist_mm = np.load(args.dm_path)
    #if args.min_dist == 1:
    #    dist_mm = np.min(dist_mm, axis=1)
    #    np.save('min_dist.npy', dist_mm)
    print(dist_mm.shape)
    #Select sample using min-max
    curr_matrix = copy.deepcopy(dist_mm)

    #curr_matrix = np.ones(len(data))
    #curr_matrix[10] = 2 #Fix zero matrix
    
    all_idxs = np.arange(len(original_data))

    print('curr_matrix.shape: ',curr_matrix.shape)
    print('curr_matrix[0][:10]: ',curr_matrix[:10])
    print('data.shape: ',data.shape)
    print('data[0][:10]: ',data[:10])
    print('Min max sampling')
    
    start_time = time.time()
    #while len(medoids_indexes) < num_sample:
    #    print(len(medoids_indexes))
    init_num = len(medoids_indexes)
    print('all_idxs: ',all_idxs)
    
    for i in tqdm(range(num_sample - init_num)):
        curr_idxs = np.setdiff1d(all_idxs, medoids_indexes)
        print('curr_idxs[:10]: ',curr_idxs[:10])
        print('len(curr_idxs): ',len(curr_idxs))
        print('len(medoids_indexes): ',len(medoids_indexes))
        if curr_idxs[0] in medoids_indexes: #No medoids index in current indexes
            print('False')
        print('---------')
        print(len(curr_matrix))
        curr_matrix[medoids_indexes] = -1
        max_idx = np.argmax(curr_matrix)
        print('max_idx: ',max_idx)
        print('max distance', np.max(curr_matrix))
        print('max distance by index', curr_matrix[max_idx])
        medoids_indexes = np.append(medoids_indexes, max_idx)
        for idx in curr_idxs:  #for each samples in D/S
            #start_time = time.time()
            #print('---------------------------')
            x = np.array(original_data[idx])
            m = np.array(original_data[max_idx])
            #print(idx)
            #print(x[:10])
            #print(m[:10])

            #curr_matrix[idx,len(medoids_indexes)] = pairwise_distances(x, m)[0][0]
            dist_xm = distance.euclidean(x, m)
            #print(dist_xm)
            #print(curr_matrix[idx])

            if dist_xm < curr_matrix[idx]:
                curr_matrix[idx] = dist_xm
            #print(curr_matrix[idx])

        if len(medoids_indexes) % 5000 == 0:
            np.save(os.path.join(args.save_folder,'minmax_balance_{}_{}_index.npy'.format(alpha, len(medoids_indexes))), medoids_indexes)
            np.save(os.path.join(args.save_folder,'dist_matrix_{}_{}_index.npy'.format(alpha, len(medoids_indexes))), curr_matrix)

        
    #exit()
            
    end_time = time.time()
    print('Took {} seconds'.format(end_time - start_time))
    return medoids_indexes, curr_matrix

data = np.load('data.npz')
data = data['data'][:200000]
data_indexes, dist_matrix = fast_min_max_sampling(data, args, num_sample=args.num_sample)
print(data_indexes.shape)
np.save(os.path.join(args.save_folder,'minmax_index_{}.npy'.format(args.num_sample)), data_indexes)
np.save(os.path.join(args.save_folder,'dist_matrix_{}.npy'.format(args.num_sample)), dist_matrix)
