{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))\n",
    "from utils import DATA_DIR, ROOT_DIR, PLOT_DIR\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "\n",
    "from dataloader import get_nd_array, get_slice\n",
    "from download.hf import pull_predictions_from_hf\n",
    "\n",
    "import metaanalysis\n",
    "plt.close()\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=RuntimeWarning)  # ignore fitting warnings\n",
    "\n",
    "# Display all columns\n",
    "# pd.set_option('display.max_columns', None)\n",
    "# pd.set_option('display.max_rows', None)\n",
    "# pd.set_option('display.width', None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 810,570 model evaluations\n"
     ]
    }
   ],
   "source": [
    "local_path = f'{DATA_DIR}/benchmarks.parquet'\n",
    "df_benchmarks = pd.read_parquet(local_path)\n",
    "print(f'Loaded {len(df_benchmarks):,} model evaluations')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import get_selected_tasks\n",
    "benchmark_tasks = set(df_benchmarks['task'].unique())\n",
    "TASKS = sorted(list(benchmark_tasks))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datadecide import get_compute\n",
    "from utils import extract_flops\n",
    "from utils.constants_models import DDOS_MODEL_NAMES\n",
    "\n",
    "df_benchmarks[\"model_path\"] = df_benchmarks[\"model_config\"].apply(lambda x: x[\"model_path\"])\n",
    "df_benchmarks[[\"flops\", \"observational_model\"]] = df_benchmarks[\"model_path\"].apply(extract_flops).apply(pd.Series)\n",
    "\n",
    "observational_models = sorted(df_benchmarks[df_benchmarks['observational_model'] == True]['model'].unique())\n",
    "datadecide_models = DDOS_MODEL_NAMES\n",
    "\n",
    "# Add FLOPs col for DataDecide\n",
    "datadecide_mask = df_benchmarks['model'].isin(datadecide_models)\n",
    "df_benchmarks.loc[datadecide_mask, 'flops'] = df_benchmarks.loc[datadecide_mask, 'size'].apply(get_compute).astype('float64')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Update model column to model_path for entries with -model-merged\n",
    "    # last-29-model-merged => peteish7-last-29-model-merged\n",
    "mask = df_benchmarks['model_path'].str.contains('-model-merged', na=False)\n",
    "df_benchmarks.loc[mask, 'model'] = df_benchmarks.loc[mask, 'model_path'].str.split('/').str[-2:].str.join('-')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Investigate Benchmarks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "mmlu      = [t for t in TASKS if 'mmlu' in t and ':' not in t and '_pro_' not in t]\n",
    "minerva   = [t for t in TASKS if 'minerva' in t and ':' not in t and 'math_500' not in t and t != 'minerva']\n",
    "mmlu_pro  = [t for t in TASKS if '_pro_' in t and ':rc' in t]\n",
    "mmlu_mc   = [t for t in TASKS if 'mmlu' in t and ':mc' in t and '_pro_' not in t]\n",
    "olmes     = ['arc_challenge', 'arc_easy', 'boolq', 'csqa', 'hellaswag', 'openbookqa', 'piqa', 'socialiqa', 'winogrande']\n",
    "olmes_mc  = [f'{task}:mc' for task in olmes]\n",
    "olmes_para        = [f'{task}:para' for task in olmes]\n",
    "olmes_distractors = [f'{task}:distractors' for task in olmes]\n",
    "olmes_enlarge     = [f'{task}:enlarge' for task in olmes]\n",
    "olmes_gen = ['drop', 'gsm8k', 'jeopardy', 'squad', 'triviaqa'] # naturalqs\n",
    "agi_eval  = [t for t in TASKS if 'agi_eval' in t and ':rc' in t]\n",
    "bbh       = [t for t in TASKS if 'bbh' in t and ':' not in t]\n",
    "paloma    = [t for t in TASKS if 'paloma' in t]\n",
    "llm_compression = [t for t in TASKS if 'llm_compression' in t]\n",
    "custom_loss = [t for t in TASKS if 'custom_loss' in t]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/cr/z51hvhb932vbx29jw5hxd9nh0000gn/T/ipykernel_58552/394494844.py:6: FutureWarning: Downcasting object dtype arrays on .fillna, .ffill, .bfill is deprecated and will change in a future version. Call result.infer_objects(copy=False) instead. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`\n",
      "  new_task = _slice.fillna('')\n",
      "/var/folders/cr/z51hvhb932vbx29jw5hxd9nh0000gn/T/ipykernel_58552/394494844.py:6: FutureWarning: Downcasting object dtype arrays on .fillna, .ffill, .bfill is deprecated and will change in a future version. Call result.infer_objects(copy=False) instead. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`\n",
      "  new_task = _slice.fillna('')\n",
      "/var/folders/cr/z51hvhb932vbx29jw5hxd9nh0000gn/T/ipykernel_58552/394494844.py:6: FutureWarning: Downcasting object dtype arrays on .fillna, .ffill, .bfill is deprecated and will change in a future version. Call result.infer_objects(copy=False) instead. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`\n",
      "  new_task = _slice.fillna('')\n",
      "/var/folders/cr/z51hvhb932vbx29jw5hxd9nh0000gn/T/ipykernel_58552/394494844.py:6: FutureWarning: Downcasting object dtype arrays on .fillna, .ffill, .bfill is deprecated and will change in a future version. Call result.infer_objects(copy=False) instead. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`\n",
      "  new_task = _slice.fillna('')\n"
     ]
    }
   ],
   "source": [
    "from utils import get_title_from_task\n",
    "\n",
    "def add_multitask_avg(df_benchmarks, task_set):\n",
    "    _slice = get_slice(df_benchmarks, task=task_set)\n",
    "\n",
    "    new_task = _slice.fillna('')\n",
    "    new_task = new_task.groupby(['model', 'step', 'mix', 'size'])[['model_config', 'primary_score', 'logits_per_byte_corr', 'logits_per_char_corr']]\n",
    "    new_task = new_task.agg(lambda x: x.iloc[0] if x.name == 'model_config' else x[pd.to_numeric(x, errors='coerce').notnull()].mean())\n",
    "    new_task = new_task.reset_index()\n",
    "\n",
    "    new_task['step'] = pd.to_numeric(new_task['step'], errors='coerce') \n",
    "    # new_task['size'] = pd.to_numeric(new_task['size'], errors='coerce')\n",
    "\n",
    "    new_task['task'] = get_title_from_task(task_set)\n",
    "\n",
    "    # Append the aggregate task entries to df_benchmarks\n",
    "    df_benchmarks = pd.concat([df_benchmarks, new_task], axis=0)\n",
    "\n",
    "    return df_benchmarks\n",
    "\n",
    "# new task suites\n",
    "multitask_math = [\"gsm_plus\", \"gsm_symbolic_main\", \"gsm_symbolic_p1\", \"gsm_symbolic_p2\", \"minerva_math_500\"] + ['gsm8k'] + ['minerva'] # 6 # \"aime\"\n",
    "multitask_code = ['mbpp', 'mbppplus', 'codex_humaneval', 'codex_humanevalplus'] # 4\n",
    "multitask_knowledge = [\"medmcqa\", 'autobencher'] + olmes + ['mmlu'] + olmes_gen + ['mmlu_pro'] + ['agi_eval'] # 19\n",
    "multitask = multitask_knowledge + multitask_math + multitask_code # 30 # bbh\n",
    "olmes_all = olmes + mmlu + olmes_gen\n",
    "\n",
    "# Re-order tasks so that the title logic works (a bit hacky, yes)\n",
    "olmes_all = ['jeopardy'] + list(set(olmes_all) - {'jeopardy'})\n",
    "multitask = ['boolq'] + list(set(multitask) - {'boolq'})\n",
    "\n",
    "df_benchmarks = add_multitask_avg(df_benchmarks, task_set=mmlu)\n",
    "df_benchmarks = add_multitask_avg(df_benchmarks, task_set=mmlu_pro)\n",
    "df_benchmarks = add_multitask_avg(df_benchmarks, task_set=agi_eval)\n",
    "df_benchmarks = add_multitask_avg(df_benchmarks, task_set=minerva)\n",
    "df_benchmarks = add_multitask_avg(df_benchmarks, task_set=multitask_math)\n",
    "df_benchmarks = add_multitask_avg(df_benchmarks, task_set=multitask_code)\n",
    "df_benchmarks = add_multitask_avg(df_benchmarks, task_set=multitask_knowledge)\n",
    "df_benchmarks = add_multitask_avg(df_benchmarks, task_set=multitask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_tasks = get_selected_tasks(TASKS) + ['multitask_all', 'multitask_math', 'multitask_code', 'multitask_knowledge']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing benchmark properties:   8%|▊         | 3/37 [07:09<54:50, 96.79s/it]   "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bpb is empty array: [] on model \"peteish-moreeval-1B-1xC\" for task \"aime\"\n",
      "bpb is empty array: [] on model \"peteish-moreeval-1B-1xC\" for task \"aime\"\n",
      "aime failed on ladder fits No scores found for model=peteish-moreeval-1B-1xC, metric=['logits_per_byte_corr', 'primary_score'], task=aime. Seeing: []\n",
      "Failed to calculate compute cost: num_instances should be constant across task=multitask_all for task_as_list=['multitask_all']\n",
      "Failed to calculate compute cost: num_instances should be constant across task=multitask_math for task_as_list=['multitask_math']\n",
      "Failed to calculate compute cost: num_instances should be constant across task=multitask_code for task_as_list=['multitask_code']\n",
      "Failed to calculate compute cost: num_instances should be constant across task=multitask_knowledge for task_as_list=['multitask_knowledge']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Computing benchmark properties: 100%|██████████| 37/37 [13:34<00:00, 22.01s/it]  \n"
     ]
    }
   ],
   "source": [
    "from metaanalysis import compute_metaproperties\n",
    "\n",
    "df_results = compute_metaproperties(\n",
    "    # df_benchmarks, None, [multitask_math, multitask_code, multitask_knowledge, multitask, olmes_all],\n",
    "    # df_benchmarks, None, [olmes], \n",
    "    df_benchmarks, None, selected_tasks, \n",
    "    run_irt=False, run_instance_analysis=False, \n",
    "    use_parallel=True, quiet=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "olmes_core9\n",
      "minerva\n",
      "olmes_gen\n",
      "mmlu\n",
      "mmlu_pro\n",
      "agi_eval\n",
      "bbh\n",
      "arc_challenge\n",
      "arc_easy\n",
      "boolq\n",
      "csqa\n",
      "hellaswag\n",
      "openbookqa\n",
      "piqa\n",
      "socialiqa\n",
      "winogrande\n",
      "drop\n",
      "gsm8k\n",
      "jeopardy\n",
      "squad\n",
      "triviaqa\n",
      "mbpp\n",
      "mbppplus\n",
      "codex_humaneval\n",
      "codex_humanevalplus\n",
      "autobencher\n",
      "gsm_plus\n",
      "gsm_symbolic_main\n",
      "gsm_symbolic_p1\n",
      "gsm_symbolic_p2\n",
      "medmcqa\n",
      "minerva_math_500\n",
      "aime\n",
      "multitask_all\n",
      "multitask_math\n",
      "multitask_code\n",
      "multitask_knowledge\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>1B-100B Primary</th>\n",
       "      <th>1B-100B BPB</th>\n",
       "      <th>Primary % Err 13B (↓)</th>\n",
       "      <th>BPB % Err 13B (↓)</th>\n",
       "      <th>Primary Dec Acc 150M (↑)</th>\n",
       "      <th>BPB Dec Acc 150M (↑)</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>task</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>TriviaQA</th>\n",
       "      <td>$27.9$</td>\n",
       "      <td>$\\mathbf{61.8}$</td>\n",
       "      <td>$2.5$</td>\n",
       "      <td>$\\mathbf{0.5}$</td>\n",
       "      <td>$68.3$</td>\n",
       "      <td>$\\mathbf{85.3}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>SQuAD</th>\n",
       "      <td>$23.8$</td>\n",
       "      <td>$\\mathbf{29.0}$</td>\n",
       "      <td>$\\mathbf{7.6}$</td>\n",
       "      <td>$27.8$</td>\n",
       "      <td>$59.7$</td>\n",
       "      <td>$\\mathbf{61.7}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OLMES Gen</th>\n",
       "      <td>$\\mathbf{23.1}$</td>\n",
       "      <td>$20.6$</td>\n",
       "      <td>$\\mathbf{0.9}$</td>\n",
       "      <td>$2.6$</td>\n",
       "      <td>$63.3$</td>\n",
       "      <td>$\\mathbf{67.3}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ARC Easy</th>\n",
       "      <td>$21.0$</td>\n",
       "      <td>$\\mathbf{64.6}$</td>\n",
       "      <td>$5.3$</td>\n",
       "      <td>$\\mathbf{0.8}$</td>\n",
       "      <td>$93.0$</td>\n",
       "      <td>$93.0$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Jeopardy</th>\n",
       "      <td>$20.2$</td>\n",
       "      <td>$\\mathbf{22.6}$</td>\n",
       "      <td>$\\mathbf{3.5}$</td>\n",
       "      <td>$18.6$</td>\n",
       "      <td>$82.0$</td>\n",
       "      <td>$\\mathbf{83.0}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AutoBencher</th>\n",
       "      <td>$15.9$</td>\n",
       "      <td>$\\mathbf{31.3}$</td>\n",
       "      <td>$\\mathbf{0.2}$</td>\n",
       "      <td>$4.5$</td>\n",
       "      <td>$89.3$</td>\n",
       "      <td>$89.3$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Knowledge Tasks</th>\n",
       "      <td>$13.7$</td>\n",
       "      <td>$\\mathbf{44.3}$</td>\n",
       "      <td>$\\mathbf{0.8}$</td>\n",
       "      <td>$1.0$</td>\n",
       "      <td>$79.0$</td>\n",
       "      <td>$\\mathbf{80.0}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>HellaSwag</th>\n",
       "      <td>$11.8$</td>\n",
       "      <td>$\\mathbf{14.9}$</td>\n",
       "      <td>$1.4$</td>\n",
       "      <td>$\\mathbf{1.0}$</td>\n",
       "      <td>$74.3$</td>\n",
       "      <td>$\\mathbf{95.3}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>DROP</th>\n",
       "      <td>$\\mathbf{11.5}$</td>\n",
       "      <td>$9.9$</td>\n",
       "      <td>$59.0$</td>\n",
       "      <td>$\\mathbf{11.3}$</td>\n",
       "      <td>$57.3$</td>\n",
       "      <td>$\\mathbf{58.7}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MMLU Pro</th>\n",
       "      <td>$11.0$</td>\n",
       "      <td>$\\mathbf{27.6}$</td>\n",
       "      <td>$2.7$</td>\n",
       "      <td>$\\mathbf{1.3}$</td>\n",
       "      <td>$83.0$</td>\n",
       "      <td>$\\mathbf{89.0}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>All Tasks</th>\n",
       "      <td>$10.0$</td>\n",
       "      <td>$\\mathbf{31.5}$</td>\n",
       "      <td>$2.3$</td>\n",
       "      <td>$\\mathbf{0.4}$</td>\n",
       "      <td>$77.0$</td>\n",
       "      <td>$\\mathbf{83.7}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MMLU</th>\n",
       "      <td>$9.8$</td>\n",
       "      <td>$\\mathbf{35.9}$</td>\n",
       "      <td>$4.3$</td>\n",
       "      <td>$\\mathbf{0.4}$</td>\n",
       "      <td>$89.0$</td>\n",
       "      <td>$\\mathbf{92.0}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ARC Challenge</th>\n",
       "      <td>$6.6$</td>\n",
       "      <td>$\\mathbf{44.8}$</td>\n",
       "      <td>$9.7$</td>\n",
       "      <td>$\\mathbf{2.1}$</td>\n",
       "      <td>$83.3$</td>\n",
       "      <td>$\\mathbf{95.0}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>HumanEval</th>\n",
       "      <td>$6.1$</td>\n",
       "      <td>$\\mathbf{25.1}$</td>\n",
       "      <td>$9.2$</td>\n",
       "      <td>$\\mathbf{7.9}$</td>\n",
       "      <td>$74.3$</td>\n",
       "      <td>$\\mathbf{95.7}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Code Tasks</th>\n",
       "      <td>$5.5$</td>\n",
       "      <td>$\\mathbf{42.0}$</td>\n",
       "      <td>$29.5$</td>\n",
       "      <td>$\\mathbf{9.7}$</td>\n",
       "      <td>$80.3$</td>\n",
       "      <td>$\\mathbf{96.7}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CommonsenseQA</th>\n",
       "      <td>$5.5$</td>\n",
       "      <td>$\\mathbf{41.9}$</td>\n",
       "      <td>$\\mathbf{3.6}$</td>\n",
       "      <td>$5.9$</td>\n",
       "      <td>$\\mathbf{68.7}$</td>\n",
       "      <td>$65.7$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>SocialIQA</th>\n",
       "      <td>$5.5$</td>\n",
       "      <td>$\\mathbf{48.0}$</td>\n",
       "      <td>$\\mathbf{0.4}$</td>\n",
       "      <td>$1.9$</td>\n",
       "      <td>$55.0$</td>\n",
       "      <td>$\\mathbf{80.0}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>HumanEval+</th>\n",
       "      <td>$5.5$</td>\n",
       "      <td>$\\mathbf{27.4}$</td>\n",
       "      <td>$29.7$</td>\n",
       "      <td>$\\mathbf{7.1}$</td>\n",
       "      <td>$66.0$</td>\n",
       "      <td>$\\mathbf{96.3}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OLMES Core 9</th>\n",
       "      <td>$5.4$</td>\n",
       "      <td>$\\mathbf{73.2}$</td>\n",
       "      <td>$3.7$</td>\n",
       "      <td>$\\mathbf{0.2}$</td>\n",
       "      <td>$73.3$</td>\n",
       "      <td>$\\mathbf{79.3}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>WinoGrande</th>\n",
       "      <td>$\\mathbf{4.6}$</td>\n",
       "      <td>$3.6$</td>\n",
       "      <td>$10.3$</td>\n",
       "      <td>$\\mathbf{0.9}$</td>\n",
       "      <td>$49.7$</td>\n",
       "      <td>$\\mathbf{75.0}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PIQA</th>\n",
       "      <td>$4.2$</td>\n",
       "      <td>$\\mathbf{8.8}$</td>\n",
       "      <td>$\\mathbf{0.5}$</td>\n",
       "      <td>$1.3$</td>\n",
       "      <td>$\\mathbf{73.3}$</td>\n",
       "      <td>$72.7$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>BBH</th>\n",
       "      <td>$\\mathbf{3.6}$</td>\n",
       "      <td>$2.5$</td>\n",
       "      <td>$67.1$</td>\n",
       "      <td>$\\mathbf{12.9}$</td>\n",
       "      <td>$\\mathbf{64.7}$</td>\n",
       "      <td>$55.0$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MedMCQA</th>\n",
       "      <td>$3.5$</td>\n",
       "      <td>$\\mathbf{29.5}$</td>\n",
       "      <td>$8.8$</td>\n",
       "      <td>$\\mathbf{4.6}$</td>\n",
       "      <td>$60.3$</td>\n",
       "      <td>$\\mathbf{86.7}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AGI Eval</th>\n",
       "      <td>$2.5$</td>\n",
       "      <td>$\\mathbf{19.5}$</td>\n",
       "      <td>$13.7$</td>\n",
       "      <td>$\\mathbf{3.4}$</td>\n",
       "      <td>$58.7$</td>\n",
       "      <td>$\\mathbf{88.0}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OpenBookQA</th>\n",
       "      <td>$2.1$</td>\n",
       "      <td>$\\mathbf{24.2}$</td>\n",
       "      <td>$7.7$</td>\n",
       "      <td>$\\mathbf{3.3}$</td>\n",
       "      <td>$65.7$</td>\n",
       "      <td>$\\mathbf{82.7}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MBPP</th>\n",
       "      <td>$2.0$</td>\n",
       "      <td>$\\mathbf{41.8}$</td>\n",
       "      <td>$23.6$</td>\n",
       "      <td>$\\mathbf{1.0}$</td>\n",
       "      <td>$68.3$</td>\n",
       "      <td>$\\mathbf{95.3}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Minerva MATH</th>\n",
       "      <td>$1.9$</td>\n",
       "      <td>$\\mathbf{88.6}$</td>\n",
       "      <td>$11.9$</td>\n",
       "      <td>$\\mathbf{1.9}$</td>\n",
       "      <td>$51.0$</td>\n",
       "      <td>$\\mathbf{90.0}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>GSM+</th>\n",
       "      <td>$1.8$</td>\n",
       "      <td>$\\mathbf{7.3}$</td>\n",
       "      <td>$20.0$</td>\n",
       "      <td>$\\mathbf{4.8}$</td>\n",
       "      <td>$59.7$</td>\n",
       "      <td>$\\mathbf{79.0}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Math Tasks</th>\n",
       "      <td>$1.8$</td>\n",
       "      <td>$\\mathbf{22.6}$</td>\n",
       "      <td>$46.0$</td>\n",
       "      <td>$\\mathbf{5.0}$</td>\n",
       "      <td>$42.3$</td>\n",
       "      <td>$\\mathbf{88.3}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MBPP+</th>\n",
       "      <td>$1.7$</td>\n",
       "      <td>$\\mathbf{30.8}$</td>\n",
       "      <td>$39.5$</td>\n",
       "      <td>$\\mathbf{8.9}$</td>\n",
       "      <td>$62.7$</td>\n",
       "      <td>$\\mathbf{93.0}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>GSM Symbolic P1</th>\n",
       "      <td>$1.6$</td>\n",
       "      <td>$\\mathbf{6.6}$</td>\n",
       "      <td>$538.6$</td>\n",
       "      <td>$\\mathbf{5.2}$</td>\n",
       "      <td>$41.3$</td>\n",
       "      <td>$\\mathbf{81.3}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>BoolQ</th>\n",
       "      <td>$1.5$</td>\n",
       "      <td>$\\mathbf{64.8}$</td>\n",
       "      <td>$\\mathbf{5.1}$</td>\n",
       "      <td>$6.6$</td>\n",
       "      <td>$47.7$</td>\n",
       "      <td>$\\mathbf{62.3}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Minerva MATH 500</th>\n",
       "      <td>$1.4$</td>\n",
       "      <td>$\\mathbf{90.5}$</td>\n",
       "      <td>$52.5$</td>\n",
       "      <td>$\\mathbf{0.9}$</td>\n",
       "      <td>$50.7$</td>\n",
       "      <td>$\\mathbf{90.3}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>GSM Symbolic</th>\n",
       "      <td>$1.3$</td>\n",
       "      <td>$\\mathbf{6.5}$</td>\n",
       "      <td>$83.0$</td>\n",
       "      <td>$\\mathbf{5.1}$</td>\n",
       "      <td>$51.0$</td>\n",
       "      <td>$\\mathbf{78.3}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>GSM8K</th>\n",
       "      <td>$1.2$</td>\n",
       "      <td>$\\mathbf{7.0}$</td>\n",
       "      <td>$38.6$</td>\n",
       "      <td>$\\mathbf{5.9}$</td>\n",
       "      <td>$46.0$</td>\n",
       "      <td>$\\mathbf{76.7}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>GSM Symbolic P2</th>\n",
       "      <td>$1.0$</td>\n",
       "      <td>$\\mathbf{7.0}$</td>\n",
       "      <td>$74.8$</td>\n",
       "      <td>$\\mathbf{5.1}$</td>\n",
       "      <td>$40.3$</td>\n",
       "      <td>$\\mathbf{79.7}$</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  1B-100B Primary      1B-100B BPB Primary % Err 13B (↓)  \\\n",
       "task                                                                       \n",
       "TriviaQA                   $27.9$  $\\mathbf{61.8}$                 $2.5$   \n",
       "SQuAD                      $23.8$  $\\mathbf{29.0}$        $\\mathbf{7.6}$   \n",
       "OLMES Gen         $\\mathbf{23.1}$           $20.6$        $\\mathbf{0.9}$   \n",
       "ARC Easy                   $21.0$  $\\mathbf{64.6}$                 $5.3$   \n",
       "Jeopardy                   $20.2$  $\\mathbf{22.6}$        $\\mathbf{3.5}$   \n",
       "AutoBencher                $15.9$  $\\mathbf{31.3}$        $\\mathbf{0.2}$   \n",
       "Knowledge Tasks            $13.7$  $\\mathbf{44.3}$        $\\mathbf{0.8}$   \n",
       "HellaSwag                  $11.8$  $\\mathbf{14.9}$                 $1.4$   \n",
       "DROP              $\\mathbf{11.5}$            $9.9$                $59.0$   \n",
       "MMLU Pro                   $11.0$  $\\mathbf{27.6}$                 $2.7$   \n",
       "All Tasks                  $10.0$  $\\mathbf{31.5}$                 $2.3$   \n",
       "MMLU                        $9.8$  $\\mathbf{35.9}$                 $4.3$   \n",
       "ARC Challenge               $6.6$  $\\mathbf{44.8}$                 $9.7$   \n",
       "HumanEval                   $6.1$  $\\mathbf{25.1}$                 $9.2$   \n",
       "Code Tasks                  $5.5$  $\\mathbf{42.0}$                $29.5$   \n",
       "CommonsenseQA               $5.5$  $\\mathbf{41.9}$        $\\mathbf{3.6}$   \n",
       "SocialIQA                   $5.5$  $\\mathbf{48.0}$        $\\mathbf{0.4}$   \n",
       "HumanEval+                  $5.5$  $\\mathbf{27.4}$                $29.7$   \n",
       "OLMES Core 9                $5.4$  $\\mathbf{73.2}$                 $3.7$   \n",
       "WinoGrande         $\\mathbf{4.6}$            $3.6$                $10.3$   \n",
       "PIQA                        $4.2$   $\\mathbf{8.8}$        $\\mathbf{0.5}$   \n",
       "BBH                $\\mathbf{3.6}$            $2.5$                $67.1$   \n",
       "MedMCQA                     $3.5$  $\\mathbf{29.5}$                 $8.8$   \n",
       "AGI Eval                    $2.5$  $\\mathbf{19.5}$                $13.7$   \n",
       "OpenBookQA                  $2.1$  $\\mathbf{24.2}$                 $7.7$   \n",
       "MBPP                        $2.0$  $\\mathbf{41.8}$                $23.6$   \n",
       "Minerva MATH                $1.9$  $\\mathbf{88.6}$                $11.9$   \n",
       "GSM+                        $1.8$   $\\mathbf{7.3}$                $20.0$   \n",
       "Math Tasks                  $1.8$  $\\mathbf{22.6}$                $46.0$   \n",
       "MBPP+                       $1.7$  $\\mathbf{30.8}$                $39.5$   \n",
       "GSM Symbolic P1             $1.6$   $\\mathbf{6.6}$               $538.6$   \n",
       "BoolQ                       $1.5$  $\\mathbf{64.8}$        $\\mathbf{5.1}$   \n",
       "Minerva MATH 500            $1.4$  $\\mathbf{90.5}$                $52.5$   \n",
       "GSM Symbolic                $1.3$   $\\mathbf{6.5}$                $83.0$   \n",
       "GSM8K                       $1.2$   $\\mathbf{7.0}$                $38.6$   \n",
       "GSM Symbolic P2             $1.0$   $\\mathbf{7.0}$                $74.8$   \n",
       "\n",
       "                 BPB % Err 13B (↓) Primary Dec Acc 150M (↑)  \\\n",
       "task                                                          \n",
       "TriviaQA            $\\mathbf{0.5}$                   $68.3$   \n",
       "SQuAD                       $27.8$                   $59.7$   \n",
       "OLMES Gen                    $2.6$                   $63.3$   \n",
       "ARC Easy            $\\mathbf{0.8}$                   $93.0$   \n",
       "Jeopardy                    $18.6$                   $82.0$   \n",
       "AutoBencher                  $4.5$                   $89.3$   \n",
       "Knowledge Tasks              $1.0$                   $79.0$   \n",
       "HellaSwag           $\\mathbf{1.0}$                   $74.3$   \n",
       "DROP               $\\mathbf{11.3}$                   $57.3$   \n",
       "MMLU Pro            $\\mathbf{1.3}$                   $83.0$   \n",
       "All Tasks           $\\mathbf{0.4}$                   $77.0$   \n",
       "MMLU                $\\mathbf{0.4}$                   $89.0$   \n",
       "ARC Challenge       $\\mathbf{2.1}$                   $83.3$   \n",
       "HumanEval           $\\mathbf{7.9}$                   $74.3$   \n",
       "Code Tasks          $\\mathbf{9.7}$                   $80.3$   \n",
       "CommonsenseQA                $5.9$          $\\mathbf{68.7}$   \n",
       "SocialIQA                    $1.9$                   $55.0$   \n",
       "HumanEval+          $\\mathbf{7.1}$                   $66.0$   \n",
       "OLMES Core 9        $\\mathbf{0.2}$                   $73.3$   \n",
       "WinoGrande          $\\mathbf{0.9}$                   $49.7$   \n",
       "PIQA                         $1.3$          $\\mathbf{73.3}$   \n",
       "BBH                $\\mathbf{12.9}$          $\\mathbf{64.7}$   \n",
       "MedMCQA             $\\mathbf{4.6}$                   $60.3$   \n",
       "AGI Eval            $\\mathbf{3.4}$                   $58.7$   \n",
       "OpenBookQA          $\\mathbf{3.3}$                   $65.7$   \n",
       "MBPP                $\\mathbf{1.0}$                   $68.3$   \n",
       "Minerva MATH        $\\mathbf{1.9}$                   $51.0$   \n",
       "GSM+                $\\mathbf{4.8}$                   $59.7$   \n",
       "Math Tasks          $\\mathbf{5.0}$                   $42.3$   \n",
       "MBPP+               $\\mathbf{8.9}$                   $62.7$   \n",
       "GSM Symbolic P1     $\\mathbf{5.2}$                   $41.3$   \n",
       "BoolQ                        $6.6$                   $47.7$   \n",
       "Minerva MATH 500    $\\mathbf{0.9}$                   $50.7$   \n",
       "GSM Symbolic        $\\mathbf{5.1}$                   $51.0$   \n",
       "GSM8K               $\\mathbf{5.9}$                   $46.0$   \n",
       "GSM Symbolic P2     $\\mathbf{5.1}$                   $40.3$   \n",
       "\n",
       "                 BPB Dec Acc 150M (↑)  \n",
       "task                                   \n",
       "TriviaQA              $\\mathbf{85.3}$  \n",
       "SQuAD                 $\\mathbf{61.7}$  \n",
       "OLMES Gen             $\\mathbf{67.3}$  \n",
       "ARC Easy                       $93.0$  \n",
       "Jeopardy              $\\mathbf{83.0}$  \n",
       "AutoBencher                    $89.3$  \n",
       "Knowledge Tasks       $\\mathbf{80.0}$  \n",
       "HellaSwag             $\\mathbf{95.3}$  \n",
       "DROP                  $\\mathbf{58.7}$  \n",
       "MMLU Pro              $\\mathbf{89.0}$  \n",
       "All Tasks             $\\mathbf{83.7}$  \n",
       "MMLU                  $\\mathbf{92.0}$  \n",
       "ARC Challenge         $\\mathbf{95.0}$  \n",
       "HumanEval             $\\mathbf{95.7}$  \n",
       "Code Tasks            $\\mathbf{96.7}$  \n",
       "CommonsenseQA                  $65.7$  \n",
       "SocialIQA             $\\mathbf{80.0}$  \n",
       "HumanEval+            $\\mathbf{96.3}$  \n",
       "OLMES Core 9          $\\mathbf{79.3}$  \n",
       "WinoGrande            $\\mathbf{75.0}$  \n",
       "PIQA                           $72.7$  \n",
       "BBH                            $55.0$  \n",
       "MedMCQA               $\\mathbf{86.7}$  \n",
       "AGI Eval              $\\mathbf{88.0}$  \n",
       "OpenBookQA            $\\mathbf{82.7}$  \n",
       "MBPP                  $\\mathbf{95.3}$  \n",
       "Minerva MATH          $\\mathbf{90.0}$  \n",
       "GSM+                  $\\mathbf{79.0}$  \n",
       "Math Tasks            $\\mathbf{88.3}$  \n",
       "MBPP+                 $\\mathbf{93.0}$  \n",
       "GSM Symbolic P1       $\\mathbf{81.3}$  \n",
       "BoolQ                 $\\mathbf{62.3}$  \n",
       "Minerva MATH 500      $\\mathbf{90.3}$  \n",
       "GSM Symbolic          $\\mathbf{78.3}$  \n",
       "GSM8K                 $\\mathbf{76.7}$  \n",
       "GSM Symbolic P2       $\\mathbf{79.7}$  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from utils import get_pretty_task_name\n",
    "\n",
    "# sizes = ['1B', '1B-100B', '13B', '32B']\n",
    "sizes = ['1B-100B']\n",
    "metric_a = 'primary_score'\n",
    "metric_b = 'logits_per_byte_corr'\n",
    "\n",
    "# '60M', '150M'\n",
    "# In order to add DataDecide, we need step_rel_std:last30 numbers\n",
    "\n",
    "ERROR_METRICS_PRIMARY = [\n",
    "    f'dec_acc:{metric_a}:150M',\n",
    "    f'rel_error:step_2:13B:bpb_to_primary', # metric_b\n",
    "]\n",
    "\n",
    "ERROR_METRICS_BPB = [\n",
    "    f'dec_acc:{metric_b}:150M',\n",
    "    f'rel_error:step_1:13B:bpb_to_primary',\n",
    "]\n",
    "\n",
    "table_tasks = df_results.index\n",
    "\n",
    "# Collect all data into a dict\n",
    "data = {}\n",
    "for task in table_tasks:\n",
    "    task_data = df_results.loc[task]\n",
    "    data[task] = {}\n",
    "    # Add noise measures for each size\n",
    "    for size in sizes:\n",
    "        data[task][size] = {}\n",
    "        for col in [\n",
    "            f'snr:{metric_a}:{size}', \n",
    "            f'rel_std:{metric_a}:{size}', \n",
    "            f'step_rel_std:last30:{metric_a}:{size}',\n",
    "            f'snr:{metric_b}:{size}', \n",
    "            f'rel_std:{metric_b}:{size}', \n",
    "            f'step_rel_std:last30:{metric_b}:{size}',\n",
    "        ]:\n",
    "            if col in task_data.index:\n",
    "                col_no_size = col.replace(f':{size}', '')\n",
    "                data[task][size][col_no_size] = task_data[col]\n",
    "        \n",
    "    for col in ERROR_METRICS_PRIMARY + ERROR_METRICS_BPB:\n",
    "        if col in task_data.index:\n",
    "            if 'error' not in data[task]:\n",
    "                data[task]['error'] = {}\n",
    "            data[task]['error'][col] = task_data[col]\n",
    "        \n",
    "\n",
    "def pretty_metric_name(metric):\n",
    "    PRETTY_METRIC_NAME = {\n",
    "        'primary_score': 'Primary',\n",
    "        'logits_per_byte_corr': 'BPB',\n",
    "        'dec_acc:primary_score:150M': 'Primary Dec Acc 150M (↑)',\n",
    "        'dec_acc:logits_per_byte_corr:150M': 'BPB Dec Acc 150M (↑)',\n",
    "        'rel_error:step_1:13B:bpb_to_primary': 'BPB % Err 13B (↓)',\n",
    "        'rel_error:step_2:13B:bpb_to_primary': 'Primary % Err 13B (↓)',\n",
    "    }\n",
    "    return PRETTY_METRIC_NAME.get(metric, metric)\n",
    "\n",
    "# Convert to readable metrics\n",
    "rows = []\n",
    "for task, task_data in data.items():\n",
    "    row = {'task': get_pretty_task_name(task) } # pretty_metric_name(metric)\n",
    "    for metric, err_metrics in zip([metric_a, metric_b], [ERROR_METRICS_PRIMARY, ERROR_METRICS_BPB]):\n",
    "        for size in sizes:\n",
    "            snr          = task_data[size].get(f'snr:{metric}', '--')\n",
    "            rel_std      = task_data[size].get(f'rel_std:{metric}', '--') \n",
    "            step_rel_std = task_data[size].get(f'step_rel_std:last30:{metric}', '--')\n",
    "            \n",
    "            if all(x != '--' for x in [snr, rel_std, step_rel_std]):\n",
    "                row[size + ' ' + pretty_metric_name(metric)] = f'${snr:.1f}$'\n",
    "                # row[size + ' ' + pretty_metric_name(metric)] = f'${snr:.1f}_{{{rel_std:.3f} / {step_rel_std:.3f}}}$'\n",
    "            else:\n",
    "                row[size + ' ' + pretty_metric_name(metric)] = '--'\n",
    "    \n",
    "        for err_metric in err_metrics:\n",
    "            error = task_data['error'][err_metric]\n",
    "\n",
    "            if 'dec_acc' in err_metric:\n",
    "                err_str = f'${error*100:.1f}$'\n",
    "            # elif 'rel_error' in err_metric:\n",
    "            #     margin_of_err_metric = {\n",
    "            #         'rel_error:step_2:13B:bpb_to_primary': 'scaling_margin_of_error:stacked:13B:bpb_to_primary',\n",
    "            #         'rel_error:step_1:13B:bpb_to_primary': 'scaling_margin_of_error:step_1:13B:bpb_to_primary'\n",
    "            #     }\n",
    "            #     moe = df_results.loc[task, margin_of_err_metric[err_metric]]\n",
    "            #     err_str = f'${error*100:.1f}%_{{\\pm {moe*100:.1f}%}}$'\n",
    "            else:\n",
    "                err_str = f'${error*100:.1f}$'\n",
    "            \n",
    "            row[pretty_metric_name(err_metric)] = err_str\n",
    "        \n",
    "    rows.append(row)\n",
    "\n",
    "df_noise_bpb = pd.DataFrame(rows)\n",
    "df_noise_bpb = df_noise_bpb.set_index('task')\n",
    "\n",
    "# Sort cols\n",
    "df_noise_bpb = df_noise_bpb.reindex([\n",
    "    \"1B-100B Primary\",\n",
    "    \"1B-100B BPB\",\n",
    "    \"Primary % Err 13B (↓)\",\n",
    "    \"BPB % Err 13B (↓)\",\n",
    "    \"Primary Dec Acc 150M (↑)\",\n",
    "    \"BPB Dec Acc 150M (↑)\",\n",
    "], axis=1)\n",
    "\n",
    "# Remove certain tasks\n",
    "df_noise_bpb = df_noise_bpb[~(\n",
    "    df_noise_bpb.index.str.contains(':mc', case=False) |\n",
    "    df_noise_bpb.index.str.contains('paloma', case=False) |\n",
    "    df_noise_bpb.index.str.contains('aime', case=False)\n",
    ")]\n",
    "\n",
    "# Sort by Primary column, converting string values like '$8.16$' to floats\n",
    "df_noise_bpb = df_noise_bpb.sort_values(\n",
    "    by='1B-100B Primary', \n",
    "    key=lambda x: x.str.extract(r'\\$([\\d.]+)\\$')[0].astype(float), \n",
    "    ascending=False\n",
    ")\n",
    "\n",
    "# Bold higher values between Primary and BPB columns\n",
    "col_pairs = [\n",
    "    ['1B-100B Primary', '1B-100B BPB'],\n",
    "    ['Primary Dec Acc 150M (↑)', 'BPB Dec Acc 150M (↑)'],\n",
    "    ['Primary % Err 13B (↓)', 'BPB % Err 13B (↓)']\n",
    "]\n",
    "for col_a, col_b in col_pairs:\n",
    "    for idx in df_noise_bpb.index:\n",
    "        col_a_val = df_noise_bpb.loc[idx, col_a]\n",
    "        col_b_val = df_noise_bpb.loc[idx, col_b]\n",
    "\n",
    "        def get_val(_str):\n",
    "            _str = _str.strip('$')\n",
    "            _str = _str.split('_')[0]\n",
    "            _str = _str.strip('%')\n",
    "            return _str\n",
    "\n",
    "        val_a = float(get_val(col_a_val))\n",
    "        val_b = float(get_val(col_b_val))\n",
    "\n",
    "        if abs(val_a - val_b) < 0.01:  # Values are effectively equal\n",
    "            df_noise_bpb.loc[idx, col_a] = f'{col_a_val}'\n",
    "            df_noise_bpb.loc[idx, col_b] = f'{col_b_val}'\n",
    "        elif 'Err' in col_a:  # Lower is better for error metrics\n",
    "            if val_a < val_b:\n",
    "                df_noise_bpb.loc[idx, col_a] = f'$\\\\mathbf{{{col_a_val.strip(\"$\")}}}$'\n",
    "                df_noise_bpb.loc[idx, col_b] = f'{col_b_val}'\n",
    "            else:\n",
    "                df_noise_bpb.loc[idx, col_a] = f'{col_a_val}'\n",
    "                df_noise_bpb.loc[idx, col_b] = f'$\\\\mathbf{{{col_b_val.strip(\"$\")}}}$'\n",
    "        else:  # Higher is better for non-error metrics\n",
    "            if val_a > val_b:\n",
    "                df_noise_bpb.loc[idx, col_a] = f'$\\\\mathbf{{{col_a_val.strip(\"$\")}}}$'\n",
    "                df_noise_bpb.loc[idx, col_b] = f'{col_b_val}'\n",
    "            else:\n",
    "                df_noise_bpb.loc[idx, col_a] = f'{col_a_val}'\n",
    "                df_noise_bpb.loc[idx, col_b] = f'$\\\\mathbf{{{col_b_val.strip(\"$\")}}}$'\n",
    "\n",
    "df_noise_bpb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{table*}\n",
      "\\scriptsize\n",
      "\\centering\n",
      "\\makebox[\\textwidth]{\n",
      "\\label{tab:noise}\n",
      "\\begin{tabular}{p{0.18\\textwidth}C{0.05\\textwidth}C{0.05\\textwidth}C{0.05\\textwidth}C{0.05\\textwidth}C{0.05\\textwidth}C{0.05\\textwidth}}\n",
      "\\toprule\n",
      " & 1B-100B Primary & 1B-100B BPB & Primary \\% Err 13B (↓) & BPB \\% Err 13B (↓) & Primary Dec Acc 150M (↑) & BPB Dec Acc 150M (↑) \\\\\n",
      "Task &  &  &  &  &  &  \\\\\n",
      "\\midrule\n",
      "TriviaQA & $27.9$ & $\\mathbf{61.8}$ & $2.5$ & $\\mathbf{0.5}$ & $68.3$ & $\\mathbf{85.3}$ \\\\\n",
      "SQuAD & $23.8$ & $\\mathbf{29.0}$ & $\\mathbf{7.6}$ & $27.8$ & $59.7$ & $\\mathbf{61.7}$ \\\\\n",
      "OLMES Gen & $\\mathbf{23.1}$ & $20.6$ & $\\mathbf{0.9}$ & $2.6$ & $63.3$ & $\\mathbf{67.3}$ \\\\\n",
      "ARC Easy & $21.0$ & $\\mathbf{64.6}$ & $5.3$ & $\\mathbf{0.8}$ & $93.0$ & $93.0$ \\\\\n",
      "Jeopardy & $20.2$ & $\\mathbf{22.6}$ & $\\mathbf{3.5}$ & $18.6$ & $82.0$ & $\\mathbf{83.0}$ \\\\\n",
      "AutoBencher & $15.9$ & $\\mathbf{31.3}$ & $\\mathbf{0.2}$ & $4.5$ & $89.3$ & $89.3$ \\\\\n",
      "Knowledge Tasks & $13.7$ & $\\mathbf{44.3}$ & $\\mathbf{0.8}$ & $1.0$ & $79.0$ & $\\mathbf{80.0}$ \\\\\n",
      "HellaSwag & $11.8$ & $\\mathbf{14.9}$ & $1.4$ & $\\mathbf{1.0}$ & $74.3$ & $\\mathbf{95.3}$ \\\\\n",
      "DROP & $\\mathbf{11.5}$ & $9.9$ & $59.0$ & $\\mathbf{11.3}$ & $57.3$ & $\\mathbf{58.7}$ \\\\\n",
      "MMLU Pro & $11.0$ & $\\mathbf{27.6}$ & $2.7$ & $\\mathbf{1.3}$ & $83.0$ & $\\mathbf{89.0}$ \\\\\n",
      "All Tasks & $10.0$ & $\\mathbf{31.5}$ & $2.3$ & $\\mathbf{0.4}$ & $77.0$ & $\\mathbf{83.7}$ \\\\\n",
      "MMLU & $9.8$ & $\\mathbf{35.9}$ & $4.3$ & $\\mathbf{0.4}$ & $89.0$ & $\\mathbf{92.0}$ \\\\\n",
      "ARC Challenge & $6.6$ & $\\mathbf{44.8}$ & $9.7$ & $\\mathbf{2.1}$ & $83.3$ & $\\mathbf{95.0}$ \\\\\n",
      "HumanEval & $6.1$ & $\\mathbf{25.1}$ & $9.2$ & $\\mathbf{7.9}$ & $74.3$ & $\\mathbf{95.7}$ \\\\\n",
      "Code Tasks & $5.5$ & $\\mathbf{42.0}$ & $29.5$ & $\\mathbf{9.7}$ & $80.3$ & $\\mathbf{96.7}$ \\\\\n",
      "CommonsenseQA & $5.5$ & $\\mathbf{41.9}$ & $\\mathbf{3.6}$ & $5.9$ & $\\mathbf{68.7}$ & $65.7$ \\\\\n",
      "SocialIQA & $5.5$ & $\\mathbf{48.0}$ & $\\mathbf{0.4}$ & $1.9$ & $55.0$ & $\\mathbf{80.0}$ \\\\\n",
      "HumanEval+ & $5.5$ & $\\mathbf{27.4}$ & $29.7$ & $\\mathbf{7.1}$ & $66.0$ & $\\mathbf{96.3}$ \\\\\n",
      "OLMES Core 9 & $5.4$ & $\\mathbf{73.2}$ & $3.7$ & $\\mathbf{0.2}$ & $73.3$ & $\\mathbf{79.3}$ \\\\\n",
      "WinoGrande & $\\mathbf{4.6}$ & $3.6$ & $10.3$ & $\\mathbf{0.9}$ & $49.7$ & $\\mathbf{75.0}$ \\\\\n",
      "PIQA & $4.2$ & $\\mathbf{8.8}$ & $\\mathbf{0.5}$ & $1.3$ & $\\mathbf{73.3}$ & $72.7$ \\\\\n",
      "BBH & $\\mathbf{3.6}$ & $2.5$ & $67.1$ & $\\mathbf{12.9}$ & $\\mathbf{64.7}$ & $55.0$ \\\\\n",
      "MedMCQA & $3.5$ & $\\mathbf{29.5}$ & $8.8$ & $\\mathbf{4.6}$ & $60.3$ & $\\mathbf{86.7}$ \\\\\n",
      "AGI Eval & $2.5$ & $\\mathbf{19.5}$ & $13.7$ & $\\mathbf{3.4}$ & $58.7$ & $\\mathbf{88.0}$ \\\\\n",
      "OpenBookQA & $2.1$ & $\\mathbf{24.2}$ & $7.7$ & $\\mathbf{3.3}$ & $65.7$ & $\\mathbf{82.7}$ \\\\\n",
      "MBPP & $2.0$ & $\\mathbf{41.8}$ & $23.6$ & $\\mathbf{1.0}$ & $68.3$ & $\\mathbf{95.3}$ \\\\\n",
      "Minerva MATH & $1.9$ & $\\mathbf{88.6}$ & $11.9$ & $\\mathbf{1.9}$ & $51.0$ & $\\mathbf{90.0}$ \\\\\n",
      "GSM+ & $1.8$ & $\\mathbf{7.3}$ & $20.0$ & $\\mathbf{4.8}$ & $59.7$ & $\\mathbf{79.0}$ \\\\\n",
      "Math Tasks & $1.8$ & $\\mathbf{22.6}$ & $46.0$ & $\\mathbf{5.0}$ & $42.3$ & $\\mathbf{88.3}$ \\\\\n",
      "MBPP+ & $1.7$ & $\\mathbf{30.8}$ & $39.5$ & $\\mathbf{8.9}$ & $62.7$ & $\\mathbf{93.0}$ \\\\\n",
      "GSM Symb. P1 & $1.6$ & $\\mathbf{6.6}$ & $538.6$ & $\\mathbf{5.2}$ & $41.3$ & $\\mathbf{81.3}$ \\\\\n",
      "BoolQ & $1.5$ & $\\mathbf{64.8}$ & $\\mathbf{5.1}$ & $6.6$ & $47.7$ & $\\mathbf{62.3}$ \\\\\n",
      "Minerva MATH 500 & $1.4$ & $\\mathbf{90.5}$ & $52.5$ & $\\mathbf{0.9}$ & $50.7$ & $\\mathbf{90.3}$ \\\\\n",
      "GSM Symb. & $1.3$ & $\\mathbf{6.5}$ & $83.0$ & $\\mathbf{5.1}$ & $51.0$ & $\\mathbf{78.3}$ \\\\\n",
      "GSM8K & $1.2$ & $\\mathbf{7.0}$ & $38.6$ & $\\mathbf{5.9}$ & $46.0$ & $\\mathbf{76.7}$ \\\\\n",
      "GSM Symb. P2 & $1.0$ & $\\mathbf{7.0}$ & $74.8$ & $\\mathbf{5.1}$ & $40.3$ & $\\mathbf{79.7}$ \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "}\n",
      "\\caption{SNR using BPB}\n",
      "\\end{table*}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "CAPTION = \"SNR using BPB\"\n",
    "table_str = df_noise_bpb.rename_axis('Task').fillna('--').to_latex(\n",
    "    label=\"tab:noise\",\n",
    "    escape=False,\n",
    "    float_format=lambda x: '{:.3f}'.format(x),\n",
    "    column_format='p{0.18\\\\textwidth}' + 'C{0.05\\\\textwidth}'*len(df_noise_bpb.columns)\n",
    ")\n",
    "table_str = table_str.replace('%', '\\%')\n",
    "table_str = table_str.replace('table', 'table*')\\\n",
    "    .replace('\\\\begin{table*}', '\\\\begin{table*}\\n\\\\scriptsize\\n\\\\centering\\n\\\\makebox[\\\\textwidth]{')\\\n",
    "    .replace('\\\\end{table*}', '}\\n\\\\caption{' + CAPTION + '}\\n\\\\end{table*}')\\\n",
    "    .replace('Symbolic', 'Symb.')\n",
    "\n",
    "print(table_str)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
