import pandas as pd
import torch

# file_path = "/fs-computility/ai4phys/caipengxiang.p/v1.0_ChemBOMAS_WetExp_TongJi/mas/data/exp/suzuki/experiment_index.csv"
all_csv_path = "all.csv"

df = pd.read_csv(all_csv_path)

df_sampled = df.sample(n=50, random_state=500)

df_sampled.to_csv("train.csv", index=False, encoding="utf-8")

# 其余作为test
df_test = df.drop(df_sampled.index)
df_test.to_csv("test.csv", index=False, encoding="utf-8")

# 保存对应index
df_index = df_sampled.index.tolist()
split_idx = {}

split_idx["train"] = df_index
split_idx["test"] = df_test.index.tolist()

torch.save(split_idx, "split_idx.pt")

# 对应保存file_path的结果
# file_df = pd.read_csv(file_path)
# sampled_file_df = file_df[file_df.index.isin(df_index)]
# sampled_file_df.to_csv("train_file.csv", index=False, encoding="utf-8")

# sampled_file_df_test = file_df[file_df.index.isin(df_test.index)]
# sampled_file_df_test.to_csv("test_file.csv", index=False, encoding="utf-8")
