import numpy as np
import argparse
from verl.utils.reward_score import default_compute_score
from verl.workers.reward_manager.prime_backup import run_reward_scoring
from utils import save_dataset
from utils import load_single_dataset

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, required=True)
    parser.add_argument('--num_responses_per_prompt', type=int, required=True)
    parser.add_argument('--num_processes', type=int, required=False, default=128)
    parser.add_argument('--begin', type=int, required=False, default=None)
    parser.add_argument('--end', type=int, required=False, default=None)
    parser.add_argument("--debug", action="store_true", help="Enable debug mode")
    parser.add_argument("--save_file", type=str, required=False)
    args = parser.parse_args()
    
    # prepare for score
    ds = load_single_dataset(args.data)
    if args.debug:
        ds = ds.select(range(100))

    
    if args.begin is not None or args.end is not None:
        n = len(ds)
        begin = args.begin if args.begin is not None else 0
        end = args.end if args.end is not None else n
        if begin < 0 or end < 0 or begin > n or end > n or begin > end:
            raise ValueError(f"Invalid slice range: begin={begin}, end={end}, length={n}")
        ds = ds.select(range(begin, end))

    sequences_strs, ground_truths, data_sources, indices = [], [], [], []
    for (i, row) in enumerate(ds):
        for j, (response, finish_reason) in enumerate(zip(row["responses"], row["finish_reasons"])):
            if finish_reason == "stop":
                sequences_strs.append(response)
                ground_truths.append(row["reward_model"]["ground_truth"])
                data_sources.append(row["data_source"])
                indices.append((i, j))
    # score
    scores = run_reward_scoring(
        default_compute_score,
        sequences_strs, 
        ground_truths,
        data_sources,
        num_processes=args.num_processes
        )
    
    # log the score
    scores_arr = np.zeros((len(ds), args.num_responses_per_prompt), dtype=int)
    for (index, score) in zip(indices, scores):
        scores_arr[index] = int(score)
    scores_arr = scores_arr.tolist()
    ds = ds.add_column("scores", scores_arr)
    save_dataset(ds, args.save_file)
        

"""
    
~/verl_cs/.conda/bin/python ~/verl_cs/scripts/dsfilter_2_scored.py \
    --data ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/validation_0_2048.json \
    --num_responses_per_prompt 4 \
    --num_processes 64 \
    --save_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/validation_0_2048_scored.json

"""

