from datasets import load_dataset
from sal.models.reward_models import load_prm
from sal.config import Config
import torch
import os
from datetime import datetime
import argparse

# argparse
parser = argparse.ArgumentParser(description="Evaluate and score dataset using PRM.")
parser.add_argument("--dataset_subset", type=str, required=True, help="Subset of the dataset to load.")
parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the output.")
args = parser.parse_args()

dataset_name = "HuggingFaceH4/Llama-3.2-1B-Instruct-beam-search-completions"
dataset_subset = args.dataset_subset
dataset_split = "train"
output_dir = args.output_dir
# n_samples = 16  # for testing
batch_size = 1

prm_paths = [
    "RLHFlow/Llama3.1-8B-PRM-Deepseek-Data",
    "peiyi9979/math-shepherd-mistral-7b-prm",
    "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B",
    "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B"
]

score_names = ["rlhf_score", "math_shepherd_score", "skywork_1.5B_score", "skywork_7B_score"]

dataset = load_dataset(dataset_name, name=dataset_subset, split=dataset_split)

def add_scores(score_name: str, batch, prm):
    prompts = batch["problem"]
    completions = batch["completions"]
    scores = prm.score(prompts, completions)
    torch.cuda.empty_cache()
    return {score_name: scores}

for prm_path, score_name in zip(prm_paths, score_names):
    print(f"🔄 Scoring with PRM: {prm_path}")

    config = Config(prm_path=prm_path, prm_batch_size=batch_size)
    prm = load_prm(config)

    dataset = dataset.map(
        lambda batch: add_scores(score_name, batch, prm),
        batched=True,
        batch_size=batch_size,
        desc=f"Scoring with {prm_path}",
    )
    print(f"✅ Scored with {prm_path}")

output_filename = dataset_subset + f".jsonl"
output_path = os.path.join(output_dir, output_filename)

os.makedirs(output_dir, exist_ok=True)
dataset.to_json(output_path, lines=True, force_ascii=False)
print(f"✅ Saved scores to: {output_path}")
