{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "36c7e384",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "\n",
    "def read_jsonl(fp):\n",
    "    \"\"\"\n",
    "    Read a JSONL file and return a list of dictionaries.\n",
    "    \"\"\"\n",
    "    data = []\n",
    "    with open(fp, 'r') as f:\n",
    "        for line in f:\n",
    "            data.append(json.loads(line))\n",
    "    return data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12cf76a5",
   "metadata": {},
   "source": [
    "### S5.5 Self-verification frequency and accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5be508ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "def calculate_verification_pass_rate(data):\n",
    "    \"\"\"\n",
    "    Calculate the average pass rate for responses containing verification keywords.\n",
    "    \n",
    "    Parameters:\n",
    "    -----------\n",
    "    data : list[dict]\n",
    "        Each dict must contain 'response' (list of strings) and 'correct_list' (list of 0/1 values)\n",
    "        where each item in correct_list indicates if the corresponding response is correct\n",
    "    \n",
    "    Returns:\n",
    "    --------\n",
    "    dict:\n",
    "        - verification_responses: number of responses containing verification keywords\n",
    "        - verification_pass_rate: average pass rate for verification responses\n",
    "        - non_verification_pass_rate: average pass rate for non-verification responses\n",
    "        - overall_pass_rate: average pass rate for all responses\n",
    "    \"\"\"\n",
    "    # Define verification keywords\n",
    "    verification_keywords = [\n",
    "        \"verify\", \"verifying\", \"recheck\", \"validate\", \"re-evaluate\"\n",
    "    ]\n",
    "    \n",
    "    verification_correct = 0\n",
    "    verification_total = 0\n",
    "    non_verification_correct = 0\n",
    "    non_verification_total = 0\n",
    "    \n",
    "    for item in data:\n",
    "        responses = item.get(\"response\", [])\n",
    "        correct_list = item.get(\"correct_list\", [])\n",
    "        \n",
    "        if len(responses) != len(correct_list):\n",
    "            raise ValueError(\"Responses and correct_list must have the same length\")\n",
    "            \n",
    "        for response, is_correct in zip(responses, correct_list):\n",
    "            if not isinstance(response, str):\n",
    "                raise ValueError(\"Each response must be a string\")\n",
    "                \n",
    "            # Check if response contains any verification keyword\n",
    "            has_verification = False\n",
    "            for keyword in verification_keywords:\n",
    "                # Use word boundary check to match whole words only\n",
    "                pattern = re.compile(r'\\b' + re.escape(keyword) + r'\\b', re.IGNORECASE)\n",
    "                if pattern.search(response):\n",
    "                    has_verification = True\n",
    "                    break\n",
    "            if has_verification:\n",
    "                verification_total += 1\n",
    "                verification_correct += is_correct\n",
    "            else:\n",
    "                non_verification_total += 1\n",
    "                non_verification_correct += is_correct\n",
    "    \n",
    "    verification_pass_rate = verification_correct / verification_total if verification_total > 0 else 0\n",
    "    non_verification_pass_rate = non_verification_correct / non_verification_total if non_verification_total > 0 else 0\n",
    "    overall_total = verification_total + non_verification_total\n",
    "    overall_correct = verification_correct + non_verification_correct\n",
    "    overall_pass_rate = overall_correct / overall_total if overall_total > 0 else 0\n",
    "    \n",
    "    return {\n",
    "        \"verification_responses\": verification_total,\n",
    "        \"verification_pass_rate\": verification_pass_rate,\n",
    "        \"verification_portion\": verification_total / overall_total if overall_total > 0 else 0,\n",
    "        \"non_verification_pass_rate\": non_verification_pass_rate,\n",
    "        \"overall_pass_rate\": overall_pass_rate,\n",
    "        \"total_responses\": overall_total\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6ce312b",
   "metadata": {},
   "outputs": [],
   "source": [
    "zero1 = read_jsonl('/path/to/zero_rl/math500/generation.jsonl')\n",
    "rise1 = read_jsonl('/path/to/rise/math500/generation.jsonl')\n",
    "\n",
    "zero2 = read_jsonl('/path/to/zero_rl/aime24/generation.jsonl')\n",
    "rise2 = read_jsonl('/path/to/rise/aime24/generation.jsonl')\n",
    "\n",
    "zero3 = read_jsonl('/path/to/zero_rl/amc23/generation.jsonl')\n",
    "rise3 = read_jsonl('/path/to/rise/amc23/generation.jsonl')\n",
    "\n",
    "zero4 = read_jsonl('/path/to/zero_rl/minerva_math/generation.jsonl')\n",
    "rise4 = read_jsonl('/path/to/rise/minerva_math/generation.jsonl')\n",
    "\n",
    "zero5 = read_jsonl('/path/to/zero_rl/olympiad_bench/generation.jsonl')\n",
    "rise5 = read_jsonl('/path/to/rise/olympiad_bench/generation.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a57c5368",
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_collection = [zero1, zero2, zero3, zero4, zero5]\n",
    "rise_collection = [rise1, rise2, rise3, rise4, rise5]\n",
    "\n",
    "for i in range(len(baseline_collection)):\n",
    "    baseline = baseline_collection[i]\n",
    "    rise = rise_collection[i]\n",
    "\n",
    "    print(f\"Zero-RL {i+1}:\")\n",
    "    print(calculate_verification_pass_rate(baseline))\n",
    "    print('-'*40)\n",
    "    print(f\"RISE {i+1}:\")\n",
    "    print(calculate_verification_pass_rate(rise))\n",
    "    print('='*80)\n",
    "    print('\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc91a7e4",
   "metadata": {},
   "source": [
    "### Apx E.3 General Reflection Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "640d40f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "keywords_by_type = {\n",
    "    \"verification\": {\n",
    "        \"verify\": 0,\n",
    "        \"verifying\": 0,\n",
    "        \"recheck\": 0,\n",
    "        \"validate\": 0,\n",
    "        \"re-evaluate\": 0\n",
    "    },\n",
    "    \"general_reflection\": {\n",
    "        \"however\": 0,\n",
    "        \"alternatively\": 0,\n",
    "        \"wait\": 0,\n",
    "        \"retry\": 0,\n",
    "        \"recheck\": 0,\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "d724e0b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from collections.abc import Iterable\n",
    "from copy import deepcopy\n",
    "\n",
    "def count_reflection_keywords(\n",
    "    data: list[dict],\n",
    "    keywords_by_type: dict[str, dict[str, int]],\n",
    "    mode: str = \"word_count\"\n",
    ") -> dict[str, dict[str, int]]:\n",
    "    \"\"\"\n",
    "    Count reflection-related keywords inside `data` and return a new\n",
    "    dictionary with updated counts.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    data : list[dict]\n",
    "        Each dict must contain a key `\"response\"` whose value is an\n",
    "        iterable (list/tuple) of strings.\n",
    "    keywords_by_type : dict[str, dict[str, int]]\n",
    "        Nested dictionary whose innermost keys are the keywords\n",
    "        (all assumed to be lower-case) and whose values are the\n",
    "        starting counts (usually zero).\n",
    "    mode : str\n",
    "        One of two modes:\n",
    "        - `\"word_count\"`: Count total occurrences of keywords in each response.\n",
    "        - `\"response_count\"`: Count the number of responses that contain at least one keyword.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    dict[str, dict[str, int]]\n",
    "        A deep-copied version of `keywords_by_type` with counts\n",
    "        incremented according to the specified mode.\n",
    "    \"\"\"\n",
    "    # --- compile regex patterns once for speed ----------------------\n",
    "    patterns: dict[str, dict[str, re.Pattern]] = {\n",
    "        cat: {kw: re.compile(rf\"\\b{re.escape(kw)}\\b\", re.IGNORECASE)\n",
    "              for kw in kw_dict}\n",
    "        for cat, kw_dict in keywords_by_type.items()\n",
    "    }\n",
    "\n",
    "    # --- prepare a fresh counter dict --------------------------------\n",
    "    counts = deepcopy(keywords_by_type)\n",
    "\n",
    "    # --- scan every response string ----------------------------------\n",
    "    for item in data:\n",
    "        # tolerate non-iterable or missing \"response\" fields\n",
    "        responses: Iterable[str] = item.get(\"response\", [])\n",
    "        for text in responses:\n",
    "            if not isinstance(text, str):\n",
    "                continue\n",
    "            for cat, pat_dict in patterns.items():\n",
    "                found_keywords = set()  # track which keywords are in this response\n",
    "                for kw, pat in pat_dict.items():\n",
    "                    matches = pat.findall(text)\n",
    "                    if matches:\n",
    "                        if mode == \"word_count\":\n",
    "                            counts[cat][kw] += len(matches)  # count each occurrence\n",
    "                        elif mode == \"response_count\":\n",
    "                            found_keywords.add(kw)  # mark that we found a keyword\n",
    "                if mode == \"response_count\" and found_keywords:\n",
    "                    for kw in found_keywords:\n",
    "                        counts[cat][kw] += 1  # increment once per response that has the keyword\n",
    "\n",
    "    return counts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "e90adb50",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pretty_print_counts(\n",
    "    counts: dict[str, dict[str, int]],\n",
    "    total_responses: int | None = None,\n",
    ") -> None:\n",
    "    \"\"\"\n",
    "    Display keyword-frequency statistics in a tidy, indented format.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    counts : dict[str, dict[str, int]]\n",
    "        Nested dict whose innermost values are keyword counts.\n",
    "    total_responses : int | None, optional\n",
    "        If given and > 0, an extra “per-response” figure\n",
    "        (total word count ÷ total_responses) is printed for each category.\n",
    "    \"\"\"\n",
    "    per_resp = (\n",
    "        lambda total: f\" | per-response: {total / total_responses:.3f}\"\n",
    "        if total_responses and total_responses > 0 else \"\"\n",
    "    )\n",
    "\n",
    "    for category, kwdict in counts.items():\n",
    "        header = category.replace(\"_\", \" \").title()\n",
    "        print(header)\n",
    "        print(\"-\" * len(header))\n",
    "\n",
    "        # print individual keyword counts\n",
    "        for kw, n in sorted(kwdict.items(), key=lambda kv: (-kv[1], kv[0])):\n",
    "            print(f\"  {kw:<15} : {n}\")\n",
    "\n",
    "        # total for the category\n",
    "        total = sum(kwdict.values())\n",
    "        print(f\"  {'Total word count':<15} : {total}{per_resp(total)}\\n\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bc53039",
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_collection = [zero1, zero2, zero3, zero4, zero5]\n",
    "rise_collection = [rise1, rise2, rise3, rise4, rise5]\n",
    "\n",
    "for i in range(len(baseline_collection)):\n",
    "    baseline = baseline_collection[i]\n",
    "    rise = rise_collection[i]\n",
    "\n",
    "    total_responses = len(baseline) * len(baseline[0]['response'])\n",
    "    print(\"Total Responses: \", total_responses)\n",
    "    \n",
    "    print(f\"Zero-RL {i+1}:\")\n",
    "    baseline_counts = count_reflection_keywords(baseline, keywords_by_type)\n",
    "    pretty_print_counts(baseline_counts, total_responses)\n",
    "    \n",
    "    print(f\"RISE {i+1}:\")\n",
    "    rise_counts = count_reflection_keywords(rise, keywords_by_type)\n",
    "    pretty_print_counts(rise_counts, total_responses)\n",
    "    print('='*80)\n",
    "    print('\\n')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "trf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
