import asyncio
import json
import os
import argparse
from verl.utils.reward_score import default_compute_score
from verl.workers.reward_manager.prime import run_reward_scoring
from scripts.utils import save_dataset
import datasets


def load_jsonl(jsonl_path):
    return [json.loads(row) for row in open(jsonl_path, "r").readlines()]


def load_single_dataset(dataset_path: str, dataset_split: str = None) -> datasets.Dataset:
    # load from file
    if os.path.isfile(dataset_path):
        if dataset_path.endswith("jsonl"):
            dataset = load_jsonl(dataset_path)
        elif dataset_path.endswith("json"):
            dataset = json.load(open(dataset_path, "r"))
        elif dataset_path.endswith("parquet"):
            dataset = datasets.load_dataset('parquet', data_files=dataset_path)
        else: 
            raise RuntimeError(f"No support file type for {dataset_path.split('.')[-1]}")
        if isinstance(dataset, list):
            dataset = datasets.Dataset.from_list(dataset)
    
    # load from fold
    else:
        try:
            return datasets.load_dataset(dataset_path, split=dataset_split)
        except ValueError:
            dataset = datasets.load_from_disk(dataset_path)
    
    # dataset split
    if dataset_split is not None and isinstance(dataset, datasets.DatasetDict):
        dataset = dataset[dataset_split]
    return dataset
        

def prepare_eval_row(row):
    return {
        "sequences_str": row["response"],
        "ground_truth": row["reward_model"]["ground_truth"],
        "data_source": row["data_source"]
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, required=True)
    parser.add_argument("--dataset_split", type=str, required=False, default=None)
    parser.add_argument("--num-responses-per-prompt", type=int, required=True)
    parser.add_argument("--save_file", type=str, required=False)
    parser.add_argument("--begin",  type=int, required=False, default=None)
    parser.add_argument("--end",    type=int, required=False, default=None)
    args = parser.parse_args()
    test_set = load_single_dataset(args.data, dataset_split=args.dataset_split)

    # —— 安全切片：用 select 而不是 Python 切片 ——
    if args.begin is not None or args.end is not None:
        n = len(test_set)
        begin = args.begin if args.begin is not None else 0
        end = args.end if args.end is not None else n
        if begin < 0 or end < 0 or begin > n or end > n or begin > end:
            raise ValueError(f"Invalid slice range: begin={begin}, end={end}, length={n}")
        test_set = test_set.select(range(begin, end))

    test_set_split = {"sequences_str": [], "ground_truth": [], "data_source": []}
    for row in test_set:
        test_set_split["sequences_str"].extend(row["responses"])
        test_set_split["ground_truth"].extend([row["reward_model"]["ground_truth"]] * args.num_responses_per_prompt)
        test_set_split["data_source"].extend([row["data_source"]] * args.num_responses_per_prompt)
    scores = run_reward_scoring(
        default_compute_score,
        test_set_split["sequences_str"], 
        test_set_split["ground_truth"],
        test_set_split["data_source"],
        num_processes=16
        )
    scores_chunk = [scores[i:i + args.num_responses_per_prompt] for i in range(0, len(scores), args.num_responses_per_prompt)]
    test_set = test_set.add_column("scores", scores_chunk)
    
    if args.save_file is not None:
        save_dataset(test_set, args.save_file)
    
    score_dict = {}
    count_dict = {}
    for task, score in zip(test_set_split["data_source"], scores):
        count_dict[task] = count_dict.get(task, 0) + 1
        if score > 0.5:
            score_dict[task] = score_dict.get(task, 0) + 1
    print(score_dict)
    print(count_dict)
    for task, cnt in count_dict.items():
        print(task, round(score_dict.get(task, 0) / cnt * 100, 2), "%")
    print("all", sum(score_dict.values()) / sum(count_dict.values()))

"""


~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/bon1_verify.py \
    --data ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/bon_test_0_842.json \
    --num-responses-per-prompt 64 \
    --save_file ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/bon_test_0_842_scored.json










~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/bon1_verify.py \
    --data ~/LLaMA-Factory-250514/saves_shuyan/prime_math_valid_0_203.json \
    --num-responses-per-prompt 64 \
    --save_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/prime_math_valid_0_203_scored.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/bon1_verify.py \
    --data ~/LLaMA-Factory-250514/saves_shuyan/prime_math_valid_203_406.json \
    --num-responses-per-prompt 64 \
    --save_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/prime_math_valid_203_406_scored.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/bon1_verify.py \
    --data ~/LLaMA-Factory-250514/saves_shuyan/prime_math_valid_406_609.json \
    --num-responses-per-prompt 64 \
    --save_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/prime_math_valid_406_609_scored.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/bon1_verify.py \
    --data ~/LLaMA-Factory-250514/saves_shuyan/prime_math_valid_609_812.json \
    --num-responses-per-prompt 64 \
    --save_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/prime_math_valid_609_812_scored.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/bon1_verify.py \
    --data ~/LLaMA-Factory-250514/saves_shuyan/prime_math_valid_812_1014.json \
    --num-responses-per-prompt 64 \
    --save_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/prime_math_valid_812_1014_scored.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/bon1_verify.py \
    --data ~/LLaMA-Factory-250514/saves_shuyan/prime_math_valid_1014_1216.json \
    --num-responses-per-prompt 64 \
    --save_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/prime_math_valid_1014_1216_scored.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/bon1_verify.py \
    --data ~/LLaMA-Factory-250514/saves_shuyan/prime_math_valid_1216_1418.json \
    --num-responses-per-prompt 64 \
    --save_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/prime_math_valid_1216_1418_scored.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/bon1_verify.py \
    --data ~/LLaMA-Factory-250514/saves_shuyan/prime_math_valid_1418_1620.json \
    --num-responses-per-prompt 64 \
    --save_file ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/prime_math_valid_1418_1620_scored.json


"""