{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44a7767f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "from collections import Counter\n",
    "from verl.utils.reward_score.latex_math import math_is_equivalent, extract_solution\n",
    "\n",
    "def pass_at_1_from_scores(scores: List[List[int]]) -> float:\n",
    "    \"\"\"\n",
    "    DeepSeek-style pass@1 estimator from k samples per problem.\n",
    "    \"\"\"\n",
    "    total, correct = 0, 0\n",
    "    for row in scores:\n",
    "        correct += sum(int(x) for x in row)\n",
    "        total   += len(row)\n",
    "    return (correct / total) if total else 0.0\n",
    "\n",
    "def consensus_at_k(answers, k, gts):\n",
    "    \"\"\"Majority-vote accuracy from first K answers per problem.\"\"\"\n",
    "    n = len(answers)\n",
    "    hits = 0\n",
    "    for preds, gt in zip(answers, gts):\n",
    "        topk = preds[:k]\n",
    "        freq = Counter(topk)\n",
    "        max_freq = max(freq.values())\n",
    "        winners = sorted([a for a, f in freq.items() if f == max_freq])\n",
    "        voted = winners[0]\n",
    "        hits += math_is_equivalent(voted, gt)\n",
    "    return hits / n if n else 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48d51eec",
   "metadata": {},
   "outputs": [],
   "source": [
    "TEST_DIRS=[\n",
    "    \"datasets/test_aime2024/test.parquet\",\n",
    "    \"datasets/test_aime2025/test.parquet\",\n",
    "    \"datasets/test_amc/test.parquet\",\n",
    "    \"datasets/test_brumo2025/test.parquet\",\n",
    "    \"datasets/test_hmmt_feb_2025/test.parquet\"\n",
    "]\n",
    "\n",
    "MODEL=\"Qwen/Qwen3-1.7B\"\n",
    "OUTPUT_DIR=\"outputs/inference_results/16K/\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84c96fcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import datasets\n",
    "from tqdm.auto import tqdm \n",
    "from pathlib import Path\n",
    "\n",
    "from verl.utils.reward_score import default_compute_score\n",
    "\n",
    "def load_dataset(parquet_path):\n",
    "    dataframe = datasets.load_dataset(\"parquet\", data_files=parquet_path)[\"train\"]\n",
    "    df = dataframe.to_pandas()\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d47c5997",
   "metadata": {},
   "outputs": [],
   "source": [
    "for parquet_path in TEST_DIRS:\n",
    "    df = load_dataset(parquet_path)\n",
    "    original_prompts = df[\"prompt\"].tolist()\n",
    "    ground_truths = df[\"reward_model\"].apply(lambda x: x[\"ground_truth\"]).values\n",
    "    data_src = df[\"data_source\"].iloc[0]\n",
    "\n",
    "    output_filename = f\"{Path(data_src).name}_{Path(MODEL).name}_generations.json\"\n",
    "    output_path = os.path.join(OUTPUT_DIR, output_filename)\n",
    "    with open(output_path, \"r\") as f:\n",
    "        generations = json.load(f)\n",
    "\n",
    "    scores_all = []\n",
    "    preds_all = []\n",
    "    for p, gts in tqdm(zip(original_prompts, ground_truths), total=len(original_prompts)):\n",
    "        po = generations[str(p)][\"generations\"]\n",
    "        scores_p = []\n",
    "        preds_p = []\n",
    "        for o in po:\n",
    "            scores_p.append(default_compute_score(data_src, o, gts)==1)\n",
    "            extracted_pred = extract_solution(o)\n",
    "            preds_p.append(extracted_pred if extracted_pred else \"\")\n",
    "        scores_all.append(scores_p)\n",
    "        preds_all.append(preds_p)\n",
    "    print(f\"{data_src} pass@1\", 100*pass_at_1_from_scores(scores_all))\n",
    "    print(f\"{data_src} cons@32\", 100*consensus_at_k(preds_all, k=32, gts=ground_truths))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28d3ca82",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "verl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
