#!/usr/bin/env python
import argparse
from datasets import load_from_disk, DatasetDict
from utils import load_single_dataset, save_dataset
from transformers import AutoTokenizer

def main():
    parser = argparse.ArgumentParser(description="处理 HuggingFace DatasetDict：新增原始索引列并用固定种子打乱。")
    parser.add_argument("--input_path", type=str, help="输入数据集路径（load_from_disk 的路径）")
    parser.add_argument("--output_path", type=str, help="保存数据集路径（save_to_disk 的目标路径）")
    args = parser.parse_args()

    # 读取已有的数据集
    ds = load_single_dataset(args.input_path, dataset_split="train")
    tokenizer = AutoTokenizer.from_pretrained("~/LLaMA-Factory-250514/saves_shuyan/qwen3-0.6B-base/prime-sft")

    # 新增原始索引列
    ds = ds.map(lambda ex, idx: {"orig_idx": idx}, with_indices=True, num_proc=64)
    ds = ds.filter(lambda row: row["ability"] == "math", num_proc=64)
    ds = ds.filter(lambda row: len(tokenizer.apply_chat_template(row["prompt"], add_generation_prompt=True, tokenize=True)) < 400, num_proc=64)
    ds = ds.shuffle(seed=42)
    print(ds)
    # 保存处理后的数据集
    save_dataset(ds, args.output_path)
    print(f"处理完成，数据已保存到 {args.output_path}")

if __name__ == "__main__":
    main()

"""
~/verl_cs/.conda/bin/python ~/verl_cs/scripts/shuffle_dataset.py \
    --input_path ~/datasets/PRIME-RL-Eurus-2-RL-Data/train_shuffled_math.parquet \
    --output_path ~/datasets/PRIME-RL-Eurus-2-RL-Data/train_shuffled_math_400.parquet
"""