{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "990d4224",
   "metadata": {},
   "source": [
    "## SciBench Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b43eb6e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import pandas as pd\n",
    "from datasets import load_dataset\n",
    "from math_verify import parse, verify\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cc698b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load results\n",
    "df_path = \"../results/google__gemma-3-1b-it/xw27__scibench_results.csv\"\n",
    "df = pd.read_csv(df_path, index_col=\"question\")\n",
    "methods = [c for c in df.columns if c != \"ground_truth\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e283e244",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = \"xw27/scibench\"\n",
    "dataset_subset = \"default\"\n",
    "dataset_split = \"train\"\n",
    "text_field = \"problem_text\"         # or \"question\" depending what the JSON key is\n",
    "answer_field = \"answer_number\"  # ground truth answer field\n",
    "choices_field = None\n",
    "\n",
    "local_dataset_dir = f\"../local/data/{dataset_name}-{dataset_subset}\"\n",
    "hf_cache_dir = \"../.hf_cache\"\n",
    "dataset = load_dataset(local_dataset_dir, dataset_subset, split=dataset_split, cache_dir=hf_cache_dir, trust_remote_code=True)\n",
    "\n",
    "# Clean up dataset to remove columns with None values\n",
    "cols_with_none = []\n",
    "for col in dataset.column_names:\n",
    "    # Check if any example in this column is None\n",
    "    if any(example[col] is None for example in dataset):\n",
    "        cols_with_none.append(col)\n",
    "\n",
    "dataset = dataset.remove_columns(cols_with_none)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b7132b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_ans(text):\n",
    "    if not text or pd.isna(text):\n",
    "        return None\n",
    "\n",
    "    # --- Helper: balanced-brace extractor ---\n",
    "    def extract_balanced(text, start):\n",
    "        i = start + len(r'\\boxed{')\n",
    "        depth = 1\n",
    "        buf = []\n",
    "        while i < len(text) and depth > 0:\n",
    "            if text[i] == '{':\n",
    "                depth += 1\n",
    "            elif text[i] == '}':\n",
    "                depth -= 1\n",
    "                if depth == 0:\n",
    "                    break\n",
    "            if depth > 0:\n",
    "                buf.append(text[i])\n",
    "            i += 1\n",
    "        return \"\".join(buf) if depth == 0 else None\n",
    "\n",
    "    # --- 1. Look for last well-formed \\boxed{...} ---\n",
    "    start = text.rfind(r'\\boxed{')\n",
    "    if start != -1:\n",
    "        ans = extract_balanced(text, start)\n",
    "        if ans is not None and ans.strip():\n",
    "            return ans.strip().strip(\". \")\n",
    "\n",
    "    # --- 2. Plain text fallbacks ---\n",
    "    plain_patterns = [\n",
    "        r'(?:final answer|answer)[:=]?\\s*([^\\n\\.]+)',\n",
    "    ]\n",
    "    for pat in plain_patterns:\n",
    "        m = re.search(pat, text, flags=re.IGNORECASE)\n",
    "        if m:\n",
    "            ans = m.group(1).strip().strip(\". \")\n",
    "            # strip leading '=' or ':' if still present\n",
    "            ans = re.sub(r'^[=:]+\\s*', '', ans)\n",
    "            return ans if ans else None\n",
    "\n",
    "    # --- 3. Cleanup stray \"is: $3150\" style ---\n",
    "    m = re.search(r'is[: ]+\\$?([0-9][^\\s\\}]*)', text, flags=re.IGNORECASE)\n",
    "    if m:\n",
    "        ans = m.group(1).strip().strip(\". \")\n",
    "        return ans if ans else None\n",
    "\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dc10893",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_math_matches(x, y):\n",
    "    math_matches = []\n",
    "    \n",
    "    for answer, gt in zip(x, y):\n",
    "        ans = parse(answer)\n",
    "        truth = parse(gt)\n",
    "        match = verify(ans, truth)\n",
    "        math_matches.append(match)\n",
    "    \n",
    "    return math_matches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab7205e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "correct = df[\"ground_truth\"].values\n",
    "correct = [x[1:] if x.startswith(\"+\") else x for x in correct]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55e4c073",
   "metadata": {},
   "outputs": [],
   "source": [
    "for method in methods:  \n",
    "    \n",
    "     # Try exact match\n",
    "    exact_match = (df[method] == correct).values\n",
    "    \n",
    "    # Extracted match\n",
    "    res = df[method].apply(lambda x: extract_ans(x))\n",
    "    extracted_match = (res == correct).values\n",
    "        \n",
    "    # Try match matches from math_verify    \n",
    "    math_matches = get_math_matches(df[method].values, correct)\n",
    "    math_extracted_matches = get_math_matches(res, correct)\n",
    "    \n",
    "    matches = np.logical_or.reduce((exact_match, extracted_match, math_matches, math_extracted_matches))\n",
    "    \n",
    "    print(method, \"Matches:\", round(100 * np.mean(matches)), \"%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faf14ac6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.isna().sum()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
