import sklearn
import matplotlib
import random
from prototype import get_prototype, get_prototype_median
import numpy as np
import pandas as pd
import pickle

def get_distance(save=True):
    
    with open("features.bin", "rb") as f1:
        d1 = pickle.load(f1)
    features = np.array(d1["features"])
    labels = np.array(d1["label"])
    prots = get_prototype(features, labels)
    print("features shape: ", features.shape)
    prots_for_each_example = np.zeros(shape=(features.shape[0], prots.shape[-1]))
    
    num_classes = len(np.unique(labels))
    # assert num_classes == 65
    for i in range(num_classes):
        # pdb.set_trace()
        prots_for_each_example[(labels==i).nonzero()[0], :] = prots[i]
    dis = np.linalg.norm(features - prots_for_each_example, axis=1)
    
    if save:
        with open("distance.bin", "wb") as f:
            dic = {"distance": dis, "label": labels}
            pickle.dump(dic, f)
        
if __name__ == "__main__":
    get_distance()