import argparse
from utils import load_single_dataset, save_dataset
from datasets import Dataset, DatasetDict
from verl.workers.reward_manager.prime import run_reward_scoring
from verl.utils.reward_score import default_compute_score


def filter_outlength_row(row):
    return any([fr == "stop" for fr in row["finish_reasons"]])


def is_all_true(row):
    return all([score > 0.5 for score in row['scores']])

def is_all_false(row):
    return all([score < 0.5 for score in row['scores']])

def is_partial_true(row):
    return not is_all_true(row) and not is_all_false(row)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, required=True)
    parser.add_argument('--num_responses_per_prompt', type=int, required=True)
    parser.add_argument('--save_ds', type=str, required=True)
    args = parser.parse_args()
    ds: Dataset = load_single_dataset(args.data)
    ds = ds.filter(filter_outlength_row, num_proc=64)
    print(len(ds))

    test_set_split = {"sequences_str": [], "ground_truth": [], "data_source": []}
    for row in ds:
        test_set_split["sequences_str"].extend(row["responses"])
        test_set_split["ground_truth"].extend([row["reward_model"]["ground_truth"]] * args.num_responses_per_prompt)
        test_set_split["data_source"].extend([row["data_source"]] * args.num_responses_per_prompt)
    scores = run_reward_scoring(
        default_compute_score,
        test_set_split["sequences_str"], 
        test_set_split["ground_truth"],
        test_set_split["data_source"],
        num_processes=16
        )
    scores_chunk = [scores[i:i + args.num_responses_per_prompt] for i in range(0, len(scores), args.num_responses_per_prompt)]
    ds = ds.add_column("scores", scores_chunk)

    ds1 = DatasetDict({
        "all_true": ds.filter(is_all_true, num_proc=64),
        "all_false": ds.filter(is_all_false, num_proc=64),
        "true_and_false": ds.filter(is_partial_true, num_proc=64),
    })

    ds1.save_to_disk(args.save_ds)


"""



~/verl_cs/.conda/bin/python ~/verl_cs/scripts/filter_rl_dataset.py \
    --data ~/LLaMA-Factory-250514/saves/qwen3-1.7B/prime-rl-t1n5.jsonl \
    --num_responses_per_prompt 5 \
    --save_ds ~/LLaMA-Factory-250514/saves/qwen3-1.7B/prime_filtered_dataset 

    

"""
