import os
import pandas as pd

data_name = "suzuki"
data_name = "tandem"
data_name = "buchwald"
data_name = "cpa"

if data_name == "suzuki":
    # 输入文件路径
    file_path = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/data/exp/suzuki/experiment_index.csv"
    # 要分组的列名
    group_cols = ['electrophile', 'nucleophile']    # suzuki
    # 输出目录（不存在则创建）

elif data_name == "tandem":
    # 输入文件路径
    file_path = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/data/exp/tandem/processed.csv"
    # 要分组的列名
    group_cols = ["Product"]    # tandem

elif data_name == "buchwald":
    # 输入文件路径
    file_path = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/data/exp/buchwald/processed.csv"
    # 要分组的列名
    group_cols = ["Product"]    # buchwald

elif data_name == "cpa":
    # 输入文件路径
    file_path = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/data/exp/cpa/processed.csv"
    # 要分组的列名
    group_cols = ["Product"]    # cpa



output_dir = f"/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/data/grouped_exp/{data_name}"
os.makedirs(output_dir, exist_ok=True)

# 读取数据
df = pd.read_csv(file_path)
mp = dict()
# 按照 group_cols 分组
for group_vals, group_df in df.groupby(group_cols):
    safe_names = [str(v).replace('/', '_') for v in group_vals]
    filename = "_".join(safe_names) + ".csv"
    out_path = os.path.join(output_dir, filename)

    # 保存分组后的 DataFrame
    group_df.to_csv(out_path, index=False)