from core.utils import *
import torch, os

NUM_CLUSTERD = 4

# json_path = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_encoder/json_files"
# save_path = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_encoder/saved_data_maps"
# for js in os.listdir(json_path):
#     if not js.endswith(".json"):
#         continue
#     if 'old' in js:
#         continue
#     js_path = os.path.join(json_path, js)
    # data_maps = get_embedding(js_path, save_path)
#     file_name = js.split(".")[0]

#     cos_sim_np = get_sim_matrix(data_maps, file_name, "results/", num_clusterd=NUM_CLUSTERD)
#     cluster_similarity(cos_sim_np, list(data_maps.keys()), file_name, "results/")
#     cluster_and_visualize(data_maps, file_name, "results/")

pt_path = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_encoder/saved_data_maps"
for pt in os.listdir(pt_path):
    if "solvent" in pt:
        continue
    if not pt.endswith(".pt"):
        continue
    data_maps = torch.load(os.path.join(pt_path, pt))
    file_name = os.path.basename(pt).split("_dry_")[-1].split("_embedding")[0]

    cos_sim_np, extra_sims = get_sim_matrix(data_maps, file_name, "results/", num_clusterd=NUM_CLUSTERD)
    cluster_similarity(cos_sim_np, list(data_maps.keys()), file_name, "results/", extra_sims=extra_sims)
    # cluster_and_visualize(data_maps, file_name, "results/")