import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.spatial import distance
import argparse
from tqdm import tqdm as tqdm
import os

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--medoids_path', type=str)
    parser.add_argument('--save_folder', type=str)
    parser.add_argument('--num_sample', type=int)
    
    return parser
parser = parse_arguments()
args = parser.parse_args()
medoid_indices_ = np.load(args.medoids_path)
dist_matrix = np.zeros(args.num_sample)


data = np.load('data.npz')['data'][:args.num_sample]

for i in tqdm(range(args.num_sample)):
    for j in range(len(medoid_indices_)):
        x = np.array(data[i])
        m = np.array(data[medoid_indices_[j]])
        #print(medoid_indices_[j])
        #print(x[:10])
        #print(m[:10])
        #dist_matrix[i][j] = pairwise_distances(x, m)[0][0]
        dist_xm = distance.euclidean(x, m)
        if dist_xm !=0: #Not same point
            if j == 0:
                dist_matrix[i] = dist_xm
            else:
                if dist_xm < dist_matrix[i]:
                    dist_matrix[i] = dist_xm

print(np.min(dist_matrix))
print(np.max(dist_matrix))
np.save(os.path.join(args.save_folder, 'distance_matrix.npy'),dist_matrix)