import pandas as pd
import torch

file_path = "/fs-computility/ai4phys/caipengxiang.p/v1.0_ChemBOMAS_WetExp_TongJi/mas/data/exp/suzuki/experiment_index.csv"
all_csv_path = "/fs-computility/ai4phys/caipengxiang.p/v1.0_ChemBOMAS_WetExp_TongJi/train_regression/data4regression/suzuki_60/all.csv"

# 随机抽取60行，保存为新的csv
df = pd.read_csv(all_csv_path)
df_sampled = df.sample(n=60, random_state=1111)
df_sampled.to_csv("train.csv", index=False, encoding="utf-8")

# 其余作为test
df_test = df.drop(df_sampled.index)
df_test.to_csv("test.csv", index=False, encoding="utf-8")

# 保存对应index
df_index = df_sampled.index.tolist()
split_idx = {}

split_idx["train"] = df_index
split_idx["test"] = df_test.index.tolist()

torch.save(split_idx, "split_idx.pt")

# 对应保存file_path的结果
file_df = pd.read_csv(file_path)
sampled_file_df = file_df[file_df.index.isin(df_index)]
sampled_file_df.to_csv("train_file.csv", index=False, encoding="utf-8")

sampled_file_df_test = file_df[file_df.index.isin(df_test.index)]
sampled_file_df_test.to_csv("test_file.csv", index=False, encoding="utf-8")
