{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Routerbench"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Table 1 from our paper, 0-shot results\n",
      "\\begin{tabular}{rrrrrrrrr}\n",
      "\\toprule\n",
      "0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 \\\\\n",
      "\\midrule\n",
      "69.619211 & 69.619211 & 69.619211 & 69.223137 & 69.223137 & 69.223137 & 70.506406 & 70.506406 & 70.506406 \\\\\n",
      "79.726787 & 74.965568 & 71.812637 & 81.235991 & 74.432665 & 71.328369 & 83.245409 & 74.632066 & 72.670410 \\\\\n",
      "80.863873 & 74.641304 & 72.475432 & 82.332691 & 73.029091 & 69.526597 & 84.480300 & 73.643526 & 69.788751 \\\\\n",
      "81.129172 & 76.100591 & 72.666791 & 83.049498 & 75.148599 & 70.175505 & 84.449863 & 75.103982 & 70.255392 \\\\\n",
      "82.339794 & 76.556644 & 73.232318 & 84.342562 & 76.315419 & 72.744801 & 87.268375 & 77.628052 & 74.401078 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "files = {\n",
    "    '9,4,5': {\n",
    "        'low': '9,4,5_low_0shot.json',\n",
    "        'medium': '9,4,5_medium_0shot.json',\n",
    "        'high': '9,4,5_high_0shot.json'\n",
    "    },\n",
    "    '0,9,4,3,5': {\n",
    "        'low': '0,9,4,3,5_low_0shot.json',\n",
    "        'medium': '0,9,4,3,5_medium_0shot.json',\n",
    "        'high': '0,9,4,3,5_high_0shot.json'\n",
    "    },\n",
    "    '0,1,2,3,4,5,6,7,8,9,10': {\n",
    "        'low': '0,1,2,3,4,5,6,7,8,9,10_low_0shot.json',\n",
    "        'medium': '0,1,2,3,4,5,6,7,8,9,10_medium_0shot.json',\n",
    "        'high': '0,1,2,3,4,5,6,7,8,9,10_high_0shot.json'\n",
    "    },\n",
    "}\n",
    "\n",
    "base_path = '../data/results/routerbench/'\n",
    "\n",
    "all_results = {\n",
    "    'linear': [],\n",
    "    'routing': [],\n",
    "    'cascading': [],\n",
    "    'cascading_ours': [],\n",
    "    'cascade_routing': []\n",
    "}\n",
    "\n",
    "for n in files:\n",
    "    for strategy in files[n]:\n",
    "        with open(base_path + files[n][strategy], 'r') as file:\n",
    "            results = json.load(file)\n",
    "            all_results['linear'].append(results['aucs_baseline']['auc'] * 100)\n",
    "            all_results['routing'].append(results['aucs_router']['auc'] * 100)\n",
    "            all_results['cascading'].append(results['aucs_cascade']['auc'] * 100)\n",
    "            if 'aucs_cascade_ours' in results:\n",
    "                all_results['cascading_ours'].append(results['aucs_cascade_ours']['auc'] * 100)\n",
    "            else:\n",
    "                all_results['cascading_ours'].append(0)\n",
    "            all_results['cascade_routing'].append(results['aucs']['auc'] * 100)\n",
    "\n",
    "df = pd.DataFrame(all_results).T\n",
    "# add avg column\n",
    "# to latex\n",
    "print(\"Table 1 from our paper, 0-shot results\")\n",
    "print(df.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Table 1 from our paper, but relative results\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>routing</th>\n",
       "      <td>25.851969</td>\n",
       "      <td>29.760017</td>\n",
       "      <td>64.724369</td>\n",
       "      <td>25.860392</td>\n",
       "      <td>36.140585</td>\n",
       "      <td>67.281467</td>\n",
       "      <td>31.579911</td>\n",
       "      <td>72.618332</td>\n",
       "      <td>79.975246</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>cascading</th>\n",
       "      <td>13.125524</td>\n",
       "      <td>38.138287</td>\n",
       "      <td>26.499559</td>\n",
       "      <td>15.331345</td>\n",
       "      <td>86.347009</td>\n",
       "      <td>1060.504033</td>\n",
       "      <td>19.952030</td>\n",
       "      <td>127.012190</td>\n",
       "      <td>10000.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>cascading_ours</th>\n",
       "      <td>10.518034</td>\n",
       "      <td>7.036359</td>\n",
       "      <td>18.556577</td>\n",
       "      <td>9.352171</td>\n",
       "      <td>19.691624</td>\n",
       "      <td>269.779744</td>\n",
       "      <td>20.213870</td>\n",
       "      <td>54.900003</td>\n",
       "      <td>10000.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                        0          1          2          3          4  \\\n",
       "routing         25.851969  29.760017  64.724369  25.860392  36.140585   \n",
       "cascading       13.125524  38.138287  26.499559  15.331345  86.347009   \n",
       "cascading_ours  10.518034   7.036359  18.556577   9.352171  19.691624   \n",
       "\n",
       "                          5          6           7             8  \n",
       "routing           67.281467  31.579911   72.618332     79.975246  \n",
       "cascading       1060.504033  19.952030  127.012190  10000.000000  \n",
       "cascading_ours   269.779744  20.213870   54.900003  10000.000000  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "relative_improvement = {\n",
    "    'routing': [],\n",
    "    'cascading': [],\n",
    "    'cascading_ours': [],\n",
    "}\n",
    "\n",
    "for col in df.columns:\n",
    "    for name in relative_improvement:\n",
    "        diff_name = (df.loc[name][col] - df.loc['linear'][col])\n",
    "        diff_ours = (df.loc['cascade_routing'][col] - df.loc['linear'][col])\n",
    "        if diff_name <= 0:\n",
    "            relative_improvement[name].append(10000)\n",
    "        else:\n",
    "            relative_improvement[name].append((diff_ours - diff_name) / diff_name * 100)\n",
    "\n",
    "df2 = pd.DataFrame(relative_improvement).T\n",
    "# to latex\n",
    "print(\"Table 1 from our paper, but relative results\")\n",
    "df2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Table 1 from our paper, 5-shot results\n",
      "\\begin{tabular}{rrrrrrrrr}\n",
      "\\toprule\n",
      "0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 \\\\\n",
      "\\midrule\n",
      "74.210281 & 74.210281 & 74.210281 & 73.823577 & 73.823577 & 73.823577 & 75.160194 & 75.160194 & 75.160194 \\\\\n",
      "81.495843 & 77.223977 & 76.006404 & 82.425209 & 76.843640 & 75.542914 & 85.341982 & 77.767974 & 76.435299 \\\\\n",
      "83.163363 & 78.578931 & 76.885932 & 84.271977 & 76.593376 & 73.918097 & 87.143259 & 78.603027 & 74.942672 \\\\\n",
      "82.678614 & 78.793712 & 77.003991 & 84.264853 & 77.195958 & 74.297096 & 86.671101 & 78.674625 & 75.075952 \\\\\n",
      "83.818197 & 78.920287 & 77.106730 & 85.506838 & 78.822596 & 76.736936 & 88.783678 & 80.881557 & 78.020648 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "files = {\n",
    "    '9,4,5': {\n",
    "        'low': '9,4,5_low_5shot.json',\n",
    "        'medium': '9,4,5_medium_5shot.json',\n",
    "        'high': '9,4,5_high_5shot.json'\n",
    "    },\n",
    "    '0,9,4,3,5': {\n",
    "        'low': '0,9,4,3,5_low_5shot.json',\n",
    "        'medium': '0,9,4,3,5_medium_5shot.json',\n",
    "        'high': '0,9,4,3,5_high_5shot.json'\n",
    "    },\n",
    "    '0,1,2,3,4,5,6,7,8,9,10': {\n",
    "        'low': '0,1,2,3,4,5,6,7,8,9,10_low_5shot.json',\n",
    "        'medium': '0,1,2,3,4,5,6,7,8,9,10_medium_5shot.json',\n",
    "        'high': '0,1,2,3,4,5,6,7,8,9,10_high_5shot.json'\n",
    "    },\n",
    "}\n",
    "\n",
    "base_path = '../data/results/routerbench/'\n",
    "\n",
    "all_results = {\n",
    "    'linear': [],\n",
    "    'routing': [],\n",
    "    'cascading': [],\n",
    "    'cascading_ours': [],\n",
    "    'cascade_routing': []\n",
    "}\n",
    "\n",
    "for n in files:\n",
    "    for strategy in files[n]:\n",
    "        with open(base_path + files[n][strategy], 'r') as file:\n",
    "            results = json.load(file)\n",
    "            all_results['linear'].append(results['aucs_baseline']['auc'] * 100)\n",
    "            all_results['routing'].append(results['aucs_router']['auc'] * 100)\n",
    "            all_results['cascading'].append(results['aucs_cascade']['auc'] * 100)\n",
    "            if 'aucs_cascade_ours' in results:\n",
    "                all_results['cascading_ours'].append(results['aucs_cascade_ours']['auc'] * 100)\n",
    "            else:\n",
    "                all_results['cascading_ours'].append(0)\n",
    "            all_results['cascade_routing'].append(results['aucs']['auc'] * 100)\n",
    "\n",
    "df = pd.DataFrame(all_results).T\n",
    "# add avg column\n",
    "# to latex\n",
    "print(\"Table 1 from our paper, 5-shot results\")\n",
    "print(df.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Other Benchmarks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Table 2 from our paper, first half on classification\n",
      "\\begin{tabular}{rrr}\n",
      "\\toprule\n",
      "0 & 1 & 2 \\\\\n",
      "\\midrule\n",
      "74.282600 & 61.683007 & 63.389435 \\\\\n",
      "74.921240 & 64.439964 & 64.890300 \\\\\n",
      "74.803145 & 54.308093 & 61.216450 \\\\\n",
      "75.523230 & 64.699685 & 64.970634 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "files = [\n",
    "    '0,1,2.json',\n",
    "    '3,4,5.json',\n",
    "    '6,7,8.json',\n",
    "]\n",
    "\n",
    "base_path = '../data/results/classification/mmlu_arc_mixeval/'\n",
    "\n",
    "all_results = {\n",
    "    'linear': [],\n",
    "    'routing': [],\n",
    "    'cascading': [],\n",
    "    'cascade_routing': []\n",
    "}\n",
    "\n",
    "for file_name in files:\n",
    "    with open(base_path + file_name, 'r') as file:\n",
    "        results = json.load(file)\n",
    "        all_results['linear'].append(results['aucs_baseline']['auc'] * 100)\n",
    "        all_results['routing'].append(results['aucs_router']['auc'] * 100)\n",
    "        all_results['cascading'].append(results['aucs_cascade']['auc'] * 100)\n",
    "        all_results['cascade_routing'].append(results['aucs']['auc'] * 100)\n",
    "\n",
    "df = pd.DataFrame(all_results).T\n",
    "# to latex\n",
    "print('Table 2 from our paper, first half on classification')\n",
    "print(df.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Table 2 from our paper, second half on open-form\n",
      "\\begin{tabular}{rrr}\n",
      "\\toprule\n",
      "0 & 1 & 2 \\\\\n",
      "\\midrule\n",
      "79.113000 & 54.097636 & 53.859567 \\\\\n",
      "79.318910 & 58.395591 & 58.705107 \\\\\n",
      "79.232862 & 56.175065 & 48.288646 \\\\\n",
      "79.837041 & 59.615956 & 58.689949 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "files = [\n",
    "    '0,1,2.json',\n",
    "    '3,4,5.json',\n",
    "    '6,7,8.json',\n",
    "]\n",
    "\n",
    "base_path = '../data/results/free_form/mmlu_gsm8k/'\n",
    "\n",
    "all_results = {\n",
    "    'linear': [],\n",
    "    'routing': [],\n",
    "    'cascading': [],\n",
    "    'cascade_routing': []\n",
    "}\n",
    "\n",
    "for file_name in files:\n",
    "    with open(base_path + file_name, 'r') as file:\n",
    "        results = json.load(file)\n",
    "        all_results['linear'].append(results['aucs_baseline']['auc'] * 100)\n",
    "        all_results['routing'].append(results['aucs_router']['auc'] * 100)\n",
    "        all_results['cascading'].append(results['aucs_cascade']['auc'] * 100)\n",
    "        all_results['cascade_routing'].append(results['aucs']['auc'] * 100)\n",
    "\n",
    "df = pd.DataFrame(all_results).T\n",
    "# to latex\n",
    "print('Table 2 from our paper, second half on open-form')\n",
    "print(df.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ablation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Table 3 from our paper\n",
      "\\begin{tabular}{rrrrrr}\n",
      "\\toprule\n",
      "0 & 1 & 2 & 3 & 4 & 5 \\\\\n",
      "\\midrule\n",
      "76.560166 & 0.699906 & 76.341777 & 1.844149 & 77.604619 & 12.754497 \\\\\n",
      "76.362055 & 0.480752 & 76.144915 & 0.788402 & 77.162711 & 1.587761 \\\\\n",
      "76.346145 & 0.364291 & 76.122586 & 0.773881 & 77.111513 & 3.039908 \\\\\n",
      "76.547995 & 0.535915 & 76.312433 & 2.235476 & 77.624591 & 87.231799 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "files = {\n",
    "    '3': {\n",
    "        'normal': '9,4,5_medium_0shot_False_False_False.json',\n",
    "        'greedy': '9,4,5_medium_0shot_False_True_False.json',\n",
    "        'no_deviation': '9,4,5_medium_0shot_False_False_True.json',\n",
    "        'no_speedup': '9,4,5_medium_0shot_True_False_False.json',\n",
    "    },\n",
    "    '5': {\n",
    "        'normal': '0,9,4,3,5_medium_0shot_False_False_False.json',\n",
    "        'greedy': '0,9,4,3,5_medium_0shot_False_True_False.json',\n",
    "        'no_deviation': '0,9,4,3,5_medium_0shot_False_False_True.json',\n",
    "        'no_speedup': '0,9,4,3,5_medium_0shot_True_False_False.json',\n",
    "    },\n",
    "    '10': {\n",
    "        'normal': '0,1,2,3,4,5,6,7,8,9,10_medium_0shot_False_False_False.json',\n",
    "        'greedy': '0,1,2,3,4,5,6,7,8,9,10_medium_0shot_False_True_False.json',\n",
    "        'no_deviation': '0,1,2,3,4,5,6,7,8,9,10_medium_0shot_False_False_True.json',\n",
    "        'no_speedup': '0,1,2,3,4,5,6,7,8,9,10_medium_0shot_True_False_False.json',\n",
    "    },\n",
    "}\n",
    "base_path = '../data/results/routerbench_times/'\n",
    "\n",
    "all_results = {\n",
    "    'normal': [],\n",
    "    'greedy': [],\n",
    "    'no_deviation': [],\n",
    "    'no_speedup': []\n",
    "}\n",
    "\n",
    "for n in files:\n",
    "    for strategy in files[n]:\n",
    "        with open(base_path + files[n][strategy], 'r') as file:\n",
    "            results = json.load(file)\n",
    "            all_results[strategy].append(results['aucs']['auc'] * 100)\n",
    "            all_results[strategy].append(np.mean(results['test']['mean_times']) * 1000)\n",
    "\n",
    "df = pd.DataFrame(all_results).T\n",
    "# add avg column\n",
    "# to latex\n",
    "print('Table 3 from our paper')\n",
    "print(df.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "selection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
