{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b43eb6e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import glob\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from math_verify import parse, verify\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.style.use(\"bmh\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cc698b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load results\n",
    "df_dir = \"../results/ibm-granite__granite-3.3-2b-instruct/\"\n",
    "df_path = f\"{df_dir}/openai__gsm8k_results.csv\"\n",
    "df = pd.read_csv(df_path, index_col=\"question\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "670e56ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Direct answer\n",
    "def extract_ans(text):\n",
    "    if text is None or pd.isna(text):\n",
    "        return None\n",
    "    \n",
    "    # simplified example\n",
    "    patterns = [r'####\\s*(\\d+)', r'\\\\?boxed\\{?(\\d+)\\}?']\n",
    "    for p in patterns:\n",
    "        m = re.search(p, text)\n",
    "        if m:\n",
    "            return int(m.group(1))\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55e4c073",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_results(df):\n",
    "    \n",
    "    correct = df[\"ground_truth\"].apply(lambda x: extract_ans(x))\n",
    "\n",
    "    results = {}\n",
    "        \n",
    "    for method in df:        \n",
    "        # Two forms of extraction, if either match consider correct\n",
    "        res_a = df[method].apply(lambda x: extract_ans(x))\n",
    "        option_a = res_a == correct\n",
    "        \n",
    "        option_b = []\n",
    "        \n",
    "        for q, row in df.iterrows():\n",
    "            ans = parse(row[method])\n",
    "            truth = parse(row[\"ground_truth\"])\n",
    "            match = verify(ans, truth)\n",
    "            option_b.append(match)\n",
    "            \n",
    "        res = np.logical_or(option_a, option_b)        \n",
    "        print(method,  round(100 * res.mean()))\n",
    "        \n",
    "        results[method] = res.mean()\n",
    "        \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89edd945",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = get_results(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f169fe4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "dirs = \"../results/closed/*/meta-llama__Llama-3.2-3B-Instruct\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc9711ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "\n",
    "for k_dir in glob.glob(dirs):\n",
    "    df_path = f\"{k_dir}/openai__gsm8k_results.csv\"\n",
    "    k = int(df_path.split(\"closed/\")[1].split(\"/\")[0])\n",
    "    df = pd.read_csv(df_path, index_col=\"question\")\n",
    "    \n",
    "    results[k] = df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e653b488",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = sorted(results.items(), key=lambda x: x[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "245d61d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "k_results = {}\n",
    "\n",
    "for k, df in results:\n",
    "    print(f\"K={k}\")\n",
    "    k_result = get_results(df)\n",
    "    # Top-k to actual k\n",
    "    k_result = {key.replace(\"k\", str(k)): v for key, v in k_result.items()}\n",
    "    \n",
    "    proposed = k_result.pop(\"proposed\")\n",
    "    k_result[f\"top-{k} + EDEN\"] = proposed\n",
    "    \n",
    "    k_results = k_results | k_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a33d24d",
   "metadata": {},
   "outputs": [],
   "source": [
    "eden = pd.read_csv(\"../results/meta-llama__Llama-3.2-3B-Instruct/openai__gsm8k_results.csv\")\n",
    "eden_results = get_results(eden)\n",
    "k_results[\"EDEN\"] = eden_results[\"proposed_5\"]\n",
    "k_results.pop(\"ground_truth\", None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ad6c9ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = list(k_results.keys())\n",
    "ys = list(k_results.values())\n",
    "\n",
    "plt.figure(figsize=(8, 4))\n",
    "\n",
    "max_idx = ys.index(max(ys))\n",
    "last_idx = len(xs) - 1\n",
    "\n",
    "# ---- Bars ----\n",
    "for i, (x, y) in enumerate(zip(xs, ys)):\n",
    "    color = \"purple\" if \"EDEN\" in x else \"orange\"\n",
    "    if \"greedy\" in x:\n",
    "        color = \"gray\"\n",
    "    alpha = 1 if i == max_idx else 0.7\n",
    "    \n",
    "    if i == 0:\n",
    "        i -= 0.5\n",
    "    elif i == last_idx:\n",
    "        i += 0.5\n",
    "    \n",
    "    plt.bar(i, y, color=color, alpha=alpha, zorder=3)\n",
    "\n",
    "plt.ylabel(\"Accuracy\")\n",
    "\n",
    "# ---- Values above bars ----\n",
    "for i, val in enumerate(ys):\n",
    "    if i == 0:\n",
    "        i -= 0.5\n",
    "    elif i == last_idx:\n",
    "        i += 0.5\n",
    "        \n",
    "    plt.text(i, val + 0.005, f\"{val:.2f}\", ha='center', fontsize=12)\n",
    "\n",
    "# ---- Dashed vertical separators ----\n",
    "group_boundaries = []\n",
    "for i, label in enumerate(xs[:-1]):\n",
    "    current_group = label.split()[0]\n",
    "    next_group = xs[i+1].split()[0]\n",
    "    if current_group != next_group:\n",
    "        # Center the vertical line between the bars\n",
    "        group_boundaries.append(i + 0.5)\n",
    "\n",
    "\n",
    "group_boundaries[-1] += 0.25  # move last boundary to the right\n",
    "\n",
    "for i, b in enumerate(group_boundaries):\n",
    "    linestyle = \"-\" if i in [0, len(group_boundaries)-1] else \"--\"\n",
    "    plt.axvline(x=b, color=\"black\",\n",
    "                linestyle=linestyle, alpha=0.5, zorder=2)\n",
    "\n",
    "# ---- Group labels ----\n",
    "group_labels = {\n",
    "    \"K=5\": [1, 2],\n",
    "    \"K=10\": [3, 4],\n",
    "    \"K=20\": [5, 6]\n",
    "}\n",
    "\n",
    "y_top = 1.05 * max(ys)\n",
    "for label, idxs in group_labels.items():\n",
    "    center = (min(idxs) + max(idxs)) / 2\n",
    "    plt.text(center, y_top, label, ha=\"center\", va=\"bottom\",\n",
    "             fontsize=13, fontweight=\"bold\", alpha=0.4)\n",
    "    \n",
    "    \n",
    "# ---- Bracket helper (drawn in axes coords so it sits below axis) ----\n",
    "def draw_bracket(x0, x1, text, ax, height=0.04, y=-0.12, text_offset=0.01):\n",
    "    \"\"\"\n",
    "    Draw a rectangular-ish bracket between data x0 and x1, placed at axes y.\n",
    "    text uses normal weight (not bold).\n",
    "    \"\"\"\n",
    "    # map data x to axes fraction [0..1]\n",
    "    x0_disp = (x0 - ax.get_xlim()[0]) / (ax.get_xlim()[1] - ax.get_xlim()[0])\n",
    "    x1_disp = (x1 - ax.get_xlim()[0]) / (ax.get_xlim()[1] - ax.get_xlim()[0])\n",
    "\n",
    "    ax.plot([x0_disp, x0_disp, x1_disp, x1_disp],\n",
    "            [y, y-height, y-height, y],\n",
    "            transform=ax.transAxes, color=\"gray\", lw=1.2, clip_on=False, zorder=6)\n",
    "    ax.text((x0_disp + x1_disp) / 2, y - height - text_offset, text,\n",
    "            transform=ax.transAxes, ha=\"center\", va=\"top\", fontsize=12, fontweight='normal', color=\"gray\", zorder=6)\n",
    "\n",
    "ax = plt.gca()\n",
    "\n",
    "# Closed: from leftmost bar to the one before last\n",
    "draw_bracket(-1.0, last_idx - 0.25, \"Closed\", ax)\n",
    "\n",
    "# Open: last bar only\n",
    "draw_bracket(last_idx - 0.25, last_idx + 1.275,\n",
    "             \"Open\", ax)\n",
    "\n",
    "# ---- Ellipsis under the final vertical line to indicate spectrum between Closed and Open ----\n",
    "# ---- Ellipsis under the final vertical line ----\n",
    "final_line_x = group_boundaries[-1]  # last boundary line\n",
    "plt.text(final_line_x, 0.9 * min(ys), \"...\", ha=\"center\", va=\"top\", fontsize=16, color=\"black\", alpha=0.85)\n",
    "\n",
    "plt.title(\"API Strategy\")\n",
    "plt.ylim(0.9 * min(ys), 1.1 * max(ys))\n",
    "\n",
    "x_ticks = [x for x in range(len(xs))]\n",
    "x_ticks[0] -= 0.5\n",
    "x_ticks[-1] += 0.5\n",
    "\n",
    "plt.xticks(x_ticks, [x.replace(\" + \", \"+\\n\") + \" \" for x in xs], rotation=0, ha=\"center\")\n",
    "\n",
    "plt.grid(False)\n",
    "\n",
    "plt.savefig(\"../plots/comparison_gsm8k.pdf\", bbox_inches='tight')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
