{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2e60528f-b13c-4af3-b22a-e9ca64a2c345",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "from matplotlib import pyplot as plt\n",
    "import pickle\n",
    "import numpy as np\n",
    "import networkx as nx\n",
    "from tqdm import tqdm\n",
    "import glob\n",
    "import json\n",
    "\n",
    "from tot import get_task as get_task_tot\n",
    "from tasks import get_task\n",
    "\n",
    "plt.rcParams['pdf.fonttype'] = 42\n",
    "plt.rcParams['ps.fonttype'] = 42\n",
    "\n",
    "import math\n",
    "\n",
    "SUFFIXES = ['', 'K', 'M', 'B', 'T', 'P', 'E']   # up to exa\n",
    "\n",
    "def human_number_format(n: float, precision: int = 1) -> str:\n",
    "    \"\"\"\n",
    "    Convert 31510669  -> '31.5M'\n",
    "            -4200     -> '-4.2K'\n",
    "              27      -> '27'\n",
    "    \"\"\"\n",
    "    if n == 0:\n",
    "        return f'0'\n",
    "\n",
    "    magnitude = int(math.log10(abs(n)) // 3)         # 0 for <1K, 1 for <1M, …\n",
    "    magnitude = min(magnitude, len(SUFFIXES) - 1)    # cap if number is huge\n",
    "    scaled = n / 10**(3 * magnitude)\n",
    "\n",
    "    fmt = f'{{:.{precision}f}}{{}}'\n",
    "    return fmt.format(scaled, SUFFIXES[magnitude])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46a7c39b-24e8-4e62-a775-b6fcb603049e",
   "metadata": {},
   "source": [
    "# Game24\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e08e0564-d83b-47ac-b480-1ac062dabde9",
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = [\"tot_test_split\", \"four_digits_unsolvable\"]\n",
    "model_names = {\n",
    "    \"mistralai/Mistral-Small-24B-Instruct-2501\": \"Mistral-Small\",\n",
    "    \"microsoft/phi-4\": \"Phi-4\",\n",
    "    \"RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic\": \"LLama-3.3-70B\",\n",
    "}\n",
    "tasks = {d:get_task(d) for d in datasets}\n",
    "tot_tasks = {d:get_task_tot(d) for d in datasets}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e7ecee8f-f2a3-4b4f-901e-ec6843fc4bfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def success_any(p: float, k: int) -> float:\n",
    "    \"\"\"\n",
    "    Probability at least one of k independent samples is correct,\n",
    "    given single-sample success prob p.\n",
    "    \"\"\"\n",
    "    return 1 - (1 - p)**k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9e338816-d262-4cf2-bf0d-0c6ab8f1e0cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def dataset_from_path(path):\n",
    "    for d in datasets:\n",
    "        if d in path:\n",
    "            return d\n",
    "    raise ValueError(\"No dataset found in path\")\n",
    "\n",
    "def model_from_path(path):\n",
    "    for d in model_names.keys():\n",
    "        if d in path:\n",
    "            return d\n",
    "    raise ValueError(\"No model found in path\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "173a1544-3414-404c-ab20-8cf0c996ef37",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def analyze_baseline_log(folder):\n",
    "    dataset = dataset_from_path(folder)\n",
    "    task = tasks[dataset]\n",
    "    \n",
    "    files = [x for x in os.listdir(folder) if x.endswith(\".dat\")]\n",
    "    if len(files) == 0:\n",
    "        raise ValueError(\"No .dat files in folder\")\n",
    "    baseline_logs = []\n",
    "    for fp in files:\n",
    "        with open(folder + fp, \"rb\") as f:\n",
    "            baseline_logs.append(pickle.load(f))\n",
    "    \n",
    "    baseline_results = [l[\"results\"] for l in baseline_logs]\n",
    "    for b in baseline_results:\n",
    "        b[\"solution\"] = b.solution.apply(lambda x: None if \"impossible\" in x.lower() else x)\n",
    "    \n",
    "    baseline_results = [task.evaluate_results(r) for r in baseline_results]\n",
    "    \n",
    "    all_preds = pd.concat([r.is_correct.reset_index(drop=True) for r in baseline_results], axis=1, ignore_index=True)\n",
    "    all_preds.columns = range(len(baseline_results))\n",
    "    \n",
    "    # token usages per run.\n",
    "    ct = np.mean([l[\"token_usage\"] for l in baseline_logs])\n",
    "    \n",
    "    p = all_preds.mean(axis=1)\n",
    "    oracle = success_any(p, len(baseline_results)).mean()\n",
    "    #major = (p > 0.5).mean()\n",
    "    a = pd.concat([r[[\"root\",\"solution\"]] for r in baseline_results], axis=0, ignore_index=True)\n",
    "    b = a.groupby(\"root\").solution.agg(lambda x: x.mode(dropna=False).sample(1)).to_frame().reset_index()\n",
    "    major = task.evaluate_results(b).is_correct.mean()\n",
    "    \n",
    "    r = dict(oracle_rate=oracle, oracle_ct=ct * len(baseline_results), \n",
    "             majority_rate=major, majority_ct=ct * len(baseline_results),\n",
    "             mean_rate=all_preds.mean().mean(), mean_ct=ct, dataset=dataset, \n",
    "             backend=model_from_path(folder), method=\"io\" if \"/io/\" in folder else \"cot\")\n",
    "    return r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "85233509-e09e-4aeb-aa0c-b167dcd611a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/12 [00:00<?, ?it/s]/local/u17414_10161085/ipykernel_893916/3024597436.py:29: UserWarning: Unable to sort modes: '<' not supported between instances of 'NoneType' and 'str'\n",
      "  b = a.groupby(\"root\").solution.agg(lambda x: x.mode(dropna=False).sample(1)).to_frame().reset_index()\n",
      " 17%|█▋        | 2/12 [00:07<00:37,  3.73s/it]/local/u17414_10161085/ipykernel_893916/3024597436.py:29: UserWarning: Unable to sort modes: '<' not supported between instances of 'NoneType' and 'str'\n",
      "  b = a.groupby(\"root\").solution.agg(lambda x: x.mode(dropna=False).sample(1)).to_frame().reset_index()\n",
      " 50%|█████     | 6/12 [00:39<00:46,  7.74s/it]/local/u17414_10161085/ipykernel_893916/3024597436.py:29: UserWarning: Unable to sort modes: '<' not supported between instances of 'NoneType' and 'str'\n",
      "  b = a.groupby(\"root\").solution.agg(lambda x: x.mode(dropna=False).sample(1)).to_frame().reset_index()\n",
      "100%|██████████| 12/12 [01:13<00:00,  6.10s/it]\n"
     ]
    }
   ],
   "source": [
    "base_dirs = [os.path.dirname(x)+\"/\" for x in glob.glob(\"data/logs/**/io/**/0.dat\", recursive=True)]\n",
    "base_dirs += [os.path.dirname(x)+\"/\" for x in glob.glob(\"data/logs/**/cot/**/0.dat\", recursive=True)]\n",
    "data = [y for x in tqdm(base_dirs) if not isinstance((y:=analyze_baseline_log(x)), BaseException)]\n",
    "baseline_data = pd.DataFrame(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 171,
   "id": "4374251b-005e-4455-b59c-8f8728f8d1a5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "</style>\n",
       "<table id=\"T_d0d24\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_d0d24_level0_col0\" class=\"col_heading level0 col0\" >oracle_rate</th>\n",
       "      <th id=\"T_d0d24_level0_col1\" class=\"col_heading level0 col1\" >oracle_ct</th>\n",
       "      <th id=\"T_d0d24_level0_col2\" class=\"col_heading level0 col2\" >majority_rate</th>\n",
       "      <th id=\"T_d0d24_level0_col3\" class=\"col_heading level0 col3\" >majority_ct</th>\n",
       "      <th id=\"T_d0d24_level0_col4\" class=\"col_heading level0 col4\" >mean_rate</th>\n",
       "      <th id=\"T_d0d24_level0_col5\" class=\"col_heading level0 col5\" >mean_ct</th>\n",
       "      <th id=\"T_d0d24_level0_col6\" class=\"col_heading level0 col6\" >dataset</th>\n",
       "      <th id=\"T_d0d24_level0_col7\" class=\"col_heading level0 col7\" >backend</th>\n",
       "      <th id=\"T_d0d24_level0_col8\" class=\"col_heading level0 col8\" >method</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_d0d24_level0_row0\" class=\"row_heading level0 row0\" >6</th>\n",
       "      <td id=\"T_d0d24_row0_col0\" class=\"data row0 col0\" >0.932149</td>\n",
       "      <td id=\"T_d0d24_row0_col1\" class=\"data row0 col1\" >8729274.000000</td>\n",
       "      <td id=\"T_d0d24_row0_col2\" class=\"data row0 col2\" >0.52</td>\n",
       "      <td id=\"T_d0d24_row0_col3\" class=\"data row0 col3\" >8.7M</td>\n",
       "      <td id=\"T_d0d24_row0_col4\" class=\"data row0 col4\" >0.31</td>\n",
       "      <td id=\"T_d0d24_row0_col5\" class=\"data row0 col5\" >87.3K</td>\n",
       "      <td id=\"T_d0d24_row0_col6\" class=\"data row0 col6\" >four_digits_unsolvable</td>\n",
       "      <td id=\"T_d0d24_row0_col7\" class=\"data row0 col7\" >RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic</td>\n",
       "      <td id=\"T_d0d24_row0_col8\" class=\"data row0 col8\" >cot</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d0d24_level0_row1\" class=\"row_heading level0 row1\" >0</th>\n",
       "      <td id=\"T_d0d24_row1_col0\" class=\"data row1 col0\" >0.966443</td>\n",
       "      <td id=\"T_d0d24_row1_col1\" class=\"data row1 col1\" >7702936.000000</td>\n",
       "      <td id=\"T_d0d24_row1_col2\" class=\"data row1 col2\" >0.73</td>\n",
       "      <td id=\"T_d0d24_row1_col3\" class=\"data row1 col3\" >7.7M</td>\n",
       "      <td id=\"T_d0d24_row1_col4\" class=\"data row1 col4\" >0.29</td>\n",
       "      <td id=\"T_d0d24_row1_col5\" class=\"data row1 col5\" >77.0K</td>\n",
       "      <td id=\"T_d0d24_row1_col6\" class=\"data row1 col6\" >four_digits_unsolvable</td>\n",
       "      <td id=\"T_d0d24_row1_col7\" class=\"data row1 col7\" >RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic</td>\n",
       "      <td id=\"T_d0d24_row1_col8\" class=\"data row1 col8\" >io</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d0d24_level0_row2\" class=\"row_heading level0 row2\" >7</th>\n",
       "      <td id=\"T_d0d24_row2_col0\" class=\"data row2 col0\" >0.046712</td>\n",
       "      <td id=\"T_d0d24_row2_col1\" class=\"data row2 col1\" >4909433.000000</td>\n",
       "      <td id=\"T_d0d24_row2_col2\" class=\"data row2 col2\" >0.00</td>\n",
       "      <td id=\"T_d0d24_row2_col3\" class=\"data row2 col3\" >4.9M</td>\n",
       "      <td id=\"T_d0d24_row2_col4\" class=\"data row2 col4\" >0.00</td>\n",
       "      <td id=\"T_d0d24_row2_col5\" class=\"data row2 col5\" >49.1K</td>\n",
       "      <td id=\"T_d0d24_row2_col6\" class=\"data row2 col6\" >four_digits_unsolvable</td>\n",
       "      <td id=\"T_d0d24_row2_col7\" class=\"data row2 col7\" >microsoft/phi-4</td>\n",
       "      <td id=\"T_d0d24_row2_col8\" class=\"data row2 col8\" >cot</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d0d24_level0_row3\" class=\"row_heading level0 row3\" >1</th>\n",
       "      <td id=\"T_d0d24_row3_col0\" class=\"data row3 col0\" >0.070035</td>\n",
       "      <td id=\"T_d0d24_row3_col1\" class=\"data row3 col1\" >4749780.000000</td>\n",
       "      <td id=\"T_d0d24_row3_col2\" class=\"data row3 col2\" >0.01</td>\n",
       "      <td id=\"T_d0d24_row3_col3\" class=\"data row3 col3\" >4.7M</td>\n",
       "      <td id=\"T_d0d24_row3_col4\" class=\"data row3 col4\" >0.00</td>\n",
       "      <td id=\"T_d0d24_row3_col5\" class=\"data row3 col5\" >47.5K</td>\n",
       "      <td id=\"T_d0d24_row3_col6\" class=\"data row3 col6\" >four_digits_unsolvable</td>\n",
       "      <td id=\"T_d0d24_row3_col7\" class=\"data row3 col7\" >microsoft/phi-4</td>\n",
       "      <td id=\"T_d0d24_row3_col8\" class=\"data row3 col8\" >io</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d0d24_level0_row4\" class=\"row_heading level0 row4\" >8</th>\n",
       "      <td id=\"T_d0d24_row4_col0\" class=\"data row4 col0\" >0.999990</td>\n",
       "      <td id=\"T_d0d24_row4_col1\" class=\"data row4 col1\" >8170270.000000</td>\n",
       "      <td id=\"T_d0d24_row4_col2\" class=\"data row4 col2\" >0.92</td>\n",
       "      <td id=\"T_d0d24_row4_col3\" class=\"data row4 col3\" >8.2M</td>\n",
       "      <td id=\"T_d0d24_row4_col4\" class=\"data row4 col4\" >0.40</td>\n",
       "      <td id=\"T_d0d24_row4_col5\" class=\"data row4 col5\" >81.7K</td>\n",
       "      <td id=\"T_d0d24_row4_col6\" class=\"data row4 col6\" >four_digits_unsolvable</td>\n",
       "      <td id=\"T_d0d24_row4_col7\" class=\"data row4 col7\" >mistralai/Mistral-Small-24B-Instruct-2501</td>\n",
       "      <td id=\"T_d0d24_row4_col8\" class=\"data row4 col8\" >cot</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d0d24_level0_row5\" class=\"row_heading level0 row5\" >2</th>\n",
       "      <td id=\"T_d0d24_row5_col0\" class=\"data row5 col0\" >0.893459</td>\n",
       "      <td id=\"T_d0d24_row5_col1\" class=\"data row5 col1\" >5630125.000000</td>\n",
       "      <td id=\"T_d0d24_row5_col2\" class=\"data row5 col2\" >0.27</td>\n",
       "      <td id=\"T_d0d24_row5_col3\" class=\"data row5 col3\" >5.6M</td>\n",
       "      <td id=\"T_d0d24_row5_col4\" class=\"data row5 col4\" >0.12</td>\n",
       "      <td id=\"T_d0d24_row5_col5\" class=\"data row5 col5\" >56.3K</td>\n",
       "      <td id=\"T_d0d24_row5_col6\" class=\"data row5 col6\" >four_digits_unsolvable</td>\n",
       "      <td id=\"T_d0d24_row5_col7\" class=\"data row5 col7\" >mistralai/Mistral-Small-24B-Instruct-2501</td>\n",
       "      <td id=\"T_d0d24_row5_col8\" class=\"data row5 col8\" >io</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d0d24_level0_row6\" class=\"row_heading level0 row6\" >9</th>\n",
       "      <td id=\"T_d0d24_row6_col0\" class=\"data row6 col0\" >0.781550</td>\n",
       "      <td id=\"T_d0d24_row6_col1\" class=\"data row6 col1\" >5885322.000000</td>\n",
       "      <td id=\"T_d0d24_row6_col2\" class=\"data row6 col2\" >0.25</td>\n",
       "      <td id=\"T_d0d24_row6_col3\" class=\"data row6 col3\" >5.9M</td>\n",
       "      <td id=\"T_d0d24_row6_col4\" class=\"data row6 col4\" >0.20</td>\n",
       "      <td id=\"T_d0d24_row6_col5\" class=\"data row6 col5\" >58.9K</td>\n",
       "      <td id=\"T_d0d24_row6_col6\" class=\"data row6 col6\" >tot_test_split</td>\n",
       "      <td id=\"T_d0d24_row6_col7\" class=\"data row6 col7\" >RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic</td>\n",
       "      <td id=\"T_d0d24_row6_col8\" class=\"data row6 col8\" >cot</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d0d24_level0_row7\" class=\"row_heading level0 row7\" >3</th>\n",
       "      <td id=\"T_d0d24_row7_col0\" class=\"data row7 col0\" >0.986317</td>\n",
       "      <td id=\"T_d0d24_row7_col1\" class=\"data row7 col1\" >6331750.000000</td>\n",
       "      <td id=\"T_d0d24_row7_col2\" class=\"data row7 col2\" >0.51</td>\n",
       "      <td id=\"T_d0d24_row7_col3\" class=\"data row7 col3\" >6.3M</td>\n",
       "      <td id=\"T_d0d24_row7_col4\" class=\"data row7 col4\" >0.34</td>\n",
       "      <td id=\"T_d0d24_row7_col5\" class=\"data row7 col5\" >63.3K</td>\n",
       "      <td id=\"T_d0d24_row7_col6\" class=\"data row7 col6\" >tot_test_split</td>\n",
       "      <td id=\"T_d0d24_row7_col7\" class=\"data row7 col7\" >RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic</td>\n",
       "      <td id=\"T_d0d24_row7_col8\" class=\"data row7 col8\" >io</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d0d24_level0_row8\" class=\"row_heading level0 row8\" >10</th>\n",
       "      <td id=\"T_d0d24_row8_col0\" class=\"data row8 col0\" >0.707847</td>\n",
       "      <td id=\"T_d0d24_row8_col1\" class=\"data row8 col1\" >2985804.000000</td>\n",
       "      <td id=\"T_d0d24_row8_col2\" class=\"data row8 col2\" >0.15</td>\n",
       "      <td id=\"T_d0d24_row8_col3\" class=\"data row8 col3\" >3.0M</td>\n",
       "      <td id=\"T_d0d24_row8_col4\" class=\"data row8 col4\" >0.08</td>\n",
       "      <td id=\"T_d0d24_row8_col5\" class=\"data row8 col5\" >29.9K</td>\n",
       "      <td id=\"T_d0d24_row8_col6\" class=\"data row8 col6\" >tot_test_split</td>\n",
       "      <td id=\"T_d0d24_row8_col7\" class=\"data row8 col7\" >microsoft/phi-4</td>\n",
       "      <td id=\"T_d0d24_row8_col8\" class=\"data row8 col8\" >cot</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d0d24_level0_row9\" class=\"row_heading level0 row9\" >4</th>\n",
       "      <td id=\"T_d0d24_row9_col0\" class=\"data row9 col0\" >0.580091</td>\n",
       "      <td id=\"T_d0d24_row9_col1\" class=\"data row9 col1\" >3005699.000000</td>\n",
       "      <td id=\"T_d0d24_row9_col2\" class=\"data row9 col2\" >0.19</td>\n",
       "      <td id=\"T_d0d24_row9_col3\" class=\"data row9 col3\" >3.0M</td>\n",
       "      <td id=\"T_d0d24_row9_col4\" class=\"data row9 col4\" >0.13</td>\n",
       "      <td id=\"T_d0d24_row9_col5\" class=\"data row9 col5\" >30.1K</td>\n",
       "      <td id=\"T_d0d24_row9_col6\" class=\"data row9 col6\" >tot_test_split</td>\n",
       "      <td id=\"T_d0d24_row9_col7\" class=\"data row9 col7\" >microsoft/phi-4</td>\n",
       "      <td id=\"T_d0d24_row9_col8\" class=\"data row9 col8\" >io</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d0d24_level0_row10\" class=\"row_heading level0 row10\" >11</th>\n",
       "      <td id=\"T_d0d24_row10_col0\" class=\"data row10 col0\" >0.890138</td>\n",
       "      <td id=\"T_d0d24_row10_col1\" class=\"data row10 col1\" >6614005.000000</td>\n",
       "      <td id=\"T_d0d24_row10_col2\" class=\"data row10 col2\" >0.10</td>\n",
       "      <td id=\"T_d0d24_row10_col3\" class=\"data row10 col3\" >6.6M</td>\n",
       "      <td id=\"T_d0d24_row10_col4\" class=\"data row10 col4\" >0.11</td>\n",
       "      <td id=\"T_d0d24_row10_col5\" class=\"data row10 col5\" >66.1K</td>\n",
       "      <td id=\"T_d0d24_row10_col6\" class=\"data row10 col6\" >tot_test_split</td>\n",
       "      <td id=\"T_d0d24_row10_col7\" class=\"data row10 col7\" >mistralai/Mistral-Small-24B-Instruct-2501</td>\n",
       "      <td id=\"T_d0d24_row10_col8\" class=\"data row10 col8\" >cot</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_d0d24_level0_row11\" class=\"row_heading level0 row11\" >5</th>\n",
       "      <td id=\"T_d0d24_row11_col0\" class=\"data row11 col0\" >0.936007</td>\n",
       "      <td id=\"T_d0d24_row11_col1\" class=\"data row11 col1\" >7358527.000000</td>\n",
       "      <td id=\"T_d0d24_row11_col2\" class=\"data row11 col2\" >0.34</td>\n",
       "      <td id=\"T_d0d24_row11_col3\" class=\"data row11 col3\" >7.4M</td>\n",
       "      <td id=\"T_d0d24_row11_col4\" class=\"data row11 col4\" >0.20</td>\n",
       "      <td id=\"T_d0d24_row11_col5\" class=\"data row11 col5\" >73.6K</td>\n",
       "      <td id=\"T_d0d24_row11_col6\" class=\"data row11 col6\" >tot_test_split</td>\n",
       "      <td id=\"T_d0d24_row11_col7\" class=\"data row11 col7\" >mistralai/Mistral-Small-24B-Instruct-2501</td>\n",
       "      <td id=\"T_d0d24_row11_col8\" class=\"data row11 col8\" >io</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x1539060f1730>"
      ]
     },
     "execution_count": 171,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d = baseline_data.sort_values([\"dataset\", \"backend\",\"method\"])\n",
    "s = d.style\n",
    "s.format(human_number_format, [\"mean_ct\", \"majority_ct\"]).format(lambda x: f\"{x:.2f}\", [\"majority_rate\", \"mean_rate\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2844a2c2-9138-4765-9e6a-186eb76eab37",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from utils import all_paths_between, get_unsolved_roots, get_solved_roots, get_winning_subgraph, get_solution_nodes\n",
    "\n",
    "def analyze_hp_log(folder, require_results=False, only_success_rates=False, last_iter_only=False):\n",
    "    try:\n",
    "        dataset = dataset_from_path(folder)\n",
    "        task = tasks[dataset]\n",
    "        with open(folder + \"0.dat\", \"rb\") as f:\n",
    "            logs = pickle.load(f)\n",
    "        with open(folder + \"config.json\", \"rb\") as f:\n",
    "            config = json.load(f)\n",
    "        sum_runtime = sum([l[\"time\"] for l in logs if \"time\" in l])\n",
    "        if last_iter_only:\n",
    "            # last two logs if results are there, single log else\n",
    "            if \"graph\" not in logs[-1]:\n",
    "                logs = logs[-2:]\n",
    "            else:\n",
    "                logs = logs[-1:]\n",
    "\n",
    "        graphs = [l[\"graph\"] for l in logs if \"graph\" in l]\n",
    "\n",
    "        has_results = False\n",
    "        if \"results\" in logs[-1]:\n",
    "            has_results = True\n",
    "            # check whether the root nodes need to be converted from the old format\n",
    "            solutions = logs[-1][\"results\"]\n",
    "            solutions[\"root\"] = solutions.root.apply(lambda x: \"\".join(x[1:-1].split(\",\")).strip() if x.startswith(\"(\") else x)\n",
    "            r = task.evaluate_results(solutions)\n",
    "        if not has_results and require_results:\n",
    "            print(\"No results in folder\", folder)\n",
    "            return None\n",
    "            \n",
    "\n",
    "        pt = [sum(l[\"token_usage\"].values()) for l in logs if \"token_usage\" in l]\n",
    "        if \"graph\" not in logs[-1]:\n",
    "            # the finalization step needs to be merged\n",
    "            pt[-2] += pt[-1] - pt[-2] # already cumulative\n",
    "            pt = pt[:-1]\n",
    "        config[\"completion_tokens\"] = pt\n",
    "\n",
    "        call_counts = [l[\"model_calls\"] for l in logs]\n",
    "        config[\"model_calls\"] = call_counts\n",
    "\n",
    "        root_shortcut_rate = []\n",
    "        winning_paths_with_shortcuts_rate = []\n",
    "        success_rate = [] if has_results else None\n",
    "            \n",
    "        for G in tqdm(graphs, disable=last_iter_only):\n",
    "            if has_results:\n",
    "                unsolved = set(get_unsolved_roots(G))\n",
    "                local_r = r.copy()\n",
    "                local_r.solution = local_r.apply(lambda row: None if row.root in unsolved else row.solution, axis=1)\n",
    "                local_r = task.evaluate_results(local_r)\n",
    "                success_rate.append(local_r.is_correct.mean())\n",
    "                if only_success_rates:\n",
    "                    continue\n",
    "\n",
    "            W_pred = get_winning_subgraph(G)\n",
    "            solution_nodes = get_solution_nodes(G)\n",
    "            roots = get_solved_roots(G)\n",
    "            \n",
    "            paths = all_paths_between(W_pred, roots, solution_nodes)\n",
    "            if len(paths) > 0:\n",
    "                paths = pd.DataFrame({\"path\":paths})\n",
    "                paths[\"root\"] = paths.path.apply(lambda x: x[0][0])\n",
    "                paths[\"has_shortcut\"] = paths.path.apply(lambda p: any([W_pred.edges[e].get(\"is_shortcut\", False)  for e in p]))\n",
    "                mean_roots_with_only_shortcuts = paths.groupby(\"root\").has_shortcut.all().mean()\n",
    "                mean_winning_paths_with_shortcuts = paths.has_shortcut.mean()\n",
    "            else:\n",
    "                mean_roots_with_only_shortcuts = mean_winning_paths_with_shortcuts = 0\n",
    "            root_shortcut_rate.append(mean_roots_with_only_shortcuts)\n",
    "            winning_paths_with_shortcuts_rate.append(mean_winning_paths_with_shortcuts)\n",
    "\n",
    "\n",
    "        root_shortcut_rate = np.array(root_shortcut_rate)\n",
    "        winning_paths_with_shortcuts_rate = np.array(winning_paths_with_shortcuts_rate)\n",
    "\n",
    "        # count nodes per layer over time\n",
    "        layers = pd.DataFrame([[len(x) for x in nx.bfs_layers(G, [r for r in G if G.in_degree(r) == 0])] for l in logs if (G := l.get(\"graph\",None))]).fillna(0).values\n",
    "        config[\"nodes_per_layer\"] = np.array(layers)\n",
    "\n",
    "    \n",
    "        config[\"root_shortcut_rate\"] = root_shortcut_rate\n",
    "        config[\"winning_paths_with_shortcuts_rate\"] = winning_paths_with_shortcuts_rate\n",
    "        config[\"success_rate\"] = success_rate\n",
    "        config[\"runtime\"] = sum_runtime\n",
    "\n",
    "        if last_iter_only:\n",
    "            # unpack\n",
    "            config[\"completion_tokens\"] = config[\"completion_tokens\"][0]\n",
    "            config[\"nodes_per_layer\"] = config[\"nodes_per_layer\"][0]\n",
    "            config[\"root_shortcut_rate\"] = config[\"root_shortcut_rate\"][0]\n",
    "            config[\"winning_paths_with_shortcuts_rate\"] = config[\"winning_paths_with_shortcuts_rate\"][0]\n",
    "            config[\"success_rate\"] = config[\"success_rate\"][0] if has_results else None\n",
    "            config[\"model_calls\"] = config[\"model_calls\"][-1]\n",
    "            \n",
    "        #config[\"logs\"] = logs\n",
    "    except Exception as e:\n",
    "        print(folder)\n",
    "        return e\n",
    "    return config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "dca3e1d9-0aa7-4367-9aed-d21b888251f9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 75/75 [01:51<00:00,  1.48s/it]\n"
     ]
    }
   ],
   "source": [
    "hp_dirs_make24 = [os.path.dirname(x)+\"/\" for x in glob.glob(\"data/logs/tot_test_split/hp_[0-9]/**/0.dat\", recursive=True)]\n",
    "data = [y for x in tqdm(hp_dirs_make24) if not isinstance((y:=analyze_hp_log(x, last_iter_only=True)), BaseException)]\n",
    "hp_data_make24 = pd.DataFrame(data)\n",
    "hp_data_make24[\"backend\"] = hp_data_make24.logdir.apply(model_from_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "81796deb-8e6b-474f-8faa-1ac67b998217",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "</style>\n",
       "<table id=\"T_eced6\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank\" >&nbsp;</th>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_eced6_level0_col0\" class=\"col_heading level0 col0\" >success_rate</th>\n",
       "      <th id=\"T_eced6_level0_col1\" class=\"col_heading level0 col1\" colspan=\"2\">completion_tokens</th>\n",
       "      <th id=\"T_eced6_level0_col3\" class=\"col_heading level0 col3\" >model_calls</th>\n",
       "      <th id=\"T_eced6_level0_col4\" class=\"col_heading level0 col4\" >backend</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"blank\" >&nbsp;</th>\n",
       "      <th class=\"blank level1\" >&nbsp;</th>\n",
       "      <th id=\"T_eced6_level1_col0\" class=\"col_heading level1 col0\" >mean</th>\n",
       "      <th id=\"T_eced6_level1_col1\" class=\"col_heading level1 col1\" >mean</th>\n",
       "      <th id=\"T_eced6_level1_col2\" class=\"col_heading level1 col2\" >std</th>\n",
       "      <th id=\"T_eced6_level1_col3\" class=\"col_heading level1 col3\" >mean</th>\n",
       "      <th id=\"T_eced6_level1_col4\" class=\"col_heading level1 col4\" >count</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"index_name level0\" >backend</th>\n",
       "      <th class=\"index_name level1\" >n_shortcuts</th>\n",
       "      <th class=\"blank col0\" >&nbsp;</th>\n",
       "      <th class=\"blank col1\" >&nbsp;</th>\n",
       "      <th class=\"blank col2\" >&nbsp;</th>\n",
       "      <th class=\"blank col3\" >&nbsp;</th>\n",
       "      <th class=\"blank col4\" >&nbsp;</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_eced6_level0_row0\" class=\"row_heading level0 row0\" rowspan=\"5\">RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic</th>\n",
       "      <th id=\"T_eced6_level1_row0\" class=\"row_heading level1 row0\" >0</th>\n",
       "      <td id=\"T_eced6_row0_col0\" class=\"data row0 col0\" >0.962</td>\n",
       "      <td id=\"T_eced6_row0_col1\" class=\"data row0 col1\" >12.8M</td>\n",
       "      <td id=\"T_eced6_row0_col2\" class=\"data row0 col2\" >1.2M</td>\n",
       "      <td id=\"T_eced6_row0_col3\" class=\"data row0 col3\" >171.7K</td>\n",
       "      <td id=\"T_eced6_row0_col4\" class=\"data row0 col4\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_eced6_level1_row1\" class=\"row_heading level1 row1\" >10</th>\n",
       "      <td id=\"T_eced6_row1_col0\" class=\"data row1 col0\" >0.980</td>\n",
       "      <td id=\"T_eced6_row1_col1\" class=\"data row1 col1\" >10.6M</td>\n",
       "      <td id=\"T_eced6_row1_col2\" class=\"data row1 col2\" >1.8M</td>\n",
       "      <td id=\"T_eced6_row1_col3\" class=\"data row1 col3\" >141.4K</td>\n",
       "      <td id=\"T_eced6_row1_col4\" class=\"data row1 col4\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_eced6_level1_row2\" class=\"row_heading level1 row2\" >50</th>\n",
       "      <td id=\"T_eced6_row2_col0\" class=\"data row2 col0\" >0.964</td>\n",
       "      <td id=\"T_eced6_row2_col1\" class=\"data row2 col1\" >8.4M</td>\n",
       "      <td id=\"T_eced6_row2_col2\" class=\"data row2 col2\" >559.7K</td>\n",
       "      <td id=\"T_eced6_row2_col3\" class=\"data row2 col3\" >107.6K</td>\n",
       "      <td id=\"T_eced6_row2_col4\" class=\"data row2 col4\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_eced6_level1_row3\" class=\"row_heading level1 row3\" >100</th>\n",
       "      <td id=\"T_eced6_row3_col0\" class=\"data row3 col0\" >0.952</td>\n",
       "      <td id=\"T_eced6_row3_col1\" class=\"data row3 col1\" >8.5M</td>\n",
       "      <td id=\"T_eced6_row3_col2\" class=\"data row3 col2\" >1.5M</td>\n",
       "      <td id=\"T_eced6_row3_col3\" class=\"data row3 col3\" >104.6K</td>\n",
       "      <td id=\"T_eced6_row3_col4\" class=\"data row3 col4\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_eced6_level1_row4\" class=\"row_heading level1 row4\" >200</th>\n",
       "      <td id=\"T_eced6_row4_col0\" class=\"data row4 col0\" >0.952</td>\n",
       "      <td id=\"T_eced6_row4_col1\" class=\"data row4 col1\" >8.3M</td>\n",
       "      <td id=\"T_eced6_row4_col2\" class=\"data row4 col2\" >732.4K</td>\n",
       "      <td id=\"T_eced6_row4_col3\" class=\"data row4 col3\" >96.4K</td>\n",
       "      <td id=\"T_eced6_row4_col4\" class=\"data row4 col4\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_eced6_level0_row5\" class=\"row_heading level0 row5\" rowspan=\"5\">microsoft/phi-4</th>\n",
       "      <th id=\"T_eced6_level1_row5\" class=\"row_heading level1 row5\" >0</th>\n",
       "      <td id=\"T_eced6_row5_col0\" class=\"data row5 col0\" >0.604</td>\n",
       "      <td id=\"T_eced6_row5_col1\" class=\"data row5 col1\" >21.6M</td>\n",
       "      <td id=\"T_eced6_row5_col2\" class=\"data row5 col2\" >5.5M</td>\n",
       "      <td id=\"T_eced6_row5_col3\" class=\"data row5 col3\" >193.6K</td>\n",
       "      <td id=\"T_eced6_row5_col4\" class=\"data row5 col4\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_eced6_level1_row6\" class=\"row_heading level1 row6\" >10</th>\n",
       "      <td id=\"T_eced6_row6_col0\" class=\"data row6 col0\" >0.902</td>\n",
       "      <td id=\"T_eced6_row6_col1\" class=\"data row6 col1\" >18.2M</td>\n",
       "      <td id=\"T_eced6_row6_col2\" class=\"data row6 col2\" >2.5M</td>\n",
       "      <td id=\"T_eced6_row6_col3\" class=\"data row6 col3\" >175.2K</td>\n",
       "      <td id=\"T_eced6_row6_col4\" class=\"data row6 col4\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_eced6_level1_row7\" class=\"row_heading level1 row7\" >50</th>\n",
       "      <td id=\"T_eced6_row7_col0\" class=\"data row7 col0\" >0.956</td>\n",
       "      <td id=\"T_eced6_row7_col1\" class=\"data row7 col1\" >18.6M</td>\n",
       "      <td id=\"T_eced6_row7_col2\" class=\"data row7 col2\" >9.1M</td>\n",
       "      <td id=\"T_eced6_row7_col3\" class=\"data row7 col3\" >161.5K</td>\n",
       "      <td id=\"T_eced6_row7_col4\" class=\"data row7 col4\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_eced6_level1_row8\" class=\"row_heading level1 row8\" >100</th>\n",
       "      <td id=\"T_eced6_row8_col0\" class=\"data row8 col0\" >0.962</td>\n",
       "      <td id=\"T_eced6_row8_col1\" class=\"data row8 col1\" >16.3M</td>\n",
       "      <td id=\"T_eced6_row8_col2\" class=\"data row8 col2\" >5.3M</td>\n",
       "      <td id=\"T_eced6_row8_col3\" class=\"data row8 col3\" >141.8K</td>\n",
       "      <td id=\"T_eced6_row8_col4\" class=\"data row8 col4\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_eced6_level1_row9\" class=\"row_heading level1 row9\" >200</th>\n",
       "      <td id=\"T_eced6_row9_col0\" class=\"data row9 col0\" >0.946</td>\n",
       "      <td id=\"T_eced6_row9_col1\" class=\"data row9 col1\" >16.7M</td>\n",
       "      <td id=\"T_eced6_row9_col2\" class=\"data row9 col2\" >5.2M</td>\n",
       "      <td id=\"T_eced6_row9_col3\" class=\"data row9 col3\" >140.7K</td>\n",
       "      <td id=\"T_eced6_row9_col4\" class=\"data row9 col4\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_eced6_level0_row10\" class=\"row_heading level0 row10\" rowspan=\"5\">mistralai/Mistral-Small-24B-Instruct-2501</th>\n",
       "      <th id=\"T_eced6_level1_row10\" class=\"row_heading level1 row10\" >0</th>\n",
       "      <td id=\"T_eced6_row10_col0\" class=\"data row10 col0\" >0.716</td>\n",
       "      <td id=\"T_eced6_row10_col1\" class=\"data row10 col1\" >15.6M</td>\n",
       "      <td id=\"T_eced6_row10_col2\" class=\"data row10 col2\" >333.0K</td>\n",
       "      <td id=\"T_eced6_row10_col3\" class=\"data row10 col3\" >166.1K</td>\n",
       "      <td id=\"T_eced6_row10_col4\" class=\"data row10 col4\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_eced6_level1_row11\" class=\"row_heading level1 row11\" >10</th>\n",
       "      <td id=\"T_eced6_row11_col0\" class=\"data row11 col0\" >0.808</td>\n",
       "      <td id=\"T_eced6_row11_col1\" class=\"data row11 col1\" >15.2M</td>\n",
       "      <td id=\"T_eced6_row11_col2\" class=\"data row11 col2\" >646.6K</td>\n",
       "      <td id=\"T_eced6_row11_col3\" class=\"data row11 col3\" >159.2K</td>\n",
       "      <td id=\"T_eced6_row11_col4\" class=\"data row11 col4\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_eced6_level1_row12\" class=\"row_heading level1 row12\" >50</th>\n",
       "      <td id=\"T_eced6_row12_col0\" class=\"data row12 col0\" >0.820</td>\n",
       "      <td id=\"T_eced6_row12_col1\" class=\"data row12 col1\" >15.4M</td>\n",
       "      <td id=\"T_eced6_row12_col2\" class=\"data row12 col2\" >585.6K</td>\n",
       "      <td id=\"T_eced6_row12_col3\" class=\"data row12 col3\" >156.5K</td>\n",
       "      <td id=\"T_eced6_row12_col4\" class=\"data row12 col4\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_eced6_level1_row13\" class=\"row_heading level1 row13\" >100</th>\n",
       "      <td id=\"T_eced6_row13_col0\" class=\"data row13 col0\" >0.858</td>\n",
       "      <td id=\"T_eced6_row13_col1\" class=\"data row13 col1\" >13.3M</td>\n",
       "      <td id=\"T_eced6_row13_col2\" class=\"data row13 col2\" >1.9M</td>\n",
       "      <td id=\"T_eced6_row13_col3\" class=\"data row13 col3\" >130.5K</td>\n",
       "      <td id=\"T_eced6_row13_col4\" class=\"data row13 col4\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_eced6_level1_row14\" class=\"row_heading level1 row14\" >200</th>\n",
       "      <td id=\"T_eced6_row14_col0\" class=\"data row14 col0\" >0.870</td>\n",
       "      <td id=\"T_eced6_row14_col1\" class=\"data row14 col1\" >12.7M</td>\n",
       "      <td id=\"T_eced6_row14_col2\" class=\"data row14 col2\" >2.9M</td>\n",
       "      <td id=\"T_eced6_row14_col3\" class=\"data row14 col3\" >118.8K</td>\n",
       "      <td id=\"T_eced6_row14_col4\" class=\"data row14 col4\" >5</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x14f8955c17c0>"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g = hp_data_make24.groupby([\"backend\",\"n_shortcuts\"], group_keys=False)\n",
    "hp_tot_avg = g.agg({\"success_rate\":\"mean\",\"completion_tokens\":[\"mean\",\"std\"], \"model_calls\":\"mean\", \"backend\":\"count\"})\n",
    "hp_tot_avg.style.format(human_number_format, [\"completion_tokens\",\"model_calls\"]).format(lambda x: f\"{x:.3f}\", \"success_rate\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d72eb47-c08f-4b00-93d3-76710668b015",
   "metadata": {},
   "source": [
    "## Ablation\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "7b7eb1eb-fc51-4a32-b500-611b6199f38a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:39<00:00,  1.96s/it]\n"
     ]
    }
   ],
   "source": [
    "d = [os.path.dirname(x)+\"/\" for x in glob.glob(\"data/logs/tot_test_split/hp_ablation_*/**/0.dat\", recursive=True)]\n",
    "data = [y for x in tqdm(d) if not isinstance((y:=analyze_hp_log(x, last_iter_only=True)), BaseException)]\n",
    "ablation_data = pd.DataFrame(data)\n",
    "ablation_data[\"backend\"] = ablation_data.logdir.apply(model_from_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "f9d8de0f-5ce6-416e-a68c-cf0aec1c41d8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "</style>\n",
       "<table id=\"T_08ec8\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank\" >&nbsp;</th>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_08ec8_level0_col0\" class=\"col_heading level0 col0\" >backend</th>\n",
       "      <th id=\"T_08ec8_level0_col1\" class=\"col_heading level0 col1\" >completion_tokens</th>\n",
       "      <th id=\"T_08ec8_level0_col2\" class=\"col_heading level0 col2\" >model_calls</th>\n",
       "      <th id=\"T_08ec8_level0_col3\" class=\"col_heading level0 col3\" >success_rate</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"index_name level0\" >do_verify_moves</th>\n",
       "      <th class=\"index_name level1\" >do_shortcut</th>\n",
       "      <th class=\"blank col0\" >&nbsp;</th>\n",
       "      <th class=\"blank col1\" >&nbsp;</th>\n",
       "      <th class=\"blank col2\" >&nbsp;</th>\n",
       "      <th class=\"blank col3\" >&nbsp;</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_08ec8_level0_row0\" class=\"row_heading level0 row0\" rowspan=\"2\">False</th>\n",
       "      <th id=\"T_08ec8_level1_row0\" class=\"row_heading level1 row0\" >False</th>\n",
       "      <td id=\"T_08ec8_row0_col0\" class=\"data row0 col0\" >5</td>\n",
       "      <td id=\"T_08ec8_row0_col1\" class=\"data row0 col1\" >28701192.400000</td>\n",
       "      <td id=\"T_08ec8_row0_col2\" class=\"data row0 col2\" >221567.200000</td>\n",
       "      <td id=\"T_08ec8_row0_col3\" class=\"data row0 col3\" >0.524000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_08ec8_level1_row1\" class=\"row_heading level1 row1\" >True</th>\n",
       "      <td id=\"T_08ec8_row1_col0\" class=\"data row1 col0\" >5</td>\n",
       "      <td id=\"T_08ec8_row1_col1\" class=\"data row1 col1\" >15388496.200000</td>\n",
       "      <td id=\"T_08ec8_row1_col2\" class=\"data row1 col2\" >122543.600000</td>\n",
       "      <td id=\"T_08ec8_row1_col3\" class=\"data row1 col3\" >0.916000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_08ec8_level0_row2\" class=\"row_heading level0 row2\" rowspan=\"2\">True</th>\n",
       "      <th id=\"T_08ec8_level1_row2\" class=\"row_heading level1 row2\" >False</th>\n",
       "      <td id=\"T_08ec8_row2_col0\" class=\"data row2 col0\" >5</td>\n",
       "      <td id=\"T_08ec8_row2_col1\" class=\"data row2 col1\" >20826618.000000</td>\n",
       "      <td id=\"T_08ec8_row2_col2\" class=\"data row2 col2\" >201514.800000</td>\n",
       "      <td id=\"T_08ec8_row2_col3\" class=\"data row2 col3\" >0.944000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_08ec8_level1_row3\" class=\"row_heading level1 row3\" >True</th>\n",
       "      <td id=\"T_08ec8_row3_col0\" class=\"data row3 col0\" >5</td>\n",
       "      <td id=\"T_08ec8_row3_col1\" class=\"data row3 col1\" >13135582.200000</td>\n",
       "      <td id=\"T_08ec8_row3_col2\" class=\"data row3 col2\" >116830.000000</td>\n",
       "      <td id=\"T_08ec8_row3_col3\" class=\"data row3 col3\" >0.962000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x14f88e7c1a00>"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ablation_results = ablation_data.groupby([\"do_verify_moves\",\"do_shortcut\"]).agg({\"backend\":\"count\", \"completion_tokens\":\"mean\", \"model_calls\":\"mean\", \"success_rate\":\"mean\",})\n",
    "ablation_results.style"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "8da454a6-b352-4a4a-9e33-d3bba24e8551",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrr}\n",
      " & do_verify_moves & do_shortcut & completion_tokens & model_calls & success_rate \\\\\n",
      "0 & \\xmark & \\xmark & 28.7M & 221.6K & 52.40 \\\\\n",
      "1 & \\xmark & \\cmark & 15.4M & 122.5K & 91.60 \\\\\n",
      "2 & \\cmark & \\xmark & 20.8M & 201.5K & 94.40 \\\\\n",
      "3 & \\cmark & \\cmark & 13.1M & 116.8K & 96.20 \\\\\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "s = ablation_results.drop(\"backend\",axis=1).reset_index().style\n",
    "s = s.format(human_number_format, \"completion_tokens\")\n",
    "s = s.format(lambda x: f\"{x*100:.2f}\", \"success_rate\")\n",
    "s = s.format(lambda x: \"\\\\cmark\" if x else \"\\\\xmark\", [\"do_verify_moves\",\"do_shortcut\"])\n",
    "s = s.format(human_number_format, \"model_calls\")\n",
    "print(s.to_latex())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47867a39-15a3-4dbb-9866-db8926e00cc1",
   "metadata": {},
   "source": [
    "## Unsolvables\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3a777d26-887c-46b2-bb31-76f05abc5e40",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 18/18 [00:34<00:00,  1.90s/it]\n"
     ]
    }
   ],
   "source": [
    "hp_dirs_make24 = [os.path.dirname(x)+\"/\" for x in glob.glob(\"data/logs/four_digits_unsolvable/hp_[0-9]/**/0.dat\", recursive=True)]\n",
    "data = [y for x in tqdm(hp_dirs_make24) if not isinstance((y:=analyze_hp_log(x, last_iter_only=True)), BaseException)]\n",
    "hp_data_make24 = pd.DataFrame(data)\n",
    "hp_data_make24[\"backend\"] = hp_data_make24.logdir.apply(model_from_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "caab5bd4-9665-4205-83a5-7b841a01004c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "</style>\n",
       "<table id=\"T_b4c33\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank\" >&nbsp;</th>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_b4c33_level0_col0\" class=\"col_heading level0 col0\" >success_rate</th>\n",
       "      <th id=\"T_b4c33_level0_col1\" class=\"col_heading level0 col1\" >completion_tokens</th>\n",
       "      <th id=\"T_b4c33_level0_col2\" class=\"col_heading level0 col2\" >model_calls</th>\n",
       "      <th id=\"T_b4c33_level0_col3\" class=\"col_heading level0 col3\" >backend</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"index_name level0\" >backend</th>\n",
       "      <th class=\"index_name level1\" >n_shortcuts</th>\n",
       "      <th class=\"blank col0\" >&nbsp;</th>\n",
       "      <th class=\"blank col1\" >&nbsp;</th>\n",
       "      <th class=\"blank col2\" >&nbsp;</th>\n",
       "      <th class=\"blank col3\" >&nbsp;</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_b4c33_level0_row0\" class=\"row_heading level0 row0\" >RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic</th>\n",
       "      <th id=\"T_b4c33_level1_row0\" class=\"row_heading level1 row0\" >10</th>\n",
       "      <td id=\"T_b4c33_row0_col0\" class=\"data row0 col0\" >1.000</td>\n",
       "      <td id=\"T_b4c33_row0_col1\" class=\"data row0 col1\" >9.6M</td>\n",
       "      <td id=\"T_b4c33_row0_col2\" class=\"data row0 col2\" >132.7K</td>\n",
       "      <td id=\"T_b4c33_row0_col3\" class=\"data row0 col3\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_b4c33_level0_row1\" class=\"row_heading level0 row1\" rowspan=\"2\">microsoft/phi-4</th>\n",
       "      <th id=\"T_b4c33_level1_row1\" class=\"row_heading level1 row1\" >50</th>\n",
       "      <td id=\"T_b4c33_row1_col0\" class=\"data row1 col0\" >1.000</td>\n",
       "      <td id=\"T_b4c33_row1_col1\" class=\"data row1 col1\" >17.2M</td>\n",
       "      <td id=\"T_b4c33_row1_col2\" class=\"data row1 col2\" >148.9K</td>\n",
       "      <td id=\"T_b4c33_row1_col3\" class=\"data row1 col3\" >3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_b4c33_level1_row2\" class=\"row_heading level1 row2\" >100</th>\n",
       "      <td id=\"T_b4c33_row2_col0\" class=\"data row2 col0\" >0.998</td>\n",
       "      <td id=\"T_b4c33_row2_col1\" class=\"data row2 col1\" >21.7M</td>\n",
       "      <td id=\"T_b4c33_row2_col2\" class=\"data row2 col2\" >189.2K</td>\n",
       "      <td id=\"T_b4c33_row2_col3\" class=\"data row2 col3\" >5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_b4c33_level0_row3\" class=\"row_heading level0 row3\" >mistralai/Mistral-Small-24B-Instruct-2501</th>\n",
       "      <th id=\"T_b4c33_level1_row3\" class=\"row_heading level1 row3\" >200</th>\n",
       "      <td id=\"T_b4c33_row3_col0\" class=\"data row3 col0\" >1.000</td>\n",
       "      <td id=\"T_b4c33_row3_col1\" class=\"data row3 col1\" >12.6M</td>\n",
       "      <td id=\"T_b4c33_row3_col2\" class=\"data row3 col2\" >136.9K</td>\n",
       "      <td id=\"T_b4c33_row3_col3\" class=\"data row3 col3\" >5</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x14f8949b3c80>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g = hp_data_make24.groupby([\"backend\",\"n_shortcuts\"], group_keys=False)\n",
    "hp_tot_avg = g.agg({\"success_rate\":\"mean\",\"completion_tokens\":\"mean\", \"model_calls\":\"mean\", \"backend\":\"count\"})\n",
    "hp_tot_avg.style.format(human_number_format, [\"completion_tokens\",\"model_calls\"]).format(lambda x: f\"{x:.3f}\", \"success_rate\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37fd9534-6548-4c20-a0f2-36a0f2ddd048",
   "metadata": {},
   "source": [
    "## ToT\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "53efda74-1679-4e05-bb40-713866a3c1f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def analyze_tot_baseline(folder):\n",
    "    try:\n",
    "        with open(folder + \"0.dat\", \"rb\") as f:\n",
    "            logs = pickle.load(f)\n",
    "    \n",
    "        results = logs[\"results\"]\n",
    "        out = logs[\"outputs\"] # list of tuples (idx, (ys, infos))\n",
    "    \n",
    "        b = pd.DataFrame([x[1][1] for x in out])\n",
    "        mean_visited_nodes = b.steps.apply(lambda x: sum([len(y[\"new_ys\"]) for y in x])).mean()\n",
    "        rates = [results[:,:i].any(axis=1).mean() for i in range(1,results.shape[1]+1)]\n",
    "        r = dict(visited_nodes=[mean_visited_nodes] * len(rates), success_rates=rates, args=logs[\"args\"], runtime=logs[\"runtime\"], \n",
    "                 completion_tokens=[logs[\"completion_tokens\"]] * len(rates), prompt_tokens=[logs[\"prompt_tokens\"]] * len(rates), \n",
    "                 model_calls=[logs[\"model_calls\"]] * len(rates))\n",
    "    except Exception as e:\n",
    "        print(folder, e)\n",
    "        r = e\n",
    "    return r\n",
    "\n",
    "import re\n",
    "pattern_equal_24 = re.compile(r\"=\\s*24(?![.\\d])\\b\", re.IGNORECASE)\n",
    "#pattern_left_24 = re.compile(r'$left:\\s*24$', re.IGNORECASE)\n",
    "\n",
    "def claims_24(solution):\n",
    "    return bool(pattern_equal_24.search(solution))\n",
    "\n",
    "def analyse_tot_baseline_for_unsolvables(folder):\n",
    "    try:\n",
    "        with open(folder + \"0.dat\", \"rb\") as f:\n",
    "            logs = pickle.load(f)\n",
    "        \n",
    "        results = logs[\"results\"]\n",
    "        out = logs[\"outputs\"] # list of tuples (idx, (ys, infos))\n",
    "        b = pd.DataFrame([x[1][1] for x in out])\n",
    "        mean_visited_nodes = b.steps.apply(lambda x: sum([len(y[\"new_ys\"]) for y in x])).mean()\n",
    "        \n",
    "        task = tasks[logs[\"args\"].dataset]\n",
    "        \n",
    "        roots = [task.samples[o[0]] for o in out]\n",
    "        all_solutions = [o[1][0] for o in out]\n",
    "        result_df = pd.DataFrame({\"root\":roots, \"solution\":all_solutions})\n",
    "        result_df = result_df.explode(\"solution\")\n",
    "        result_df[\"solution\"] = result_df.solution.apply(lambda o: o.strip().split(\"\\n\")[-1])\n",
    "        result_df[\"claims_24\"] = result_df.solution.apply(claims_24)\n",
    "        sr = 1 - result_df.groupby(\"root\").claims_24.any().mean()\n",
    "        \n",
    "        r = dict(visited_nodes=[mean_visited_nodes], success_rates=[sr], args=logs[\"args\"], runtime=logs[\"runtime\"],\n",
    "                 completion_tokens=[logs[\"completion_tokens\"]], prompt_tokens=[logs[\"prompt_tokens\"]],\n",
    "                 model_calls=[logs[\"model_calls\"]])\n",
    "\n",
    "    except Exception as e:\n",
    "        print(folder, e)\n",
    "        r = e\n",
    "\n",
    "    return r"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80446eb9-8bf0-42be-9a6f-e150f4892caf",
   "metadata": {},
   "source": [
    "### get param ranges for follow up runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "1b92e8d3-7ae5-49da-9411-6f95dbdf374c",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 4/90 [00:00<00:02, 38.00it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/0_1_3_11_1/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/0_1_3_11_10/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/0_1_3_11_20/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/0_1_3_11_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/0_1_3_11_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/0_1_3_3_1/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/0_1_3_3_10/ 'model_calls'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█▍        | 13/90 [00:00<00:02, 35.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/0_1_3_3_20/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/0_1_3_3_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/0_1_3_3_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/0_1_3_7_1/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/0_1_3_7_10/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/0_1_3_7_20/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/0_1_3_7_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/0_1_3_7_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/1_1_3_11_1/ 'model_calls'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|██▍       | 22/90 [00:00<00:01, 37.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/1_1_3_11_10/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/1_1_3_11_20/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/1_1_3_11_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/1_1_3_11_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/1_1_3_3_1/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/1_1_3_3_10/ 'model_calls'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 27/90 [00:00<00:01, 38.11it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/1_1_3_3_20/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/1_1_3_3_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/1_1_3_3_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/1_1_3_7_1/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/1_1_3_7_10/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/1_1_3_7_20/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/1_1_3_7_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic/1_1_3_7_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/0_1_3_11_1/ 'model_calls'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 36/90 [00:00<00:01, 34.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/0_1_3_11_10/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/0_1_3_11_20/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/0_1_3_11_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/0_1_3_11_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/0_1_3_3_1/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/0_1_3_3_10/ 'model_calls'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▍     | 40/90 [00:01<00:01, 31.60it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/0_1_3_3_20/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/0_1_3_3_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/0_1_3_3_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/0_1_3_7_1/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/0_1_3_7_10/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/0_1_3_7_20/ 'model_calls'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 53%|█████▎    | 48/90 [00:01<00:01, 29.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/0_1_3_7_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/0_1_3_7_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/1_1_3_11_1/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/1_1_3_11_10/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/1_1_3_11_20/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/1_1_3_11_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/1_1_3_11_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/1_1_3_3_1/ 'model_calls'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 59%|█████▉    | 53/90 [00:01<00:01, 29.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/1_1_3_3_10/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/1_1_3_3_20/ 'model_calls'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▌   | 59/90 [00:02<00:03, 10.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/1_1_3_3_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/1_1_3_3_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/1_1_3_7_1/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/1_1_3_7_10/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/1_1_3_7_20/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/1_1_3_7_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/microsoft/phi-4/1_1_3_7_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/0_1_3_11_1/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/0_1_3_11_10/ 'model_calls'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 76%|███████▌  | 68/90 [00:03<00:01, 17.03it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/0_1_3_11_20/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/0_1_3_11_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/0_1_3_11_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/0_1_3_3_1/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/0_1_3_3_10/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/0_1_3_3_20/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/0_1_3_3_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/0_1_3_3_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/0_1_3_7_1/ 'model_calls'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 87%|████████▋ | 78/90 [00:03<00:00, 24.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/0_1_3_7_10/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/0_1_3_7_20/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/0_1_3_7_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/0_1_3_7_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/1_1_3_11_1/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/1_1_3_11_10/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/1_1_3_11_20/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/1_1_3_11_3/ 'model_calls'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|█████████▊| 88/90 [00:03<00:00, 32.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/1_1_3_11_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/1_1_3_3_1/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/1_1_3_3_10/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/1_1_3_3_20/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/1_1_3_3_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/1_1_3_3_5/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/1_1_3_7_1/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/1_1_3_7_10/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/1_1_3_7_20/ 'model_calls'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 90/90 [00:03<00:00, 24.19it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/1_1_3_7_3/ 'model_calls'\n",
      "data/logs/tot_test_split/tot_broad_search/mistralai/Mistral-Small-24B-Instruct-2501/1_1_3_7_5/ 'model_calls'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "'DataFrame' object has no attribute 'success_rates'",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mAttributeError\u001b[39m                            Traceback (most recent call last)",
      "\u001b[32m/local/jobs/u17414_10927129/ipykernel_3109482/996472489.py\u001b[39m in \u001b[36m?\u001b[39m\u001b[34m()\u001b[39m\n\u001b[32m      2\u001b[39m \n\u001b[32m      3\u001b[39m data = [y \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;28;01min\u001b[39;00m tqdm(tot_dirs_make24) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;01mnot\u001b[39;00m isinstance((y:=analyze_tot_baseline(x)), BaseException)]\n\u001b[32m      4\u001b[39m d = pd.DataFrame(data)\n\u001b[32m      5\u001b[39m tot_make24_data = pd.concat([d, d.apply(\u001b[38;5;28;01mlambda\u001b[39;00m x: x.args.__dict__, axis=\u001b[32m1\u001b[39m, result_type=\u001b[33m\"expand\"\u001b[39m)], axis=\u001b[32m1\u001b[39m)\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m tot_make24_data[\u001b[33m\"oracle_success_rate\"\u001b[39m] = tot_make24_data.success_rates.apply(\u001b[38;5;28;01mlambda\u001b[39;00m x: x[-\u001b[32m1\u001b[39m])\n\u001b[32m      7\u001b[39m tot_make24_data[\u001b[33m\"oracle_ct\"\u001b[39m] = tot_make24_data.completion_tokens.apply(\u001b[38;5;28;01mlambda\u001b[39;00m x: x[-\u001b[32m1\u001b[39m])\n\u001b[32m      8\u001b[39m tot_make24_data[\u001b[33m\"first_success_rate\"\u001b[39m] = tot_make24_data.success_rates.apply(\u001b[38;5;28;01mlambda\u001b[39;00m x: x[\u001b[32m0\u001b[39m])\n\u001b[32m      9\u001b[39m tot_make24_data[\u001b[33m\"first_ct\"\u001b[39m] = tot_make24_data.completion_tokens.apply(\u001b[38;5;28;01mlambda\u001b[39;00m x: x[\u001b[32m0\u001b[39m])\n",
      "\u001b[32m/mnt/vast-kisski/projects/kisski-tubr-rallm/workspace/holistic_prompting/hp_env/lib/python3.12/site-packages/pandas/core/generic.py\u001b[39m in \u001b[36m?\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m   6295\u001b[39m             \u001b[38;5;28;01mand\u001b[39;00m name \u001b[38;5;28;01mnot\u001b[39;00m \u001b[38;5;28;01min\u001b[39;00m self._accessors\n\u001b[32m   6296\u001b[39m             \u001b[38;5;28;01mand\u001b[39;00m self._info_axis._can_hold_identifiers_and_holds_name(name)\n\u001b[32m   6297\u001b[39m         ):\n\u001b[32m   6298\u001b[39m             \u001b[38;5;28;01mreturn\u001b[39;00m self[name]\n\u001b[32m-> \u001b[39m\u001b[32m6299\u001b[39m         \u001b[38;5;28;01mreturn\u001b[39;00m object.__getattribute__(self, name)\n",
      "\u001b[31mAttributeError\u001b[39m: 'DataFrame' object has no attribute 'success_rates'"
     ]
    }
   ],
   "source": [
    "tot_dirs_make24 = [os.path.dirname(x)+\"/\" for x in glob.glob(\"data/logs/tot_test_split/tot_broad_search/**/0.dat\", recursive=True)]\n",
    "\n",
    "data = [y for x in tqdm(tot_dirs_make24) if not isinstance((y:=analyze_tot_baseline(x)), BaseException)]\n",
    "d = pd.DataFrame(data)\n",
    "tot_make24_data = pd.concat([d, d.apply(lambda x: x.args.__dict__, axis=1, result_type=\"expand\")], axis=1)\n",
    "tot_make24_data[\"oracle_success_rate\"] = tot_make24_data.success_rates.apply(lambda x: x[-1])\n",
    "tot_make24_data[\"oracle_ct\"] = tot_make24_data.completion_tokens.apply(lambda x: x[-1])\n",
    "tot_make24_data[\"first_success_rate\"] = tot_make24_data.success_rates.apply(lambda x: x[0])\n",
    "tot_make24_data[\"first_ct\"] = tot_make24_data.completion_tokens.apply(lambda x: x[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "id": "c2f9e185-77ea-44f5-9baf-78af791208a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = tot_make24_data.groupby(\"backend\", group_keys=False).first_success_rate.nlargest(3).index\n",
    "top_runs = tot_make24_data.loc[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "5e0ac199-187d-40f7-be47-bc64cdae15ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'backend': 'RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic', 'do_verify': True, 'n_evaluate_sample': 11, 'n_select_sample': 20}, {'backend': 'RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic', 'do_verify': True, 'n_evaluate_sample': 7, 'n_select_sample': 20}, {'backend': 'RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic', 'do_verify': False, 'n_evaluate_sample': 7, 'n_select_sample': 20}, {'backend': 'microsoft/phi-4', 'do_verify': True, 'n_evaluate_sample': 11, 'n_select_sample': 5}, {'backend': 'microsoft/phi-4', 'do_verify': True, 'n_evaluate_sample': 11, 'n_select_sample': 20}, {'backend': 'microsoft/phi-4', 'do_verify': True, 'n_evaluate_sample': 7, 'n_select_sample': 20}, {'backend': 'mistralai/Mistral-Small-24B-Instruct-2501', 'do_verify': True, 'n_evaluate_sample': 11, 'n_select_sample': 10}, {'backend': 'mistralai/Mistral-Small-24B-Instruct-2501', 'do_verify': True, 'n_evaluate_sample': 11, 'n_select_sample': 20}, {'backend': 'mistralai/Mistral-Small-24B-Instruct-2501', 'do_verify': False, 'n_evaluate_sample': 11, 'n_select_sample': 3}]\n"
     ]
    }
   ],
   "source": [
    "print(top_runs[[\"backend\", \"do_verify\", \"n_evaluate_sample\", \"n_select_sample\"]].to_dict(orient=\"records\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbafbc73-a313-4059-859c-76672a62a43d",
   "metadata": {},
   "source": [
    "## tot test with repetitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "cd55b22d-0870-4dff-9052-dba04afa6509",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 44/44 [00:01<00:00, 29.84it/s]\n"
     ]
    }
   ],
   "source": [
    "tot_dirs_make24 = [os.path.dirname(x)+\"/\" for x in glob.glob(\"data/logs/tot_test_split/tot_[0-9]/**/0.dat\", recursive=True)]\n",
    "\n",
    "data = [y for x in tqdm(tot_dirs_make24) if not isinstance((y:=analyze_tot_baseline(x)), BaseException)]\n",
    "d = pd.DataFrame(data)\n",
    "tot_make24_data = pd.concat([d, d.apply(lambda x: x.args.__dict__, axis=1, result_type=\"expand\")], axis=1)\n",
    "tot_make24_data[\"oracle_success_rate\"] = tot_make24_data.success_rates.apply(lambda x: x[-1])\n",
    "tot_make24_data[\"oracle_ct\"] = tot_make24_data.completion_tokens.apply(lambda x: x[-1])\n",
    "tot_make24_data[\"first_success_rate\"] = tot_make24_data.success_rates.apply(lambda x: x[0])\n",
    "tot_make24_data[\"first_ct\"] = tot_make24_data.completion_tokens.apply(lambda x: x[0])\n",
    "tot_make24_data[\"first_model_calls\"] = tot_make24_data.model_calls.apply(lambda x: x[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "9b5abb03-6559-4185-bfa5-5a24ea825d8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "cols = [\"backend\",\"n_evaluate_sample\",\"n_select_sample\", \"do_verify\"]\n",
    "rep_data = tot_make24_data.groupby(cols).filter(lambda x: len(x) > 1 or True)\n",
    "g = rep_data.groupby(cols)\n",
    "rc = \"mean\" #[\"mean\", \"std\"]\n",
    "tot_rep_data = g.agg({\"runtime\":\"count\",\"first_success_rate\":rc,\"first_ct\":rc, \"oracle_success_rate\":rc, \"oracle_ct\":rc, \n",
    "                      \"first_model_calls\":rc})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a7b2aa9a-88a7-43e8-8430-49c2211035cc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "</style>\n",
       "<table id=\"T_7c959\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank\" >&nbsp;</th>\n",
       "      <th class=\"blank\" >&nbsp;</th>\n",
       "      <th class=\"blank\" >&nbsp;</th>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_7c959_level0_col0\" class=\"col_heading level0 col0\" >runtime</th>\n",
       "      <th id=\"T_7c959_level0_col1\" class=\"col_heading level0 col1\" >first_success_rate</th>\n",
       "      <th id=\"T_7c959_level0_col2\" class=\"col_heading level0 col2\" >first_ct</th>\n",
       "      <th id=\"T_7c959_level0_col3\" class=\"col_heading level0 col3\" >oracle_success_rate</th>\n",
       "      <th id=\"T_7c959_level0_col4\" class=\"col_heading level0 col4\" >oracle_ct</th>\n",
       "      <th id=\"T_7c959_level0_col5\" class=\"col_heading level0 col5\" >first_model_calls</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"index_name level0\" >backend</th>\n",
       "      <th class=\"index_name level1\" >n_evaluate_sample</th>\n",
       "      <th class=\"index_name level2\" >n_select_sample</th>\n",
       "      <th class=\"index_name level3\" >do_verify</th>\n",
       "      <th class=\"blank col0\" >&nbsp;</th>\n",
       "      <th class=\"blank col1\" >&nbsp;</th>\n",
       "      <th class=\"blank col2\" >&nbsp;</th>\n",
       "      <th class=\"blank col3\" >&nbsp;</th>\n",
       "      <th class=\"blank col4\" >&nbsp;</th>\n",
       "      <th class=\"blank col5\" >&nbsp;</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_7c959_level0_row0\" class=\"row_heading level0 row0\" >RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic</th>\n",
       "      <th id=\"T_7c959_level1_row0\" class=\"row_heading level1 row0\" >11</th>\n",
       "      <th id=\"T_7c959_level2_row0\" class=\"row_heading level2 row0\" >20</th>\n",
       "      <th id=\"T_7c959_level3_row0\" class=\"row_heading level3 row0\" >True</th>\n",
       "      <td id=\"T_7c959_row0_col0\" class=\"data row0 col0\" >5</td>\n",
       "      <td id=\"T_7c959_row0_col1\" class=\"data row0 col1\" >0.938</td>\n",
       "      <td id=\"T_7c959_row0_col2\" class=\"data row0 col2\" >75.3M</td>\n",
       "      <td id=\"T_7c959_row0_col3\" class=\"data row0 col3\" >0.960</td>\n",
       "      <td id=\"T_7c959_row0_col4\" class=\"data row0 col4\" >75.3M</td>\n",
       "      <td id=\"T_7c959_row0_col5\" class=\"data row0 col5\" >606.0K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_7c959_level0_row1\" class=\"row_heading level0 row1\" >microsoft/phi-4</th>\n",
       "      <th id=\"T_7c959_level1_row1\" class=\"row_heading level1 row1\" >11</th>\n",
       "      <th id=\"T_7c959_level2_row1\" class=\"row_heading level2 row1\" >20</th>\n",
       "      <th id=\"T_7c959_level3_row1\" class=\"row_heading level3 row1\" >True</th>\n",
       "      <td id=\"T_7c959_row1_col0\" class=\"data row1 col0\" >4</td>\n",
       "      <td id=\"T_7c959_row1_col1\" class=\"data row1 col1\" >0.733</td>\n",
       "      <td id=\"T_7c959_row1_col2\" class=\"data row1 col2\" >119.3M</td>\n",
       "      <td id=\"T_7c959_row1_col3\" class=\"data row1 col3\" >0.780</td>\n",
       "      <td id=\"T_7c959_row1_col4\" class=\"data row1 col4\" >119.3M</td>\n",
       "      <td id=\"T_7c959_row1_col5\" class=\"data row1 col5\" >617.2K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_7c959_level0_row2\" class=\"row_heading level0 row2\" >mistralai/Mistral-Small-24B-Instruct-2501</th>\n",
       "      <th id=\"T_7c959_level1_row2\" class=\"row_heading level1 row2\" >11</th>\n",
       "      <th id=\"T_7c959_level2_row2\" class=\"row_heading level2 row2\" >20</th>\n",
       "      <th id=\"T_7c959_level3_row2\" class=\"row_heading level3 row2\" >True</th>\n",
       "      <td id=\"T_7c959_row2_col0\" class=\"data row2 col0\" >5</td>\n",
       "      <td id=\"T_7c959_row2_col1\" class=\"data row2 col1\" >0.506</td>\n",
       "      <td id=\"T_7c959_row2_col2\" class=\"data row2 col2\" >79.1M</td>\n",
       "      <td id=\"T_7c959_row2_col3\" class=\"data row2 col3\" >0.626</td>\n",
       "      <td id=\"T_7c959_row2_col4\" class=\"data row2 col4\" >79.1M</td>\n",
       "      <td id=\"T_7c959_row2_col5\" class=\"data row2 col5\" >424.2K</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x14f894aefda0>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# best methods on tot test\n",
    "tot_best = tot_rep_data.loc[tot_rep_data.first_success_rate.groupby(\"backend\").idxmax()]\n",
    "s = tot_best.style\n",
    "s.format(human_number_format, [\"first_ct\", \"oracle_ct\",\"first_model_calls\"]).format(lambda x: f\"{x:.3f}\", [\"first_success_rate\", \"oracle_success_rate\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b66dd36-4664-4108-8b1b-9e14a014b95f",
   "metadata": {},
   "source": [
    "### tot unsolvable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "6522faa6-9066-472d-82bd-e3e65553b246",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 15/15 [00:00<00:00, 57.93it/s]\n"
     ]
    }
   ],
   "source": [
    "tot_dirs_make24 = [os.path.dirname(x)+\"/\" for x in glob.glob(\"data/logs/four_digits_unsolvable/tot_[0-9]/**/0.dat\", recursive=True)]\n",
    "\n",
    "data = [y for x in tqdm(tot_dirs_make24) if not isinstance((y:=analyse_tot_baseline_for_unsolvables(x)), BaseException)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "715c1947-8ea1-41d9-83b5-0cf1fd3a098e",
   "metadata": {},
   "outputs": [],
   "source": [
    "d = pd.DataFrame(data)\n",
    "tot_unsolvable = pd.concat([d, d.apply(lambda x: x.args.__dict__, axis=1, result_type=\"expand\")], axis=1)\n",
    "tot_unsolvable[\"first_success_rate\"] = tot_unsolvable.success_rates.apply(lambda x: x[0])\n",
    "tot_unsolvable[\"first_ct\"] = tot_unsolvable.completion_tokens.apply(lambda x: x[0])\n",
    "tot_unsolvable[\"first_model_calls\"] = tot_unsolvable.model_calls.apply(lambda x: x[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "6675f2d2-faf6-471d-bae6-97601c96affa",
   "metadata": {},
   "outputs": [],
   "source": [
    "tot_unsolvable = tot_unsolvable.set_index(tot_best.index.names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "4c439c02-3e9a-4fcc-99b1-be617048369f",
   "metadata": {},
   "outputs": [],
   "source": [
    "cols = [\"backend\",\"n_evaluate_sample\",\"n_select_sample\", \"do_verify\"]\n",
    "rep_data = tot_unsolvable.groupby(cols).filter(lambda x: len(x) > 1)\n",
    "g = rep_data.groupby(cols)\n",
    "rc = \"mean\" #[\"mean\", \"std\"]\n",
    "tot_unsolvable_rep = g.agg({\"runtime\":\"count\",\"first_success_rate\":rc,\"first_ct\":rc, \"first_model_calls\":rc})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "7323f457-de1e-4253-abf0-e8ad4a9bea01",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "</style>\n",
       "<table id=\"T_bc73c\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank\" >&nbsp;</th>\n",
       "      <th class=\"blank\" >&nbsp;</th>\n",
       "      <th class=\"blank\" >&nbsp;</th>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_bc73c_level0_col0\" class=\"col_heading level0 col0\" >runtime</th>\n",
       "      <th id=\"T_bc73c_level0_col1\" class=\"col_heading level0 col1\" >first_success_rate</th>\n",
       "      <th id=\"T_bc73c_level0_col2\" class=\"col_heading level0 col2\" >first_ct</th>\n",
       "      <th id=\"T_bc73c_level0_col3\" class=\"col_heading level0 col3\" >first_model_calls</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"index_name level0\" >backend</th>\n",
       "      <th class=\"index_name level1\" >n_evaluate_sample</th>\n",
       "      <th class=\"index_name level2\" >n_select_sample</th>\n",
       "      <th class=\"index_name level3\" >do_verify</th>\n",
       "      <th class=\"blank col0\" >&nbsp;</th>\n",
       "      <th class=\"blank col1\" >&nbsp;</th>\n",
       "      <th class=\"blank col2\" >&nbsp;</th>\n",
       "      <th class=\"blank col3\" >&nbsp;</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_bc73c_level0_row0\" class=\"row_heading level0 row0\" >RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic</th>\n",
       "      <th id=\"T_bc73c_level1_row0\" class=\"row_heading level1 row0\" >11</th>\n",
       "      <th id=\"T_bc73c_level2_row0\" class=\"row_heading level2 row0\" >20</th>\n",
       "      <th id=\"T_bc73c_level3_row0\" class=\"row_heading level3 row0\" >True</th>\n",
       "      <td id=\"T_bc73c_row0_col0\" class=\"data row0 col0\" >5</td>\n",
       "      <td id=\"T_bc73c_row0_col1\" class=\"data row0 col1\" >78.60</td>\n",
       "      <td id=\"T_bc73c_row0_col2\" class=\"data row0 col2\" >70.6M</td>\n",
       "      <td id=\"T_bc73c_row0_col3\" class=\"data row0 col3\" >562.8K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_bc73c_level0_row1\" class=\"row_heading level0 row1\" >microsoft/phi-4</th>\n",
       "      <th id=\"T_bc73c_level1_row1\" class=\"row_heading level1 row1\" >11</th>\n",
       "      <th id=\"T_bc73c_level2_row1\" class=\"row_heading level2 row1\" >20</th>\n",
       "      <th id=\"T_bc73c_level3_row1\" class=\"row_heading level3 row1\" >True</th>\n",
       "      <td id=\"T_bc73c_row1_col0\" class=\"data row1 col0\" >5</td>\n",
       "      <td id=\"T_bc73c_row1_col1\" class=\"data row1 col1\" >82.20</td>\n",
       "      <td id=\"T_bc73c_row1_col2\" class=\"data row1 col2\" >110.9M</td>\n",
       "      <td id=\"T_bc73c_row1_col3\" class=\"data row1 col3\" >571.1K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_bc73c_level0_row2\" class=\"row_heading level0 row2\" >mistralai/Mistral-Small-24B-Instruct-2501</th>\n",
       "      <th id=\"T_bc73c_level1_row2\" class=\"row_heading level1 row2\" >11</th>\n",
       "      <th id=\"T_bc73c_level2_row2\" class=\"row_heading level2 row2\" >20</th>\n",
       "      <th id=\"T_bc73c_level3_row2\" class=\"row_heading level3 row2\" >True</th>\n",
       "      <td id=\"T_bc73c_row2_col0\" class=\"data row2 col0\" >5</td>\n",
       "      <td id=\"T_bc73c_row2_col1\" class=\"data row2 col1\" >74.20</td>\n",
       "      <td id=\"T_bc73c_row2_col2\" class=\"data row2 col2\" >65.8M</td>\n",
       "      <td id=\"T_bc73c_row2_col3\" class=\"data row2 col3\" >359.0K</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x14f8950a5dc0>"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s = tot_unsolvable_rep.style\n",
    "s.format(human_number_format, [\"first_ct\",\"first_model_calls\"]).format(lambda x: f\"{x*100:.2f}\", \"first_success_rate\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8de4730-5bf3-4009-8b58-96af7eeb1e35",
   "metadata": {},
   "source": [
    "# Retrosynthesis\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "617c110d-472b-464c-94e9-2ef808c534fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_results = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3877c311-cbab-4e3f-9c5c-8c14fba33b97",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load ours\n",
    "\n",
    "paths = glob.glob(\"data/chem/hp/**/result.json\")\n",
    "results = []\n",
    "for p in paths:\n",
    "    with open(p) as f:\n",
    "        r = json.load(f)\n",
    "        a = r.pop(\"args\")\n",
    "        r.update(a)\n",
    "    results.append(r)\n",
    "hp_data = pd.DataFrame(results)\n",
    "hp_data[\"method\"] = \"hp\"\n",
    "\n",
    "all_results.append(hp_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "efafceed-2a96-4050-95ef-9597bad56717",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 14/14 [00:09<00:00,  1.42it/s]\n"
     ]
    }
   ],
   "source": [
    "# load egmcts\n",
    "\n",
    "runs = glob.glob(\"data/chem/eg-mcts/**/\")\n",
    "\n",
    "# retro has 190 molecules..\n",
    "fd = []\n",
    "for folder in tqdm(runs):\n",
    "    for i in range(190):\n",
    "        with open(f\"{folder}{i}/result.json\") as f:\n",
    "            r = json.load(f)\n",
    "            a = r.pop(\"args\")\n",
    "            r.update(a)\n",
    "            fd.append(r)\n",
    "\n",
    "d = pd.DataFrame(fd)\n",
    "\n",
    "g = d.groupby([\"folder\",\"use_value_fn\",\"iterations\",\"expansion_topk\"])\n",
    "egmcts_data = g.agg(dict(\n",
    "    succ=list, iter=\"sum\", routes=list, route_len=list, expand_model_call=\"sum\",\n",
    "    value_model_call=\"sum\",reaction_nodes_lens=\"sum\", mol_nodes_lens=\"sum\", mol=list\n",
    ")).reset_index()\n",
    "egmcts_data[\"method\"] = \"egmcts\"\n",
    "\n",
    "all_results.append(egmcts_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "66b72681-7630-419e-9684-9074e88f7e7a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# check completeness of experiments\n",
    "retro_data = pd.concat(all_results, axis=0, ignore_index=True)\n",
    "for _, row in retro_data.iterrows():\n",
    "    if not len(row.succ) == 190:\n",
    "        print(\"failed\", row.folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b3de137a-825f-4ccb-81bb-de205502f04c",
   "metadata": {},
   "outputs": [],
   "source": [
    "retro_data[\"success_rate\"] = retro_data.succ.apply(lambda x: np.mean(x))\n",
    "retro_data[\"avg_iter\"] = retro_data[\"iter\"] / 190\n",
    "retro_data[\"avg_route_len\"] = retro_data.route_len.apply(lambda x: pd.Series(x).mean())\n",
    "retro_data[\"avg_value_model_calls\"] = retro_data[\"value_model_call\"] / 190\n",
    "retro_data[\"avg_expand_model_calls\"] = retro_data[\"expand_model_call\"] / 190\n",
    "retro_data[\"avg_mol_nodes\"] = retro_data[\"mol_nodes_lens\"] / 190\n",
    "retro_data[\"avg_reaction_nodes\"] = retro_data[\"reaction_nodes_lens\"] / 190\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "17c6e239-af5d-42e2-ad5e-8fd572552211",
   "metadata": {},
   "outputs": [],
   "source": [
    "for_table = retro_data[[\"method\", \"iterations\",\"expansion_topk\", \"success_rate\", \"avg_iter\", \"avg_value_model_calls\", \"avg_expand_model_calls\",\"avg_mol_nodes\",\"avg_reaction_nodes\"]]\n",
    "for_table = for_table.sort_values([\"method\",\"iterations\"])\n",
    "g = for_table.groupby(\"expansion_topk\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "58a45eb4-12a9-4b38-a08a-927c4b2bc481",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>method</th>\n",
       "      <th>iterations</th>\n",
       "      <th>expansion_topk</th>\n",
       "      <th>success_rate</th>\n",
       "      <th>avg_iter</th>\n",
       "      <th>avg_value_model_calls</th>\n",
       "      <th>avg_expand_model_calls</th>\n",
       "      <th>avg_mol_nodes</th>\n",
       "      <th>avg_reaction_nodes</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>egmcts</td>\n",
       "      <td>10</td>\n",
       "      <td>50</td>\n",
       "      <td>0.415789</td>\n",
       "      <td>8.010526</td>\n",
       "      <td>135.115789</td>\n",
       "      <td>8.010526</td>\n",
       "      <td>190.815789</td>\n",
       "      <td>131.105263</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>46</th>\n",
       "      <td>egmcts</td>\n",
       "      <td>50</td>\n",
       "      <td>50</td>\n",
       "      <td>0.815789</td>\n",
       "      <td>20.263158</td>\n",
       "      <td>337.263158</td>\n",
       "      <td>20.263158</td>\n",
       "      <td>453.857895</td>\n",
       "      <td>324.521053</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>egmcts</td>\n",
       "      <td>100</td>\n",
       "      <td>50</td>\n",
       "      <td>0.878947</td>\n",
       "      <td>27.594737</td>\n",
       "      <td>453.389474</td>\n",
       "      <td>27.594737</td>\n",
       "      <td>590.378947</td>\n",
       "      <td>432.778947</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>egmcts</td>\n",
       "      <td>200</td>\n",
       "      <td>50</td>\n",
       "      <td>0.910526</td>\n",
       "      <td>37.305263</td>\n",
       "      <td>602.626316</td>\n",
       "      <td>37.305263</td>\n",
       "      <td>766.805263</td>\n",
       "      <td>570.973684</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40</th>\n",
       "      <td>egmcts</td>\n",
       "      <td>300</td>\n",
       "      <td>50</td>\n",
       "      <td>0.942105</td>\n",
       "      <td>43.852632</td>\n",
       "      <td>701.547368</td>\n",
       "      <td>43.852632</td>\n",
       "      <td>882.684211</td>\n",
       "      <td>662.189474</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>42</th>\n",
       "      <td>egmcts</td>\n",
       "      <td>400</td>\n",
       "      <td>50</td>\n",
       "      <td>0.942105</td>\n",
       "      <td>49.642105</td>\n",
       "      <td>793.126316</td>\n",
       "      <td>49.642105</td>\n",
       "      <td>987.810526</td>\n",
       "      <td>746.415789</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>44</th>\n",
       "      <td>egmcts</td>\n",
       "      <td>500</td>\n",
       "      <td>50</td>\n",
       "      <td>0.942105</td>\n",
       "      <td>55.431579</td>\n",
       "      <td>882.773684</td>\n",
       "      <td>55.431579</td>\n",
       "      <td>1091.647368</td>\n",
       "      <td>828.784211</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>hp</td>\n",
       "      <td>10</td>\n",
       "      <td>50</td>\n",
       "      <td>0.726316</td>\n",
       "      <td>5.557895</td>\n",
       "      <td>94.800000</td>\n",
       "      <td>5.557895</td>\n",
       "      <td>81.415789</td>\n",
       "      <td>95.805263</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>hp</td>\n",
       "      <td>50</td>\n",
       "      <td>50</td>\n",
       "      <td>0.936842</td>\n",
       "      <td>10.478947</td>\n",
       "      <td>171.421053</td>\n",
       "      <td>10.478947</td>\n",
       "      <td>139.194737</td>\n",
       "      <td>172.426316</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>hp</td>\n",
       "      <td>100</td>\n",
       "      <td>50</td>\n",
       "      <td>0.957895</td>\n",
       "      <td>12.842105</td>\n",
       "      <td>207.168421</td>\n",
       "      <td>12.842105</td>\n",
       "      <td>166.157895</td>\n",
       "      <td>208.173684</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>hp</td>\n",
       "      <td>200</td>\n",
       "      <td>50</td>\n",
       "      <td>0.968421</td>\n",
       "      <td>17.305263</td>\n",
       "      <td>278.031579</td>\n",
       "      <td>17.305263</td>\n",
       "      <td>217.647368</td>\n",
       "      <td>279.036842</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>hp</td>\n",
       "      <td>300</td>\n",
       "      <td>50</td>\n",
       "      <td>0.989474</td>\n",
       "      <td>18.784211</td>\n",
       "      <td>301.994737</td>\n",
       "      <td>18.784211</td>\n",
       "      <td>234.978947</td>\n",
       "      <td>303.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>hp</td>\n",
       "      <td>400</td>\n",
       "      <td>50</td>\n",
       "      <td>0.994737</td>\n",
       "      <td>19.057895</td>\n",
       "      <td>309.300000</td>\n",
       "      <td>19.057895</td>\n",
       "      <td>240.078947</td>\n",
       "      <td>310.305263</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>hp</td>\n",
       "      <td>500</td>\n",
       "      <td>50</td>\n",
       "      <td>0.994737</td>\n",
       "      <td>19.115789</td>\n",
       "      <td>316.000000</td>\n",
       "      <td>19.115789</td>\n",
       "      <td>246.547368</td>\n",
       "      <td>317.005263</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    method  iterations  expansion_topk  success_rate   avg_iter  \\\n",
       "36  egmcts          10              50      0.415789   8.010526   \n",
       "46  egmcts          50              50      0.815789  20.263158   \n",
       "34  egmcts         100              50      0.878947  27.594737   \n",
       "38  egmcts         200              50      0.910526  37.305263   \n",
       "40  egmcts         300              50      0.942105  43.852632   \n",
       "42  egmcts         400              50      0.942105  49.642105   \n",
       "44  egmcts         500              50      0.942105  55.431579   \n",
       "9       hp          10              50      0.726316   5.557895   \n",
       "32      hp          50              50      0.936842  10.478947   \n",
       "4       hp         100              50      0.957895  12.842105   \n",
       "14      hp         200              50      0.968421  17.305263   \n",
       "19      hp         300              50      0.989474  18.784211   \n",
       "23      hp         400              50      0.994737  19.057895   \n",
       "27      hp         500              50      0.994737  19.115789   \n",
       "\n",
       "    avg_value_model_calls  avg_expand_model_calls  avg_mol_nodes  \\\n",
       "36             135.115789                8.010526     190.815789   \n",
       "46             337.263158               20.263158     453.857895   \n",
       "34             453.389474               27.594737     590.378947   \n",
       "38             602.626316               37.305263     766.805263   \n",
       "40             701.547368               43.852632     882.684211   \n",
       "42             793.126316               49.642105     987.810526   \n",
       "44             882.773684               55.431579    1091.647368   \n",
       "9               94.800000                5.557895      81.415789   \n",
       "32             171.421053               10.478947     139.194737   \n",
       "4              207.168421               12.842105     166.157895   \n",
       "14             278.031579               17.305263     217.647368   \n",
       "19             301.994737               18.784211     234.978947   \n",
       "23             309.300000               19.057895     240.078947   \n",
       "27             316.000000               19.115789     246.547368   \n",
       "\n",
       "    avg_reaction_nodes  \n",
       "36          131.105263  \n",
       "46          324.521053  \n",
       "34          432.778947  \n",
       "38          570.973684  \n",
       "40          662.189474  \n",
       "42          746.415789  \n",
       "44          828.784211  \n",
       "9            95.805263  \n",
       "32          172.426316  \n",
       "4           208.173684  \n",
       "14          279.036842  \n",
       "19          303.000000  \n",
       "23          310.305263  \n",
       "27          317.005263  "
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "1a118bef-e8df-45a5-a349-ff821232d7b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Success rates:\n",
      "EG-MCTS: 81.58 & 87.89 & 91.05 & 94.21 & 94.21 & 94.21\n",
      "HP: 93.68 & 95.79 & 96.84 & 98.95 & 99.47 & 99.47\n"
     ]
    }
   ],
   "source": [
    "d = g.get_group(50)\n",
    "print(\"Success rates:\")\n",
    "print(\"EG-MCTS:\", \" & \".join([f\"{x*100:.2f}\" for x in d[(d.method == \"egmcts\") & (d.iterations > 10)][\"success_rate\"]]))\n",
    "print(\"HP:\", \" & \".join([f\"{x*100:.2f}\" for x in d[(d.method == \"hp\") & (d.iterations > 10)][\"success_rate\"]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "caaa416c-7ae3-4dbc-a7a4-09b505d49640",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{llrrr}\n",
      " & method & avg_iter & avg_reaction_nodes & avg_mol_nodes \\\\\n",
      "44 & egmcts & 55.43 & 828.78 & 1091.65 \\\\\n",
      "27 & hp & 19.12 & 317.01 & 246.55 \\\\\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "temp = d[(d.iterations > 10) & (d.iterations==500)][[\"method\", \"avg_iter\", \"avg_reaction_nodes\",\"avg_mol_nodes\"]]\n",
    "s = temp.style.format(lambda x: f\"{x:.2f}\", subset=temp.columns[1:])\n",
    "print(s.to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64536e60-6974-4a3f-9dea-db48cd8d11ec",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hp_env2",
   "language": "python",
   "name": "hp_env2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
