{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from glob import glob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = [\"agnews\", \"sst-5\", \"gsm8k\", \"copa\", \"subj\"]\n",
    "models = [\"llama\", \"qwen\", \"mistral\"]\n",
    "seeds = [42, 43, 44]\n",
    "dfs = []\n",
    "for dataset in datasets:\n",
    "    for model in models:\n",
    "        for seed in seeds:\n",
    "            path = glob(\n",
    "                f\"../results/main_results/{dataset}/{model}/CAPO/seed{seed}/*/*/prompt_scores.parquet\"\n",
    "            )[0]\n",
    "            df = pd.read_parquet(path)\n",
    "            df[\"evaluated_blocks\"] = 10 - df.isnull().sum(axis=1)\n",
    "            df[\"evaluated_blocks_when_no_racing\"] = 10\n",
    "            dfs.append(\n",
    "                df.assign(\n",
    "                    dataset=dataset,\n",
    "                    model=model,\n",
    "                    seed=seed,\n",
    "                )\n",
    "            )\n",
    "\n",
    "df = pd.concat(dfs, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.groupby([\"dataset\", \"model\", \"seed\"]).agg(\n",
    "    mean_evaluated_blocks=(\"evaluated_blocks\", \"sum\"),\n",
    "    mean_evaluated_blocks_when_no_racing=(\"evaluated_blocks_when_no_racing\", \"sum\"),\n",
    ")\n",
    "\n",
    "# build mean\n",
    "df = df.groupby([\"dataset\", \"model\"]).agg(\n",
    "    mean_evaluated_blocks=(\"mean_evaluated_blocks\", \"mean\"),\n",
    "    mean_evaluated_blocks_when_no_racing=(\"mean_evaluated_blocks_when_no_racing\", \"mean\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"saved_evals\"] = df[\"mean_evaluated_blocks_when_no_racing\"] - df[\"mean_evaluated_blocks\"]\n",
    "df[\"saved_evals_ratio\"] = df[\"saved_evals\"] / df[\"mean_evaluated_blocks_when_no_racing\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>mean_evaluated_blocks</th>\n",
       "      <th>mean_evaluated_blocks_when_no_racing</th>\n",
       "      <th>saved_evals</th>\n",
       "      <th>saved_evals_ratio</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dataset</th>\n",
       "      <th>model</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">agnews</th>\n",
       "      <th>llama</th>\n",
       "      <td>929.000000</td>\n",
       "      <td>1886.666667</td>\n",
       "      <td>957.666667</td>\n",
       "      <td>0.507597</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mistral</th>\n",
       "      <td>608.333333</td>\n",
       "      <td>1356.666667</td>\n",
       "      <td>748.333333</td>\n",
       "      <td>0.551597</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>qwen</th>\n",
       "      <td>707.000000</td>\n",
       "      <td>1310.000000</td>\n",
       "      <td>603.000000</td>\n",
       "      <td>0.460305</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">copa</th>\n",
       "      <th>llama</th>\n",
       "      <td>804.666667</td>\n",
       "      <td>1690.000000</td>\n",
       "      <td>885.333333</td>\n",
       "      <td>0.523866</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mistral</th>\n",
       "      <td>754.666667</td>\n",
       "      <td>1273.333333</td>\n",
       "      <td>518.666667</td>\n",
       "      <td>0.407330</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>qwen</th>\n",
       "      <td>948.666667</td>\n",
       "      <td>1566.666667</td>\n",
       "      <td>618.000000</td>\n",
       "      <td>0.394468</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">gsm8k</th>\n",
       "      <th>llama</th>\n",
       "      <td>317.666667</td>\n",
       "      <td>630.000000</td>\n",
       "      <td>312.333333</td>\n",
       "      <td>0.495767</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mistral</th>\n",
       "      <td>314.000000</td>\n",
       "      <td>456.666667</td>\n",
       "      <td>142.666667</td>\n",
       "      <td>0.312409</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>qwen</th>\n",
       "      <td>376.666667</td>\n",
       "      <td>633.333333</td>\n",
       "      <td>256.666667</td>\n",
       "      <td>0.405263</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">sst-5</th>\n",
       "      <th>llama</th>\n",
       "      <td>832.666667</td>\n",
       "      <td>1316.666667</td>\n",
       "      <td>484.000000</td>\n",
       "      <td>0.367595</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mistral</th>\n",
       "      <td>703.333333</td>\n",
       "      <td>1093.333333</td>\n",
       "      <td>390.000000</td>\n",
       "      <td>0.356707</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>qwen</th>\n",
       "      <td>836.333333</td>\n",
       "      <td>1070.000000</td>\n",
       "      <td>233.666667</td>\n",
       "      <td>0.218380</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"3\" valign=\"top\">subj</th>\n",
       "      <th>llama</th>\n",
       "      <td>648.333333</td>\n",
       "      <td>1566.666667</td>\n",
       "      <td>918.333333</td>\n",
       "      <td>0.586170</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mistral</th>\n",
       "      <td>625.000000</td>\n",
       "      <td>1260.000000</td>\n",
       "      <td>635.000000</td>\n",
       "      <td>0.503968</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>qwen</th>\n",
       "      <td>672.666667</td>\n",
       "      <td>1360.000000</td>\n",
       "      <td>687.333333</td>\n",
       "      <td>0.505392</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 mean_evaluated_blocks  mean_evaluated_blocks_when_no_racing  \\\n",
       "dataset model                                                                  \n",
       "agnews  llama               929.000000                           1886.666667   \n",
       "        mistral             608.333333                           1356.666667   \n",
       "        qwen                707.000000                           1310.000000   \n",
       "copa    llama               804.666667                           1690.000000   \n",
       "        mistral             754.666667                           1273.333333   \n",
       "        qwen                948.666667                           1566.666667   \n",
       "gsm8k   llama               317.666667                            630.000000   \n",
       "        mistral             314.000000                            456.666667   \n",
       "        qwen                376.666667                            633.333333   \n",
       "sst-5   llama               832.666667                           1316.666667   \n",
       "        mistral             703.333333                           1093.333333   \n",
       "        qwen                836.333333                           1070.000000   \n",
       "subj    llama               648.333333                           1566.666667   \n",
       "        mistral             625.000000                           1260.000000   \n",
       "        qwen                672.666667                           1360.000000   \n",
       "\n",
       "                 saved_evals  saved_evals_ratio  \n",
       "dataset model                                    \n",
       "agnews  llama     957.666667           0.507597  \n",
       "        mistral   748.333333           0.551597  \n",
       "        qwen      603.000000           0.460305  \n",
       "copa    llama     885.333333           0.523866  \n",
       "        mistral   518.666667           0.407330  \n",
       "        qwen      618.000000           0.394468  \n",
       "gsm8k   llama     312.333333           0.495767  \n",
       "        mistral   142.666667           0.312409  \n",
       "        qwen      256.666667           0.405263  \n",
       "sst-5   llama     484.000000           0.367595  \n",
       "        mistral   390.000000           0.356707  \n",
       "        qwen      233.666667           0.218380  \n",
       "subj    llama     918.333333           0.586170  \n",
       "        mistral   635.000000           0.503968  \n",
       "        qwen      687.333333           0.505392  "
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "mean_evaluated_blocks                    671.933333\n",
       "mean_evaluated_blocks_when_no_racing    1231.333333\n",
       "saved_evals                              559.400000\n",
       "saved_evals_ratio                          0.439788\n",
       "dtype: float64"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.mean()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
