import os
import pandas as pd
import torch
import numpy as np

def split_data_and_save(dir_list, rates):
    for dir_path in dir_list:
        print("="*30 + dir_path + "="*30)
        for rate in rates:
            # 创建新的目录
            new_dir = os.path.join(f"{dir_path}_{rate*100}")
            os.makedirs(new_dir, exist_ok=True)
            
            # 读取CSV文件
            csv_file = os.path.join(dir_path, "all.csv")
            df = pd.read_csv(csv_file)
            
            # 计算需要保留的数据数量
            train_count = round(len(df) * rate)
            
            # 随机抽取 train_count 条数据作为训练集
            train_df = df.sample(n=train_count, random_state=500)
            
            # 其余作为测试集
            test_df = df.drop(train_df.index)
            
            # 保存为 train.csv 和 test.csv
            train_df.to_csv(os.path.join(new_dir, "train.csv"), index=False)
            test_df.to_csv(os.path.join(new_dir, "test.csv"), index=False)
            
            # 保存对应index
            split_idx = {
                "train": train_df.index.tolist(),
                "test": test_df.index.tolist()
            }
            torch.save(split_idx, os.path.join(new_dir, "split_idx.pt"))

            df.to_csv(os.path.join(new_dir, "all.csv"), index=False)

# 示例使用
dir_list = [
    # "suzuki_50",
    # "arylation",
    "buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv",
    "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv",
    "buchwald_Cc1ccc(Nc2cccnc2)cc1.csv",
    "buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv",
    "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv",
    # "grouped_exp"
]


rates = np.array([0.25, 0.5, 2, 4])/100  # 对应的比例

split_data_and_save(dir_list, rates)