import argparse
from utils import load_single_dataset, save_dataset
from datasets import Dataset, DatasetDict
from verl.workers.reward_manager.prime import run_reward_scoring
from verl.utils.reward_score import default_compute_score


def filter_outlength_row(row):
    return any([fr == "stop" for fr in row["finish_reasons"]])


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, required=True)
    parser.add_argument('--data_split', type=str, required=False, default=None)
    parser.add_argument('--save_ds', type=str, required=True)
    args = parser.parse_args()
    ds: Dataset = load_single_dataset(dataset_path=args.data, dataset_split=args.data_split)
    ds = ds.filter(filter_outlength_row, num_proc=64)
    print(len(ds))
    save_dataset(ds, args.save_ds)


"""



~/verl_cs/.conda/bin/python ~/verl_cs/scripts/dsfilter_1_out_length.py \
    --data    ~/datasets/PRIME-RL-Eurus-2-RL-Data/validation.parquet \
    --save_ds ~/datasets/PRIME-RL-Eurus-2-RL-Data/validation_math_subset.json

    
"""
