import os
import torch
import core.cluster_utils as cluster_utils
pt_path = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_encoder/distance_cluster_results"
save_path = "distance_cluster_results"
pt_files = [f for f in os.listdir(pt_path) if f.endswith(".pt")]

for file in os.listdir(save_path):
    os.remove(os.path.join(save_path, file)) if file.endswith("png") or file.endswith("json") else None

for pt in pt_files:
    embed_maps = torch.load(os.path.join(pt_path, pt))
    file_name = os.path.basename(pt).split("openai_")[-1].split("_embedding")[0]
    
    EXTRA_NAME = "nothing"
    if "solvent" in file_name:
        EXTRA_NAME = None
        
    print("="*20, file_name, "="*20)
        
    dis_loss = cluster_utils.cluster_with_distance(embed_maps, file_name, extra_name=True, save_path=save_path, return_raw_distance_list=False, algorithm="kmeans")
    print("kmeans", file_name, dis_loss)
    
    dis_loss = cluster_utils.cluster_with_distance(embed_maps, file_name, extra_name=True, save_path=save_path, return_raw_distance_list=False, algorithm="hierarchical")
    print("hierarchical", file_name, dis_loss)
    
    cluster_cos_sim_np, extra_sims = cluster_utils.get_sim_matrix(embed_maps, use_extra=EXTRA_NAME)
    dis_loss = cluster_utils.cluster_similarity_virtual_point(
        cluster_cos_sim_np, 
        list(embed_maps.keys()), 
        file_name, embed_maps, 
        extra_sims=extra_sims, 
        extra_name=EXTRA_NAME,
        save_path=save_path
    )
    print('Spectral', file_name, dis_loss)

"""
base: [34.635754, 0.0, 23.22731]
solvent: [25.621775, 23.232212, 0.0]
ligand: [33.727528, 34.972168, 27.5693]
"""