import os
import torch
from loguru import logger
import core.cluster_utils as cluster_utils
dir_path = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_encoder/exp_embed_cluster_results"

EXTRA_NAME = "nothing"

for dir in os.listdir(dir_path):

    logger.success(dir)

    save_path = os.path.join("exp_embed_cluster_results", dir)
    
    os.makedirs(save_path, exist_ok=True)
    
    pt_path = os.path.join(dir_path, dir)
    pt_files = [f for f in os.listdir(pt_path) if f.endswith(".pt")]

    for file in os.listdir(save_path):
        os.remove(os.path.join(save_path, file)) if file.endswith("png") or file.endswith("json") else None

    for pt in pt_files:
        embed_maps = torch.load(os.path.join(pt_path, pt))
        n = pt_path.split("/")[-1]+"_"
        file_name = os.path.basename(pt).split(n)[-1].split("_embedding")[0]
        
        
            
        print("="*20, file_name, "="*20)
            
        cluster_utils.cluster_with_distance(
            embed_maps, 
            file_name=file_name, 
            extra_name=EXTRA_NAME,
            algorithm="kmeans",
            save_path=save_path
        )