import os, json
from collections import defaultdict

USE_EXTRA = False

save_path = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_encoder/exp_embed_cluster_results/summary"
results_path = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_encoder/exp_embed_cluster_results"

os.makedirs(save_path, exist_ok=True)

for dataset in ["suzuki_50", "arylation", "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv", "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv"]:
    json_path = os.path.join(results_path, dataset)
    # 获取当前目录所有JSON文件
    files = [f for f in os.listdir(json_path) if f.endswith('.json')]

    # 创建按聚类方法分类的数据结构
    clusters = defaultdict(dict)

    for filename in files:
        # 解析文件名结构
        parts = filename.split('_')
        component = parts[1]
        filename = os.path.join(json_path, filename)

        # 读取文件内容
        with open(filename, "r") as f:
            data = json.load(f)
        
        # 存储到对应聚类方法
        clusters[dataset] = {**clusters[dataset], **data}

    # 保存合并后的文件
    for dataset, data in clusters.items():
        output_file = os.path.join(save_path, f'{dataset}_clusters.json')
        with open(output_file, 'w') as f:
            json.dump(data, f, indent=2)