import datasets
from scripts.utils import load_single_dataset, save_dataset
import argparse


def repeat_dataset(dataset: datasets.Dataset, n_repeat: int):
    if "unique_id" in dataset.column_names:
        dataset.remove_columns("unique_id")
    dataset = dataset.add_column("unique_id", list(range(len(dataset))))
    
    new_dataset = datasets.concatenate_datasets([dataset] * n_repeat)
    new_dataset = new_dataset.sort("unique_id")
    return new_dataset


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data", type=str)
    parser.add_argument("--bon", default=64, type=int)
    parser.add_argument("--out_data", type=str)
    args = parser.parse_args()

    dataset = load_single_dataset(args.data)
    if isinstance(dataset, datasets.DatasetDict):
        dataset = datasets.concatenate_datasets(list(dataset.values()))
    new_dataset = repeat_dataset(dataset, n_repeat=args.bon)
    save_dataset(new_dataset, args.out_data)

"""
~/verl_cs/.conda/bin/python ~/verl_cs/scripts/prepare_bob_testset.py \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --bon 64 \
    --out_data ~/datasets/matheval_64.jsonl


~/verl_cs/.conda/bin/python ~/verl_cs/scripts/prepare_bob_testset.py \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-temperature1-not-exceeed/train.parquet \
    --bon 5 \
    --out_data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-temperature1-not-exceeed/train10.parquet
"""
