from scipy.spatial import distance
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--vocab-file", help="input vocabaulary file")
parser.add_argument("--point-file", help="input vocabulary representations")
parser.add_argument("--clustering-file", help="clustering results file")
parser.add_argument("--delimeter","-d", default='|||', help="delimeter (default: '|||')")
args = parser.parse_args()

vocab_file = args.vocab_file
point_file = args.point_file
clustering_file = args.clustering_file
delimeter = args.delimeter

vocab = np.load(vocab_file)
points = np.load(point_file)

clusters = dict()
with open(clustering_file,'r') as f:
	for line in f:
		parts = line.strip().split(delimeter)
		word = delimeter.join(_ for _ in parts[:-1])
		clusters[word] = parts[-1].strip()

cluster_list = defaultdict(list)
vocab_id = dict()
for id,word in enumerate(vocab):
	try:
		label = clusters[word]
		cluster_list[label].append(word)
		vocab_id[word] = id
	except:
		print(word+" NOT FOUND!")
		exit()

cluster_ids = list(cluster_list.keys())
centroids = dict()
for cluster_id in cluster_ids:
	word_ids = []
	for word in cluster_list[cluster_id]:
		i = vocab_id[word]
		word_ids.append(i)
	word_ids = np.array(word_ids)
	centroid = np.mean(points[word_ids,:])	
	centroids[cluster_id] = centroid

closest_centroid = dict()
intertia = 0.0
for i,word in enumerate(vocab):
	vec2 = points[i,:]
	dists = []
	for cluster_id in cluster_ids:
		vec1 = centroids[cluster_id]
		dist = np.linalg.norm(vec1-vec2)
		dists.append(dist)
	idx = np.argmin(dists)
	closest_centroid[word] = cluster_ids[idx]
	dist = dists[idx]
	intertia += dist

print("Elbow score = "+str(intertia))	