import pandas as pd
import torch

csv_file = "all.csv"
rate = 50 / 5790

# 读取数据
df = pd.read_csv(csv_file)

# 计算需要保留的数据数量
# 四舍五入
train_count = round(len(df) * rate)

# 随机抽取 train_count 条数据作为训练集
train_df = df.sample(n=train_count, random_state=42)

# 其余作为测试集
test_df = df.drop(train_df.index)

# 保存为 train.csv 和 test.csv
train_df.to_csv("train.csv", index=False)
test_df.to_csv("test.csv", index=False)

# 保存对应index
split_idx = {}
split_idx["train"] = train_df.index.tolist()
split_idx["test"] = test_df.index.tolist()

torch.save(split_idx, "split_idx.pt")