{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "071aaf1b-2428-42ff-b7f6-2137d062a4cd",
   "metadata": {},
   "source": [
    "# Task-Specific Performance Gains Table Generation\n",
    "\n",
    "This notebook generatestables showing task-specific performance gains across different representation methods.\n",
    "\n",
    "- **Data**\n",
    "    - **Input**: `all_runs_v11.pkl` - 9 vision transformers, 19 datasets\n",
    "    - **Methods**: 6 representation approaches (CLS baseline, attentive probes, linear/attentive multi-layer)\n",
    "- **Statistical Analysis**\n",
    "    - **Wilcoxon signed-rank tests**: Compare each method against the best-performing method per dataset\n",
    "    - **Significance threshold**: p < 0.05\n",
    "    - **Metrics**: Train and test balanced accuracy gains\n",
    "\n",
    "- **Table Generation**: Creates two LaTeX table versions:\n",
    "    1. **Statistical significance formatting**: Bold for statistically significant improvements over best method\n",
    "    2. **Ranking-based formatting**: Bold for best performance, underline for second-best performance\n",
    "\n",
    "- **Output**: LaTeX tables showing mean ± standard deviation performance gains across models, with datasets sorted by domain category and statistical formatting applied."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bb033029-6c6b-43ab-a49e-cc9ec5ef747f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import wilcoxon\n",
    "import matplotlib.gridspec as gridspec\n",
    "\n",
    "sys.path.append('..')\n",
    "sys.path.append('../..')\n",
    "\n",
    "from constants import base_model_name_mapping, BASE_PATH_PROJECT, experiment_with_probe_type_order_list, experiment_order_list\n",
    "from helper import style_multimodel_heatmap, init_plotting_params, save_or_show"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "100ab5a4-5ef7-4baa-aa8b-b6e5319a17f2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"agg.path.chunksize\": 0,\n",
      "  \"axes.labelsize\": 13.0,\n",
      "  \"axes.titlesize\": 14.0,\n",
      "  \"axes3d.trackballsize\": 0.667,\n",
      "  \"boxplot.flierprops.markersize\": 6.0,\n",
      "  \"boxplot.meanprops.markersize\": 6.0,\n",
      "  \"errorbar.capsize\": 0.0,\n",
      "  \"figure.figsize\": [\n",
      "    6.4,\n",
      "    4.8\n",
      "  ],\n",
      "  \"figure.labelsize\": \"large\",\n",
      "  \"figure.titlesize\": \"large\",\n",
      "  \"font.cursive\": [\n",
      "    \"Apple Chancery\",\n",
      "    \"Textile\",\n",
      "    \"Zapf Chancery\",\n",
      "    \"Sand\",\n",
      "    \"Script MT\",\n",
      "    \"Felipa\",\n",
      "    \"Comic Neue\",\n",
      "    \"Comic Sans MS\",\n",
      "    \"cursive\"\n",
      "  ],\n",
      "  \"font.family\": [\n",
      "    \"sans-serif\"\n",
      "  ],\n",
      "  \"font.fantasy\": [\n",
      "    \"Chicago\",\n",
      "    \"Charcoal\",\n",
      "    \"Impact\",\n",
      "    \"Western\",\n",
      "    \"xkcd script\",\n",
      "    \"fantasy\"\n",
      "  ],\n",
      "  \"font.monospace\": [\n",
      "    \"DejaVu Sans Mono\",\n",
      "    \"Bitstream Vera Sans Mono\",\n",
      "    \"Computer Modern Typewriter\",\n",
      "    \"Andale Mono\",\n",
      "    \"Nimbus Mono L\",\n",
      "    \"Courier New\",\n",
      "    \"Courier\",\n",
      "    \"Fixed\",\n",
      "    \"Terminal\",\n",
      "    \"monospace\"\n",
      "  ],\n",
      "  \"font.sans-serif\": [\n",
      "    \"DejaVu Sans\",\n",
      "    \"Bitstream Vera Sans\",\n",
      "    \"Computer Modern Sans Serif\",\n",
      "    \"Lucida Grande\",\n",
      "    \"Verdana\",\n",
      "    \"Geneva\",\n",
      "    \"Lucid\",\n",
      "    \"Arial\",\n",
      "    \"Helvetica\",\n",
      "    \"Avant Garde\",\n",
      "    \"sans-serif\"\n",
      "  ],\n",
      "  \"font.serif\": [\n",
      "    \"DejaVu Serif\",\n",
      "    \"Bitstream Vera Serif\",\n",
      "    \"Computer Modern Roman\",\n",
      "    \"New Century Schoolbook\",\n",
      "    \"Century Schoolbook L\",\n",
      "    \"Utopia\",\n",
      "    \"ITC Bookman\",\n",
      "    \"Bookman\",\n",
      "    \"Nimbus Roman No9 L\",\n",
      "    \"Times New Roman\",\n",
      "    \"Times\",\n",
      "    \"Palatino\",\n",
      "    \"Charter\",\n",
      "    \"serif\"\n",
      "  ],\n",
      "  \"font.size\": 10.0,\n",
      "  \"font.stretch\": \"normal\",\n",
      "  \"font.style\": \"normal\",\n",
      "  \"font.variant\": \"normal\",\n",
      "  \"font.weight\": \"normal\",\n",
      "  \"legend.fontsize\": 12.0,\n",
      "  \"legend.title_fontsize\": 13.0,\n",
      "  \"lines.markersize\": 6.0,\n",
      "  \"mathtext.fontset\": \"dejavusans\",\n",
      "  \"pdf.fonttype\": 42,\n",
      "  \"pdf.use14corefonts\": false,\n",
      "  \"pgf.rcfonts\": true,\n",
      "  \"ps.fonttype\": 42,\n",
      "  \"ps.papersize\": \"letter\",\n",
      "  \"svg.fonttype\": \"path\",\n",
      "  \"xtick.labelsize\": 12.0,\n",
      "  \"xtick.major.size\": 3.5,\n",
      "  \"xtick.minor.size\": 2.0,\n",
      "  \"ytick.labelsize\": 12.0,\n",
      "  \"ytick.major.size\": 3.5,\n",
      "  \"ytick.minor.size\": 2.0\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "init_plotting_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f985fc8b-a3f7-4a98-bbf1-c92639b1bb5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "SAVE = True\n",
    "\n",
    "base_storing_path = BASE_PATH_PROJECT / 'results_iclr_exp/plots' \n",
    "\n",
    "if SAVE:\n",
    "    base_storing_path.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ed4100ec-6930-4df5-8229-6fb059a3a7a7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>base_model</th>\n",
       "      <th>level_2</th>\n",
       "      <th>task</th>\n",
       "      <th>mode</th>\n",
       "      <th>combiner</th>\n",
       "      <th>feature_normalization</th>\n",
       "      <th>feature_alignment</th>\n",
       "      <th>train_split</th>\n",
       "      <th>val_proportion</th>\n",
       "      <th>...</th>\n",
       "      <th>res_folder</th>\n",
       "      <th>abs_perf_gain_train_lp_acc1</th>\n",
       "      <th>abs_perf_gain_train_lp_acc5</th>\n",
       "      <th>abs_perf_gain_train_lp_bal_acc1</th>\n",
       "      <th>abs_perf_gain_train_lp_bal_acc5</th>\n",
       "      <th>abs_perf_gain_test_lp_acc1</th>\n",
       "      <th>abs_perf_gain_test_lp_acc5</th>\n",
       "      <th>abs_perf_gain_test_lp_bal_acc1</th>\n",
       "      <th>abs_perf_gain_test_lp_bal_acc5</th>\n",
       "      <th>abs_perf_gain_best_val_bal_acc1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>imagenet-subset-50k</td>\n",
       "      <td>OpenCLIP_ViT-B-16_openai</td>\n",
       "      <td>0</td>\n",
       "      <td>attentive_probe</td>\n",
       "      <td>combined_models</td>\n",
       "      <td>stacked_zero_pad</td>\n",
       "      <td>True</td>\n",
       "      <td>[null, null]</td>\n",
       "      <td>train</td>\n",
       "      <td>0.2</td>\n",
       "      <td>...</td>\n",
       "      <td>results_iclr_exp</td>\n",
       "      <td>0.03238</td>\n",
       "      <td>0.01422</td>\n",
       "      <td>0.03238</td>\n",
       "      <td>0.01422</td>\n",
       "      <td>0.0134</td>\n",
       "      <td>0.0093</td>\n",
       "      <td>0.0134</td>\n",
       "      <td>0.0093</td>\n",
       "      <td>0.0142</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>imagenet-subset-50k</td>\n",
       "      <td>OpenCLIP_ViT-B-16_openai</td>\n",
       "      <td>189</td>\n",
       "      <td>attentive_probe</td>\n",
       "      <td>combined_models</td>\n",
       "      <td>stacked_zero_pad</td>\n",
       "      <td>True</td>\n",
       "      <td>[null, null, null, null, null, null, null, nul...</td>\n",
       "      <td>train</td>\n",
       "      <td>0.2</td>\n",
       "      <td>...</td>\n",
       "      <td>results_iclr_exp</td>\n",
       "      <td>0.14866</td>\n",
       "      <td>0.04452</td>\n",
       "      <td>0.14866</td>\n",
       "      <td>0.04452</td>\n",
       "      <td>0.0202</td>\n",
       "      <td>0.01154</td>\n",
       "      <td>0.0202</td>\n",
       "      <td>0.01154</td>\n",
       "      <td>0.0502</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>imagenet-subset-50k</td>\n",
       "      <td>OpenCLIP_ViT-B-16_openai</td>\n",
       "      <td>375</td>\n",
       "      <td>attentive_probe</td>\n",
       "      <td>combined_models</td>\n",
       "      <td>stacked_zero_pad</td>\n",
       "      <td>True</td>\n",
       "      <td>[null, null, null, null]</td>\n",
       "      <td>train</td>\n",
       "      <td>0.2</td>\n",
       "      <td>...</td>\n",
       "      <td>results_iclr_exp</td>\n",
       "      <td>0.05232</td>\n",
       "      <td>0.0212</td>\n",
       "      <td>0.05232</td>\n",
       "      <td>0.0212</td>\n",
       "      <td>0.0142</td>\n",
       "      <td>0.00882</td>\n",
       "      <td>0.0142</td>\n",
       "      <td>0.00882</td>\n",
       "      <td>0.0173</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>imagenet-subset-50k</td>\n",
       "      <td>OpenCLIP_ViT-B-16_openai</td>\n",
       "      <td>454</td>\n",
       "      <td>attentive_probe</td>\n",
       "      <td>combined_models</td>\n",
       "      <td>stacked_zero_pad</td>\n",
       "      <td>True</td>\n",
       "      <td>[null, null, null, null, null, null, null, null]</td>\n",
       "      <td>train</td>\n",
       "      <td>0.2</td>\n",
       "      <td>...</td>\n",
       "      <td>results_iclr_exp</td>\n",
       "      <td>0.10088</td>\n",
       "      <td>0.03352</td>\n",
       "      <td>0.10088</td>\n",
       "      <td>0.03352</td>\n",
       "      <td>0.01698</td>\n",
       "      <td>0.01004</td>\n",
       "      <td>0.01698</td>\n",
       "      <td>0.01004</td>\n",
       "      <td>0.0306</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>imagenet-subset-50k</td>\n",
       "      <td>OpenCLIP_ViT-B-16_openai</td>\n",
       "      <td>533</td>\n",
       "      <td>linear_probe</td>\n",
       "      <td>single_model</td>\n",
       "      <td>NaN</td>\n",
       "      <td>True</td>\n",
       "      <td>null</td>\n",
       "      <td>train</td>\n",
       "      <td>0.2</td>\n",
       "      <td>...</td>\n",
       "      <td>results_iclr_exp</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 82 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "               dataset                base_model  level_2             task  \\\n",
       "0  imagenet-subset-50k  OpenCLIP_ViT-B-16_openai        0  attentive_probe   \n",
       "1  imagenet-subset-50k  OpenCLIP_ViT-B-16_openai      189  attentive_probe   \n",
       "2  imagenet-subset-50k  OpenCLIP_ViT-B-16_openai      375  attentive_probe   \n",
       "3  imagenet-subset-50k  OpenCLIP_ViT-B-16_openai      454  attentive_probe   \n",
       "4  imagenet-subset-50k  OpenCLIP_ViT-B-16_openai      533     linear_probe   \n",
       "\n",
       "              mode          combiner feature_normalization  \\\n",
       "0  combined_models  stacked_zero_pad                  True   \n",
       "1  combined_models  stacked_zero_pad                  True   \n",
       "2  combined_models  stacked_zero_pad                  True   \n",
       "3  combined_models  stacked_zero_pad                  True   \n",
       "4     single_model               NaN                  True   \n",
       "\n",
       "                                   feature_alignment train_split  \\\n",
       "0                                       [null, null]       train   \n",
       "1  [null, null, null, null, null, null, null, nul...       train   \n",
       "2                           [null, null, null, null]       train   \n",
       "3   [null, null, null, null, null, null, null, null]       train   \n",
       "4                                               null       train   \n",
       "\n",
       "  val_proportion  ...        res_folder abs_perf_gain_train_lp_acc1  \\\n",
       "0            0.2  ...  results_iclr_exp                     0.03238   \n",
       "1            0.2  ...  results_iclr_exp                     0.14866   \n",
       "2            0.2  ...  results_iclr_exp                     0.05232   \n",
       "3            0.2  ...  results_iclr_exp                     0.10088   \n",
       "4            0.2  ...  results_iclr_exp                         0.0   \n",
       "\n",
       "  abs_perf_gain_train_lp_acc5 abs_perf_gain_train_lp_bal_acc1  \\\n",
       "0                     0.01422                         0.03238   \n",
       "1                     0.04452                         0.14866   \n",
       "2                      0.0212                         0.05232   \n",
       "3                     0.03352                         0.10088   \n",
       "4                         0.0                             0.0   \n",
       "\n",
       "  abs_perf_gain_train_lp_bal_acc5 abs_perf_gain_test_lp_acc1  \\\n",
       "0                         0.01422                     0.0134   \n",
       "1                         0.04452                     0.0202   \n",
       "2                          0.0212                     0.0142   \n",
       "3                         0.03352                    0.01698   \n",
       "4                             0.0                        0.0   \n",
       "\n",
       "  abs_perf_gain_test_lp_acc5 abs_perf_gain_test_lp_bal_acc1  \\\n",
       "0                     0.0093                         0.0134   \n",
       "1                    0.01154                         0.0202   \n",
       "2                    0.00882                         0.0142   \n",
       "3                    0.01004                        0.01698   \n",
       "4                        0.0                            0.0   \n",
       "\n",
       "  abs_perf_gain_test_lp_bal_acc5 abs_perf_gain_best_val_bal_acc1  \n",
       "0                         0.0093                          0.0142  \n",
       "1                        0.01154                          0.0502  \n",
       "2                        0.00882                          0.0173  \n",
       "3                        0.01004                          0.0306  \n",
       "4                            0.0                             0.0  \n",
       "\n",
       "[5 rows x 82 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_runs_path = BASE_PATH_PROJECT / 'results_iclr_exp/aggregated/all_runs_v11.pkl'\n",
    "all_runs = pd.read_pickle(all_runs_path)\n",
    "all_runs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6eb340fc-6708-4bb3-9b3b-38f63054dedc",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_runs = all_runs[~all_runs['dataset'].isin(['imagenet-subset-50k', 'wds/imagenet1k'])].reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c424cfb4-ae4d-48fa-9db4-b73f84acea8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_cols = [\n",
    "    'abs_perf_gain_train_lp_bal_acc1',\n",
    "    'abs_perf_gain_test_lp_bal_acc1',\n",
    "    'test_lp_bal_acc1'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "87b46cb3-4f8f-47a1-a3d5-4c38153b97aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_runs[metrics_cols] = all_runs[metrics_cols].astype(float)\n",
    "all_runs[metrics_cols] *= 100"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25f9090f-0933-4b91-9f9b-cc5a7778c62f",
   "metadata": {},
   "source": [
    "### Table version with statistical tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "dcd49a42-969c-48f0-8175-f89e3f3c3bdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_if_is_bold(mean_val, std_val, is_significant):\n",
    "    formatted = f\"{mean_val:.2f} ± {std_val:.2f}\"\n",
    "    if is_significant:\n",
    "        return formatted\n",
    "    else:\n",
    "        return f\"\\\\textbf{{{formatted}}}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b26f67e6-bf0b-48a6-9447-0a4b14377bf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_cols = [\n",
    "    'abs_perf_gain_train_lp_bal_acc1_mod',\n",
    "    'abs_perf_gain_test_lp_bal_acc1_mod',\n",
    "]\n",
    "\n",
    "all_runs['abs_perf_gain_train_lp_bal_acc1_mod'] = all_runs['abs_perf_gain_train_lp_bal_acc1'].copy().astype(float)\n",
    "all_runs['abs_perf_gain_test_lp_bal_acc1_mod'] = all_runs['abs_perf_gain_test_lp_bal_acc1'].copy().astype(float)\n",
    "\n",
    "all_runs.loc[all_runs['Experiment'] == 'CLS last layer', 'abs_perf_gain_train_lp_bal_acc1_mod'] = all_runs.loc[all_runs['Experiment'] == 'CLS last layer', 'train_lp_bal_acc1'].astype(float)\n",
    "all_runs.loc[all_runs['Experiment'] == 'CLS last layer', 'abs_perf_gain_test_lp_bal_acc1_mod'] = all_runs.loc[all_runs['Experiment'] == 'CLS last layer', 'test_lp_bal_acc1'].astype(float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "74717d7e-02e3-4122-b30b-c2482ab3d973",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['CLS last layer',\n",
       " 'All tokens last layer (attentive)',\n",
       " 'CLS+AP last layer (linear)',\n",
       " 'CLS+AP layers from all blocks (linear)',\n",
       " 'CLS+AP last layer (attentive)',\n",
       " 'CLS+AP layers from all blocks (attentive)']"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "curr_order = [col for col in experiment_with_probe_type_order_list if (\"middle & last\" not in col) and (\"quarterly\" not in col)]\n",
    "curr_order"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6a6744ed-4aad-4c43-8d14-a233ca5529df",
   "metadata": {},
   "outputs": [],
   "source": [
    "subset_runs = all_runs[all_runs['Experiment'].isin(curr_order)].copy().reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "afe2a892-c7da-4039-9613-b88f340a3686",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_mapping = subset_runs[['dataset', 'dataset_fmt', 'dataset_domain']].value_counts().reset_index().set_index('dataset')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3c6e7b6b-a964-44c8-9766-feb337be5306",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_p_values_wrt_to_highest_mean(ds_data, alpha=0.05, metric_col = 'test_lp_bal_acc1'):\n",
    "    best_mean_exp_per_ds = ds_data.groupby(\"Experiment\")[metric_col].mean().idxmax()\n",
    "\n",
    "    pvalues = {best_mean_exp_per_ds: np.nan}\n",
    "    exp1_data = pd.to_numeric(ds_data[ds_data['Experiment'] == best_mean_exp_per_ds][metric_col]).values\n",
    "    assert len(exp1_data)==9\n",
    "    for exp in sorted(ds_data[\"Experiment\"].unique()):\n",
    "        if exp == best_mean_exp_per_ds:\n",
    "            continue\n",
    "        \n",
    "        exp2_data = pd.to_numeric(ds_data[ds_data['Experiment'] == exp][metric_col]).values\n",
    "        if len(exp1_data) != len(exp2_data):\n",
    "            continue\n",
    "\n",
    "        assert len(exp1_data) == len(exp2_data)\n",
    "        \n",
    "        statistic, pval = wilcoxon(exp1_data, exp2_data, alternative='greater')\n",
    "        pvalues[exp] = pval\n",
    "\n",
    "    rejected = {exp: pval < alpha for exp, pval in pvalues.items()}\n",
    "    return dict(pvalues=pvalues, rejected=rejected)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "1da64153-a9f8-4272-9d11-516a1ded6222",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_238806/2396546013.py:1: FutureWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
      "  res = subset_runs.groupby('dataset').apply(get_p_values_wrt_to_highest_mean)\n"
     ]
    }
   ],
   "source": [
    "res = subset_runs.groupby('dataset').apply(get_p_values_wrt_to_highest_mean)\n",
    "res_df = pd.DataFrame(res.tolist(), index=res.index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6357032b-9f6b-40e4-bd90-efa9fb6ac9bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "grouped_data = subset_runs.groupby([\"dataset\", \"Experiment\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "4317e070-f600-4d47-9158-7e827e26d38f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stored latex table at name='abs_perf_gain_train_lp_bal_acc1_mod_aggr_over_models_table_with_testing_v3.tex'\n",
      "Stored latex table at name='abs_perf_gain_test_lp_bal_acc1_mod_aggr_over_models_table_with_testing_v3.tex'\n"
     ]
    }
   ],
   "source": [
    "for metric_col in metrics_cols:\n",
    "    aggr_data = grouped_data[metric_col].agg([\"mean\", \"std\"]).reset_index()\n",
    "    \n",
    "    mean_pivot = pd.pivot(\n",
    "        aggr_data,\n",
    "        index='dataset',\n",
    "        columns=\"Experiment\", \n",
    "        values=\"mean\"\n",
    "    ).loc[:, curr_order]\n",
    "    \n",
    "    formatted_data = []\n",
    "    for idx, row in aggr_data.iterrows():\n",
    "        dataset = row['dataset']\n",
    "        exp = row['Experiment']\n",
    "        if dataset in mean_pivot.index and exp in mean_pivot.columns:\n",
    "            is_significant = res_df.loc[dataset, 'rejected'][exp]\n",
    "            formatted = format_if_is_bold(row['mean'], row['std'], is_significant)\n",
    "            formatted_data.append({\n",
    "                'dataset': dataset,\n",
    "                'Experiment': exp,\n",
    "                'mean_std': formatted\n",
    "            }) \n",
    "    \n",
    "    formatted_df = pd.DataFrame(formatted_data)\n",
    "    pivoted_aggr_data = pd.pivot(\n",
    "        formatted_df,\n",
    "        index='dataset',\n",
    "        columns=\"Experiment\",\n",
    "        values=\"mean_std\"\n",
    "    )\n",
    "    \n",
    "    pivoted_aggr_data = pivoted_aggr_data.loc[:, curr_order]\n",
    "    pivoted_aggr_data = pivoted_aggr_data.sort_values('CLS+AP layers from all blocks (attentive)', \n",
    "                                                      key=lambda x: x.apply(lambda val: float(val.split('±')[0].split('{')[-1].strip())))\n",
    "    \n",
    "    pivoted_aggr_data.index.name = None\n",
    "    pivoted_aggr_data.columns.name = None\n",
    "    pivoted_aggr_data = pivoted_aggr_data.reset_index(names='Dataset')\n",
    "    pivoted_aggr_data.insert(0, 'Category', ds_mapping.loc[pivoted_aggr_data[\"Dataset\"].tolist(), 'dataset_domain'].reset_index(drop=True))\n",
    "    pivoted_aggr_data['Dataset'] = ds_mapping.loc[pivoted_aggr_data[\"Dataset\"].tolist(), 'dataset_fmt'].reset_index(drop=True)\n",
    "\n",
    "    pivoted_aggr_data = pivoted_aggr_data.sort_values('Category', kind='stable')\n",
    "\n",
    "    latex_version_pivoted_aggr_data = pivoted_aggr_data.to_latex(escape=False, index=False)\n",
    "\n",
    "    if SAVE:\n",
    "        name = f\"{metric_col}_aggr_over_models_table_with_testing_v3.tex\" \n",
    "        filename = base_storing_path / 'tables_for_per_model_size_n_dataset_perf_gain' / name\n",
    "        with open(filename, 'w', encoding='utf-8') as f:\n",
    "            f.write(latex_version_pivoted_aggr_data)\n",
    "        print(f\"Stored latex table at {name=}\")\n",
    "\n",
    "    else:\n",
    "        print(metric_col)\n",
    "        print()\n",
    "        print(latex_version_pivoted_aggr_data)\n",
    "        print()\n",
    "        print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6fc47a3-bf25-43b4-bed2-7f886a411bc4",
   "metadata": {},
   "source": [
    "### Table version with highest median bold, second highest underline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ce2ef7f5-9004-40c0-bde1-7b9a511bafea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_for_latex(mean_val, std_val, row_means):\n",
    "    formatted = f\"{mean_val:.2f} ± {std_val:.2f}\"\n",
    "    \n",
    "    sorted_means = sorted(row_means, reverse=True)\n",
    "    if mean_val not in sorted_means:\n",
    "        return formatted\n",
    "        \n",
    "    rank = sorted_means.index(mean_val) + 1\n",
    "\n",
    "    if rank == 1:\n",
    "        return f\"\\\\textbf{{{formatted}}}\"\n",
    "    elif rank == 2:\n",
    "        return f\"\\\\underline{{{formatted}}}\"\n",
    "    else:\n",
    "        return formatted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "252a683d-971f-4b1b-a291-f732b38ae023",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stored latex table at name='abs_perf_gain_train_lp_bal_acc1_mod_aggr_over_models_table_without_testing_v3.tex'\n",
      "Stored latex table at name='abs_perf_gain_test_lp_bal_acc1_mod_aggr_over_models_table_without_testing_v3.tex'\n"
     ]
    }
   ],
   "source": [
    "for metric_col in metrics_cols:\n",
    "    aggr_data = grouped_data[metric_col].agg([\"mean\", \"std\"]).reset_index()\n",
    "    \n",
    "    # Get means for ranking\n",
    "    mean_pivot = pd.pivot(\n",
    "        aggr_data,\n",
    "        index='dataset',\n",
    "        columns=\"Experiment\", \n",
    "        values=\"mean\"\n",
    "    ).loc[:, curr_order]\n",
    "    ranks = mean_pivot.iloc[:, 1:].rank(axis=1, ascending=False).mean()\n",
    "    \n",
    "    formatted_data = []\n",
    "    for idx, row in aggr_data.iterrows():\n",
    "        model = row['dataset']\n",
    "        exp = row['Experiment']\n",
    "        if model in mean_pivot.index and exp in mean_pivot.columns:\n",
    "            row_means = mean_pivot.loc[model].values\n",
    "            formatted = format_for_latex(row['mean'], row['std'], row_means[1:])\n",
    "            formatted_data.append({\n",
    "                'dataset': model,\n",
    "                'Experiment': exp,\n",
    "                'mean_std': formatted\n",
    "            })\n",
    "    \n",
    "    formatted_df = pd.DataFrame(formatted_data)\n",
    "    pivoted_aggr_data = pd.pivot(\n",
    "        formatted_df,\n",
    "        index='dataset',\n",
    "        columns=\"Experiment\",\n",
    "        values=\"mean_std\"\n",
    "    )\n",
    "    pivoted_aggr_data = pivoted_aggr_data.loc[:, curr_order]\n",
    "    pivoted_aggr_data = pivoted_aggr_data.sort_values(curr_order[-1], key=lambda x: x.apply(lambda val: float(val.split('±')[0].split('{')[-1].strip())))\n",
    "    \n",
    "    pivoted_aggr_data.index.name = None\n",
    "    pivoted_aggr_data.columns.name = None\n",
    "    pivoted_aggr_data = pivoted_aggr_data.reset_index(names='Dataset')\n",
    "    pivoted_aggr_data.insert(0, 'Category', ds_mapping.loc[pivoted_aggr_data[\"Dataset\"].tolist(), 'dataset_domain'].reset_index(drop=True))\n",
    "    pivoted_aggr_data['Dataset'] = ds_mapping.loc[pivoted_aggr_data[\"Dataset\"].tolist(), 'dataset_fmt'].reset_index(drop=True)\n",
    "\n",
    "    pivoted_aggr_data = pivoted_aggr_data.sort_values('Category', kind='stable')\n",
    "    tmp = pd.concat([pivoted_aggr_data, ranks.to_frame().T])\n",
    "    latex_version_pivoted_aggr_data = tmp.to_latex(escape=False, index=False)\n",
    "\n",
    "    if SAVE:\n",
    "        name = f\"{metric_col}_aggr_over_models_table_without_testing_v3.tex\" \n",
    "        filename = base_storing_path / 'tables_for_per_model_size_n_dataset_perf_gain' / name\n",
    "        with open(filename, 'w', encoding='utf-8') as f:\n",
    "            f.write(latex_version_pivoted_aggr_data)\n",
    "        print(f\"Stored latex table at {name=}\")\n",
    "\n",
    "    else:\n",
    "        print(metric_col)\n",
    "        print()\n",
    "        print(latex_version_pivoted_aggr_data)\n",
    "        print()\n",
    "        print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cff71f51-6265-4c05-96af-0a88ed0ebbb4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
