{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7debdf56",
   "metadata": {},
   "source": [
    "## HumanEval Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b43eb6e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from typing import List, Tuple, Any\n",
    "from datasets import load_dataset\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.style.use(\"bmh\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cc698b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load results\n",
    "df_dir = \"../results/google__gemma-3-1b-it\"\n",
    "df_path = f\"{df_dir}/openai__openai_humaneval_results.csv\"\n",
    "df = pd.read_csv(df_path, index_col=\"question\")\n",
    "methods = [c for c in df.columns if c != \"ground_truth\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bc65d4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "human_eval_dataset = load_dataset(\n",
    "    \"../local/data/openai/openai_humaneval-openai_humaneval\",\n",
    "    \"openai_humaneval\",\n",
    "    split=\"test\",\n",
    "    cache_dir=\"../.hf_cache\",\n",
    "    trust_remote_code=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e6e4952",
   "metadata": {},
   "outputs": [],
   "source": [
    "func_names = human_eval_dataset[\"entry_point\"]\n",
    "tests = human_eval_dataset[\"test\"]\n",
    "\n",
    "run_exceptions = []\n",
    "test_exceptions = []\n",
    "\n",
    "accuracy_results = {}\n",
    "\n",
    "for method in df:     \n",
    "    if method == \"task_id\":\n",
    "        continue\n",
    "        \n",
    "    # Setup a basic execution environment for each method\n",
    "    method_env = {\n",
    "        \"__builtins__\": __builtins__,\n",
    "        \"List\": List,\n",
    "        \"Tuple\": Tuple,\n",
    "        \"Any\": Any\n",
    "    }\n",
    "\n",
    "    code_strs = df[method].values  \n",
    "    \n",
    "    for code in code_strs:\n",
    "        try:\n",
    "            exec(code, method_env)\n",
    "        except Exception as e:\n",
    "            pass\n",
    "        \n",
    "    # Both regex patterns\n",
    "    base_pattern = r\"def\\s+[a-zA-Z_]\\w*\\s*\\([^)]*\\)\\s*(?:->\\s*[\\w\\[\\], ]+)?\\s*:\\n(?:    .*\\n?)+\"\n",
    "    general_pattern = r\"\"\"\n",
    "        def\\s+[a-zA-Z_]\\w*              # function name\n",
    "        \\s*\\([^)]*\\)                    # arguments (...)\n",
    "        \\s*(?:->\\s*[\\w\\[\\], ]+)?\\s*:    # optional return annotation\n",
    "        (?:\\n[ \\t]+.*)*                 # indented body\n",
    "        \"\"\"\n",
    "\n",
    "    func_defs = []\n",
    "\n",
    "    for s in code_strs:\n",
    "        if not isinstance(s, str):\n",
    "            continue\n",
    "\n",
    "        # Try base pattern\n",
    "        m1 = re.search(base_pattern, s, re.MULTILINE)\n",
    "        if m1:\n",
    "            func_defs.append(m1.group(0))\n",
    "\n",
    "        # Try more general pattern\n",
    "        for m2 in re.finditer(general_pattern, s, re.MULTILINE | re.VERBOSE):\n",
    "            func_defs.append(m2.group(0))\n",
    "            \n",
    "    # Initialise each function\n",
    "    for func_def in func_defs:\n",
    "        try:\n",
    "            exec(func_def, method_env)\n",
    "        except Exception as e:\n",
    "            run_exceptions.append(e)\n",
    "            \n",
    "    # Run all of the tests           \n",
    "    results = []\n",
    "            \n",
    "    for test, func_name in zip(tests, func_names):\n",
    "\n",
    "        if func_name == \"how_many_times\":\n",
    "            continue  # Skipping this test as it has issues\n",
    "        \n",
    "        try:\n",
    "            exec(test, method_env)\n",
    "            fn = method_env[func_name]\n",
    "            check_fn = method_env['check']\n",
    "            check_fn(fn)\n",
    "            \n",
    "            # At this point must be valid code and passed the assertions\n",
    "            results.append(1)\n",
    "        except Exception as e:\n",
    "            results.append(0)\n",
    "            test_exceptions.append(e)\n",
    "            \n",
    "    accuracy = 100 * np.mean(results)\n",
    "                \n",
    "    print(method, round(accuracy))\n",
    "    \n",
    "    \n",
    "    accuracy_results[method] = accuracy\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22ee25fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "sizes = {}\n",
    "\n",
    "for method in methods:\n",
    "    if method.startswith(\"proposed\"):\n",
    "        \n",
    "        file_path = f\"{df_dir}/openai__openai_humaneval_{method}_info.csv\"\n",
    "        df = pd.read_csv(file_path)\n",
    "        expansions = df[\"total_branches\"].mean()\n",
    "        print(method, round(expansions))\n",
    "        sizes[method] = expansions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b403243",
   "metadata": {},
   "outputs": [],
   "source": [
    "T = 400\n",
    "proposed_curve = [(accuracy_results[proposed], size) for proposed, size in sizes.items()]\n",
    "comparison_curve = [(accuracy_results[proposed.replace(\"proposed\", \"num_beams\")], int(proposed.split(\"_\")[1]) *  T) for i, (proposed, size) in enumerate(sizes.items())]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faf61d52",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(7, 5))\n",
    "\n",
    "pa, pb = zip(*proposed_curve)\n",
    "ca, cb = zip(*comparison_curve)\n",
    "plt.plot(pb, pa, color='purple', marker=\"*\", alpha=0.5, lw=4, label=\"Proposed\")\n",
    "plt.plot(cb, ca, color='orange',  marker=\"o\", alpha=0.5, lw=4, label=\"Comparison\")\n",
    "\n",
    "plt.xlabel(\"Compute (minimize)\")\n",
    "plt.ylabel(\"Accuracy (maximize)\")\n",
    "plt.title(\"Frontiers: Accuracy vs Compute\")\n",
    "#plt.ylim(0, 100)\n",
    "\n",
    "\n",
    "plt.legend()\n",
    "plt.grid(True, linestyle=\"--\", alpha=0.4)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"../plots/frontier.pdf\", bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
