{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import wandb\n",
    "api = wandb.Api()\n",
    "\n",
    "# Project is specified by <entity/project-name>\n",
    "runs = api.runs(\"ph10m/polarity-similarity-november\")\n",
    "\n",
    "summary_list, config_list, name_list = [], [], []\n",
    "for run in runs: \n",
    "    # .summary contains the output keys/values for metrics like accuracy.\n",
    "    #  We call ._json_dict to omit large files \n",
    "    summary_list.append(run.summary._json_dict)\n",
    "\n",
    "    # .config contains the hyperparameters.\n",
    "    #  We remove special values that start with _.\n",
    "    config_list.append(\n",
    "        {k: v for k,v in run.config.items()\n",
    "          if not k.startswith('_')})\n",
    "\n",
    "    # .name is the human-readable name of the run.\n",
    "    name_list.append(run.name)\n",
    "\n",
    "runs_df = pd.DataFrame({\n",
    "    \"summary\": summary_list,\n",
    "    \"config\": config_list,\n",
    "    \"name\": name_list\n",
    "    })\n",
    "\n",
    "# runs_df.to_csv(\"project.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>polarity_mean</th>\n",
       "      <th>polarity_std</th>\n",
       "      <th>semantic_mean</th>\n",
       "      <th>semantic_std</th>\n",
       "      <th>dataset</th>\n",
       "      <th>loss</th>\n",
       "      <th>model</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>MultipleNegatives(minilm-6)</td>\n",
       "      <td>68.978</td>\n",
       "      <td>22.022</td>\n",
       "      <td>39.852</td>\n",
       "      <td>6.081</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>MultipleNegatives</td>\n",
       "      <td>all-MiniLM-L6-v2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>OnlineContrastive(minilm-6)(lambda=1)</td>\n",
       "      <td>80.478</td>\n",
       "      <td>27.831</td>\n",
       "      <td>29.186</td>\n",
       "      <td>6.369</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>OnlineContrastive</td>\n",
       "      <td>all-MiniLM-L6-v2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>OnlineContrastive(minilm-6)(lambda=0.75)</td>\n",
       "      <td>79.969</td>\n",
       "      <td>27.357</td>\n",
       "      <td>28.516</td>\n",
       "      <td>6.457</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>OnlineContrastive</td>\n",
       "      <td>all-MiniLM-L6-v2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>OnlineContrastive(minilm-6)(lambda=0.5)</td>\n",
       "      <td>79.039</td>\n",
       "      <td>26.862</td>\n",
       "      <td>28.285</td>\n",
       "      <td>6.537</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>OnlineContrastive</td>\n",
       "      <td>all-MiniLM-L6-v2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>OnlineContrastive(minilm-6)(lambda=0.25)</td>\n",
       "      <td>78.867</td>\n",
       "      <td>26.370</td>\n",
       "      <td>27.092</td>\n",
       "      <td>6.628</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>OnlineContrastive</td>\n",
       "      <td>all-MiniLM-L6-v2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>131</th>\n",
       "      <td>Triplet(gte-base)(lambda=7.5)</td>\n",
       "      <td>90.087</td>\n",
       "      <td>24.958</td>\n",
       "      <td>79.935</td>\n",
       "      <td>1.811</td>\n",
       "      <td>sst2</td>\n",
       "      <td>Triplet</td>\n",
       "      <td>gte-base</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>132</th>\n",
       "      <td>Triplet(gte-base)(lambda=5)</td>\n",
       "      <td>90.130</td>\n",
       "      <td>25.085</td>\n",
       "      <td>79.927</td>\n",
       "      <td>1.808</td>\n",
       "      <td>sst2</td>\n",
       "      <td>Triplet</td>\n",
       "      <td>gte-base</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>133</th>\n",
       "      <td>Triplet(gte-base)(lambda=1)</td>\n",
       "      <td>90.618</td>\n",
       "      <td>24.890</td>\n",
       "      <td>80.342</td>\n",
       "      <td>1.777</td>\n",
       "      <td>sst2</td>\n",
       "      <td>Triplet</td>\n",
       "      <td>gte-base</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>134</th>\n",
       "      <td>Triplet(gte-base)(lambda=0.1)</td>\n",
       "      <td>91.184</td>\n",
       "      <td>25.641</td>\n",
       "      <td>81.870</td>\n",
       "      <td>1.658</td>\n",
       "      <td>sst2</td>\n",
       "      <td>Triplet</td>\n",
       "      <td>gte-base</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>135</th>\n",
       "      <td>Triplet(gte-base)(lambda=0.01)</td>\n",
       "      <td>90.295</td>\n",
       "      <td>24.894</td>\n",
       "      <td>81.913</td>\n",
       "      <td>1.679</td>\n",
       "      <td>sst2</td>\n",
       "      <td>Triplet</td>\n",
       "      <td>gte-base</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>136 rows × 8 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         name  polarity_mean  polarity_std  \\\n",
       "0                 MultipleNegatives(minilm-6)         68.978        22.022   \n",
       "1       OnlineContrastive(minilm-6)(lambda=1)         80.478        27.831   \n",
       "2    OnlineContrastive(minilm-6)(lambda=0.75)         79.969        27.357   \n",
       "3     OnlineContrastive(minilm-6)(lambda=0.5)         79.039        26.862   \n",
       "4    OnlineContrastive(minilm-6)(lambda=0.25)         78.867        26.370   \n",
       "..                                        ...            ...           ...   \n",
       "131             Triplet(gte-base)(lambda=7.5)         90.087        24.958   \n",
       "132               Triplet(gte-base)(lambda=5)         90.130        25.085   \n",
       "133               Triplet(gte-base)(lambda=1)         90.618        24.890   \n",
       "134             Triplet(gte-base)(lambda=0.1)         91.184        25.641   \n",
       "135            Triplet(gte-base)(lambda=0.01)         90.295        24.894   \n",
       "\n",
       "     semantic_mean  semantic_std              dataset               loss  \\\n",
       "0           39.852         6.081  sarcastic-headlines  MultipleNegatives   \n",
       "1           29.186         6.369  sarcastic-headlines  OnlineContrastive   \n",
       "2           28.516         6.457  sarcastic-headlines  OnlineContrastive   \n",
       "3           28.285         6.537  sarcastic-headlines  OnlineContrastive   \n",
       "4           27.092         6.628  sarcastic-headlines  OnlineContrastive   \n",
       "..             ...           ...                  ...                ...   \n",
       "131         79.935         1.811                 sst2            Triplet   \n",
       "132         79.927         1.808                 sst2            Triplet   \n",
       "133         80.342         1.777                 sst2            Triplet   \n",
       "134         81.870         1.658                 sst2            Triplet   \n",
       "135         81.913         1.679                 sst2            Triplet   \n",
       "\n",
       "                model  \n",
       "0    all-MiniLM-L6-v2  \n",
       "1    all-MiniLM-L6-v2  \n",
       "2    all-MiniLM-L6-v2  \n",
       "3    all-MiniLM-L6-v2  \n",
       "4    all-MiniLM-L6-v2  \n",
       "..                ...  \n",
       "131          gte-base  \n",
       "132          gte-base  \n",
       "133          gte-base  \n",
       "134          gte-base  \n",
       "135          gte-base  \n",
       "\n",
       "[136 rows x 8 columns]"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metric = [\"polarity\", \"semantic\"]\n",
    "stats = [\"mean\", \"std\"]\n",
    "\n",
    "for m in metric:\n",
    "    for s in stats:\n",
    "        _id = f\"{m}_{s}\"\n",
    "        runs_df[_id] = runs_df.summary.apply(lambda x: x[_id])\n",
    "# runtime if desired\n",
    "# runs_df[\"_runtime\"] = runs_df.summary.apply(lambda x: x[\"_runtime\"])\n",
    "\n",
    "config_keys = [\"dataset\", \"simple_model_name\", \"model\"]\n",
    "for k in config_keys:\n",
    "    runs_df[k] = runs_df.config.apply(lambda x: x[k])\n",
    "\n",
    "runs_df = runs_df.rename(columns={\"simple_model_name\": \"loss\"})\n",
    "runs_df = runs_df.drop(columns=[\"summary\", \"config\"])\n",
    "runs_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>polarity_mean</th>\n",
       "      <th>polarity_std</th>\n",
       "      <th>semantic_mean</th>\n",
       "      <th>semantic_std</th>\n",
       "      <th>dataset</th>\n",
       "      <th>loss</th>\n",
       "      <th>model</th>\n",
       "      <th>lambda</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>MultipleNegatives(minilm-6)</td>\n",
       "      <td>68.978</td>\n",
       "      <td>22.022</td>\n",
       "      <td>39.852</td>\n",
       "      <td>6.081</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>MultipleNegatives</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>OnlineContrastive(minilm-6)(lambda=1)</td>\n",
       "      <td>80.478</td>\n",
       "      <td>27.831</td>\n",
       "      <td>29.186</td>\n",
       "      <td>6.369</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>OnlineContrastive</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>1.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>OnlineContrastive(minilm-6)(lambda=0.75)</td>\n",
       "      <td>79.969</td>\n",
       "      <td>27.357</td>\n",
       "      <td>28.516</td>\n",
       "      <td>6.457</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>OnlineContrastive</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>0.75</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>OnlineContrastive(minilm-6)(lambda=0.5)</td>\n",
       "      <td>79.039</td>\n",
       "      <td>26.862</td>\n",
       "      <td>28.285</td>\n",
       "      <td>6.537</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>OnlineContrastive</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>0.50</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>OnlineContrastive(minilm-6)(lambda=0.25)</td>\n",
       "      <td>78.867</td>\n",
       "      <td>26.370</td>\n",
       "      <td>27.092</td>\n",
       "      <td>6.628</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>OnlineContrastive</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>0.25</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>131</th>\n",
       "      <td>Triplet(gte-base)(lambda=7.5)</td>\n",
       "      <td>90.087</td>\n",
       "      <td>24.958</td>\n",
       "      <td>79.935</td>\n",
       "      <td>1.811</td>\n",
       "      <td>sst2</td>\n",
       "      <td>Triplet</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>7.50</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>132</th>\n",
       "      <td>Triplet(gte-base)(lambda=5)</td>\n",
       "      <td>90.130</td>\n",
       "      <td>25.085</td>\n",
       "      <td>79.927</td>\n",
       "      <td>1.808</td>\n",
       "      <td>sst2</td>\n",
       "      <td>Triplet</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>5.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>133</th>\n",
       "      <td>Triplet(gte-base)(lambda=1)</td>\n",
       "      <td>90.618</td>\n",
       "      <td>24.890</td>\n",
       "      <td>80.342</td>\n",
       "      <td>1.777</td>\n",
       "      <td>sst2</td>\n",
       "      <td>Triplet</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>1.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>134</th>\n",
       "      <td>Triplet(gte-base)(lambda=0.1)</td>\n",
       "      <td>91.184</td>\n",
       "      <td>25.641</td>\n",
       "      <td>81.870</td>\n",
       "      <td>1.658</td>\n",
       "      <td>sst2</td>\n",
       "      <td>Triplet</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>0.10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>135</th>\n",
       "      <td>Triplet(gte-base)(lambda=0.01)</td>\n",
       "      <td>90.295</td>\n",
       "      <td>24.894</td>\n",
       "      <td>81.913</td>\n",
       "      <td>1.679</td>\n",
       "      <td>sst2</td>\n",
       "      <td>Triplet</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>0.01</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>136 rows × 9 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         name  polarity_mean  polarity_std  \\\n",
       "0                 MultipleNegatives(minilm-6)         68.978        22.022   \n",
       "1       OnlineContrastive(minilm-6)(lambda=1)         80.478        27.831   \n",
       "2    OnlineContrastive(minilm-6)(lambda=0.75)         79.969        27.357   \n",
       "3     OnlineContrastive(minilm-6)(lambda=0.5)         79.039        26.862   \n",
       "4    OnlineContrastive(minilm-6)(lambda=0.25)         78.867        26.370   \n",
       "..                                        ...            ...           ...   \n",
       "131             Triplet(gte-base)(lambda=7.5)         90.087        24.958   \n",
       "132               Triplet(gte-base)(lambda=5)         90.130        25.085   \n",
       "133               Triplet(gte-base)(lambda=1)         90.618        24.890   \n",
       "134             Triplet(gte-base)(lambda=0.1)         91.184        25.641   \n",
       "135            Triplet(gte-base)(lambda=0.01)         90.295        24.894   \n",
       "\n",
       "     semantic_mean  semantic_std              dataset               loss  \\\n",
       "0           39.852         6.081  sarcastic-headlines  MultipleNegatives   \n",
       "1           29.186         6.369  sarcastic-headlines  OnlineContrastive   \n",
       "2           28.516         6.457  sarcastic-headlines  OnlineContrastive   \n",
       "3           28.285         6.537  sarcastic-headlines  OnlineContrastive   \n",
       "4           27.092         6.628  sarcastic-headlines  OnlineContrastive   \n",
       "..             ...           ...                  ...                ...   \n",
       "131         79.935         1.811                 sst2            Triplet   \n",
       "132         79.927         1.808                 sst2            Triplet   \n",
       "133         80.342         1.777                 sst2            Triplet   \n",
       "134         81.870         1.658                 sst2            Triplet   \n",
       "135         81.913         1.679                 sst2            Triplet   \n",
       "\n",
       "        model  lambda  \n",
       "0    minilm-6     NaN  \n",
       "1    minilm-6    1.00  \n",
       "2    minilm-6    0.75  \n",
       "3    minilm-6    0.50  \n",
       "4    minilm-6    0.25  \n",
       "..        ...     ...  \n",
       "131  gte-base    7.50  \n",
       "132  gte-base    5.00  \n",
       "133  gte-base    1.00  \n",
       "134  gte-base    0.10  \n",
       "135  gte-base    0.01  \n",
       "\n",
       "[136 rows x 9 columns]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import re\n",
    "import sys\n",
    "sys.path.insert(0, \"../\")\n",
    "from util import model_translations\n",
    "\n",
    "# rename model:\n",
    "runs_df[\"model\"] = runs_df.model.apply(lambda x: model_translations[x])\n",
    "\n",
    "def get_lambda(x):\n",
    "    # the model_name consists of LOSS(MODEL-TYPE)(LAMBDA=FLOATVALUE)\n",
    "    pattern = r\"lambda=(\\d+\\.?\\d*)\"\n",
    "    match = re.search(pattern, x)\n",
    "    if match:\n",
    "        return float(match.group(1))\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "runs_df[\"lambda\"] = runs_df.name.apply(get_lambda)\n",
    "runs_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "polarity\n",
      "Average & - & 88.5 & 90.3 & 86.8 & 89.8 & 84.9 & 88.4 & 79.2 & 76.6 \\\\\n",
      "semantic\n",
      "Average & - & 80.0 & 83.5 & 76.7 & 81.3 & 78.6 & 82.3 & 26.7 & 33.7 \\\\\n"
     ]
    }
   ],
   "source": [
    "# group on the model and compute the average mean \n",
    "from IPython.display import display\n",
    "\n",
    "for m in metric:\n",
    "    _mean = f\"{m}_mean\"\n",
    "    # use only the metric m_mean\n",
    "    model_df = runs_df[[\"model\", \"dataset\", _mean]]\n",
    "    # rename _mean to avg\n",
    "    model_df = model_df.rename(columns={_mean: f\"{m}_avg\"})\n",
    "    model_grouped = model_df.groupby([\"model\", \"dataset\"]).mean().round(1)[f\"{m}_avg\"].values\n",
    "    # to latex:\n",
    "    print(m)\n",
    "    print(f\"Average & - & \" + \" & \".join([str(x) for x in model_grouped]) + \" \\\\\\\\\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "polarity e5-small sarcastic-headlines => 90.64\n",
      "semantic e5-small sarcastic-headlines => 82.54\n",
      "polarity e5-small sst2 => 91.88\n",
      "semantic e5-small sst2 => 84.74\n",
      "polarity gte-base sarcastic-headlines => 89.65\n",
      "semantic gte-base sarcastic-headlines => 80.43\n",
      "polarity gte-base sst2 => 91.18\n",
      "semantic gte-base sst2 => 82.55\n",
      "polarity gte-small sarcastic-headlines => 88.43\n",
      "semantic gte-small sarcastic-headlines => 81.64\n",
      "polarity gte-small sst2 => 89.89\n",
      "semantic gte-small sst2 => 83.87\n",
      "polarity minilm-6 sarcastic-headlines => 84.07\n",
      "semantic minilm-6 sarcastic-headlines => 39.85\n",
      "polarity minilm-6 sst2 => 83.16\n",
      "semantic minilm-6 sst2 => 43.5\n"
     ]
    }
   ],
   "source": [
    "# iterate through each model:\n",
    "from IPython.display import display\n",
    "runs_df = runs_df.round(2)\n",
    "\n",
    "max_values = {}\n",
    "\n",
    "for model in sorted(runs_df.model.unique()):\n",
    "    for dataset in model_df.dataset.unique():\n",
    "        subset = runs_df[(runs_df.model == model) & (runs_df.dataset == dataset)]\n",
    "        subset = subset.sort_values(by=\"name\")\n",
    "    \n",
    "        for m in metric:\n",
    "            max_values[(m, model, dataset)] = subset[f\"{m}_mean\"].max()\n",
    "for (m, model, dataset), value in max_values.items():\n",
    "    print(m, model, dataset, \"=>\", value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "runs_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_max(row, _metric):\n",
    "    m = row.model\n",
    "    d = row.dataset\n",
    "    return max_values[(_metric, m, d)]\n",
    "\n",
    "for m in metric:\n",
    "    runs_df[f\"{m}_max\"] = runs_df.apply(lambda x: apply_max(x, m), axis=1)\n",
    "    runs_df[f\"{m}_max\"] = runs_df.apply(lambda x: x[f\"{m}_max\"] == x[f\"{m}_mean\"], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a latex string for ${mean}_{std}$ for each metric\n",
    "from collections import defaultdict\n",
    "latex_strings = defaultdict(list)\n",
    "\n",
    "for m in metric:\n",
    "    # all mean val:\n",
    "    mean_vals = runs_df[f\"{m}_mean\"].values\n",
    "    std_vals = runs_df[f\"{m}_std\"].values\n",
    "    is_max = runs_df[f\"{m}_max\"].values\n",
    "    # create latex string\n",
    "    latex_vals = []\n",
    "    for mean, std, maximum in zip(mean_vals, std_vals, is_max):\n",
    "        if maximum:\n",
    "            latex_vals.append(f\"$\\\\mathbf{{{float(mean):.1f}}}_{{{float(std):.1f}}}$\")\n",
    "        else:\n",
    "            latex_vals.append(f\"${float(mean):.1f}_{{{float(std):.1f}}}$\")\n",
    "    runs_df[m] = latex_vals\n",
    "\n",
    "runs_df = runs_df.drop(columns=[f\"{m}_mean\" for m in [\"polarity\", \"semantic\"]])\n",
    "runs_df = runs_df.drop(columns=[f\"{m}_std\" for m in [\"polarity\", \"semantic\"]])\n",
    "runs_df = runs_df.drop(columns=[f\"{m}_max\" for m in [\"polarity\", \"semantic\"]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>dataset</th>\n",
       "      <th>loss</th>\n",
       "      <th>model</th>\n",
       "      <th>lambda</th>\n",
       "      <th>polarity</th>\n",
       "      <th>semantic</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>MultipleNegatives(minilm-6)</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>MultipleNegatives</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>NaN</td>\n",
       "      <td>$69.0_{22.0}$</td>\n",
       "      <td>$\\mathbf{39.9}_{6.1}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>OnlineContrastive(minilm-6)(lambda=1)</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>OnlineContrastive</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>1.00</td>\n",
       "      <td>$80.5_{27.8}$</td>\n",
       "      <td>$29.2_{6.4}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>OnlineContrastive(minilm-6)(lambda=0.75)</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>OnlineContrastive</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>0.75</td>\n",
       "      <td>$80.0_{27.4}$</td>\n",
       "      <td>$28.5_{6.5}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>OnlineContrastive(minilm-6)(lambda=0.5)</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>OnlineContrastive</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>0.50</td>\n",
       "      <td>$79.0_{26.9}$</td>\n",
       "      <td>$28.3_{6.5}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>OnlineContrastive(minilm-6)(lambda=0.25)</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>OnlineContrastive</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>0.25</td>\n",
       "      <td>$78.9_{26.4}$</td>\n",
       "      <td>$27.1_{6.6}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>131</th>\n",
       "      <td>Triplet(gte-base)(lambda=7.5)</td>\n",
       "      <td>sst2</td>\n",
       "      <td>Triplet</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>7.50</td>\n",
       "      <td>$90.1_{25.0}$</td>\n",
       "      <td>$79.9_{1.8}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>132</th>\n",
       "      <td>Triplet(gte-base)(lambda=5)</td>\n",
       "      <td>sst2</td>\n",
       "      <td>Triplet</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>5.00</td>\n",
       "      <td>$90.1_{25.1}$</td>\n",
       "      <td>$79.9_{1.8}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>133</th>\n",
       "      <td>Triplet(gte-base)(lambda=1)</td>\n",
       "      <td>sst2</td>\n",
       "      <td>Triplet</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>1.00</td>\n",
       "      <td>$90.6_{24.9}$</td>\n",
       "      <td>$80.3_{1.8}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>134</th>\n",
       "      <td>Triplet(gte-base)(lambda=0.1)</td>\n",
       "      <td>sst2</td>\n",
       "      <td>Triplet</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>0.10</td>\n",
       "      <td>$\\mathbf{91.2}_{25.6}$</td>\n",
       "      <td>$81.9_{1.7}$</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>135</th>\n",
       "      <td>Triplet(gte-base)(lambda=0.01)</td>\n",
       "      <td>sst2</td>\n",
       "      <td>Triplet</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>0.01</td>\n",
       "      <td>$90.3_{24.9}$</td>\n",
       "      <td>$81.9_{1.7}$</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>136 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         name              dataset  \\\n",
       "0                 MultipleNegatives(minilm-6)  sarcastic-headlines   \n",
       "1       OnlineContrastive(minilm-6)(lambda=1)  sarcastic-headlines   \n",
       "2    OnlineContrastive(minilm-6)(lambda=0.75)  sarcastic-headlines   \n",
       "3     OnlineContrastive(minilm-6)(lambda=0.5)  sarcastic-headlines   \n",
       "4    OnlineContrastive(minilm-6)(lambda=0.25)  sarcastic-headlines   \n",
       "..                                        ...                  ...   \n",
       "131             Triplet(gte-base)(lambda=7.5)                 sst2   \n",
       "132               Triplet(gte-base)(lambda=5)                 sst2   \n",
       "133               Triplet(gte-base)(lambda=1)                 sst2   \n",
       "134             Triplet(gte-base)(lambda=0.1)                 sst2   \n",
       "135            Triplet(gte-base)(lambda=0.01)                 sst2   \n",
       "\n",
       "                  loss     model  lambda                polarity  \\\n",
       "0    MultipleNegatives  minilm-6     NaN           $69.0_{22.0}$   \n",
       "1    OnlineContrastive  minilm-6    1.00           $80.5_{27.8}$   \n",
       "2    OnlineContrastive  minilm-6    0.75           $80.0_{27.4}$   \n",
       "3    OnlineContrastive  minilm-6    0.50           $79.0_{26.9}$   \n",
       "4    OnlineContrastive  minilm-6    0.25           $78.9_{26.4}$   \n",
       "..                 ...       ...     ...                     ...   \n",
       "131            Triplet  gte-base    7.50           $90.1_{25.0}$   \n",
       "132            Triplet  gte-base    5.00           $90.1_{25.1}$   \n",
       "133            Triplet  gte-base    1.00           $90.6_{24.9}$   \n",
       "134            Triplet  gte-base    0.10  $\\mathbf{91.2}_{25.6}$   \n",
       "135            Triplet  gte-base    0.01           $90.3_{24.9}$   \n",
       "\n",
       "                  semantic  \n",
       "0    $\\mathbf{39.9}_{6.1}$  \n",
       "1             $29.2_{6.4}$  \n",
       "2             $28.5_{6.5}$  \n",
       "3             $28.3_{6.5}$  \n",
       "4             $27.1_{6.6}$  \n",
       "..                     ...  \n",
       "131           $79.9_{1.8}$  \n",
       "132           $79.9_{1.8}$  \n",
       "133           $80.3_{1.8}$  \n",
       "134           $81.9_{1.7}$  \n",
       "135           $81.9_{1.7}$  \n",
       "\n",
       "[136 rows x 7 columns]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "runs_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>polarity</th>\n",
       "      <th>semantic</th>\n",
       "      <th>model_for_data</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>MultipleNeg</td>\n",
       "      <td>$69.0_{22.0}$</td>\n",
       "      <td>$\\mathbf{39.9}_{6.1}$</td>\n",
       "      <td>minilm-6_sarcastic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>OnlineContr 1.0</td>\n",
       "      <td>$80.5_{27.8}$</td>\n",
       "      <td>$29.2_{6.4}$</td>\n",
       "      <td>minilm-6_sarcastic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>OnlineContr 0.75</td>\n",
       "      <td>$80.0_{27.4}$</td>\n",
       "      <td>$28.5_{6.5}$</td>\n",
       "      <td>minilm-6_sarcastic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>OnlineContr 0.5</td>\n",
       "      <td>$79.0_{26.9}$</td>\n",
       "      <td>$28.3_{6.5}$</td>\n",
       "      <td>minilm-6_sarcastic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>OnlineContr 0.25</td>\n",
       "      <td>$78.9_{26.4}$</td>\n",
       "      <td>$27.1_{6.6}$</td>\n",
       "      <td>minilm-6_sarcastic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>131</th>\n",
       "      <td>Triplet 7.5</td>\n",
       "      <td>$90.1_{25.0}$</td>\n",
       "      <td>$79.9_{1.8}$</td>\n",
       "      <td>gte-base_sst2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>132</th>\n",
       "      <td>Triplet 5.0</td>\n",
       "      <td>$90.1_{25.1}$</td>\n",
       "      <td>$79.9_{1.8}$</td>\n",
       "      <td>gte-base_sst2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>133</th>\n",
       "      <td>Triplet 1.0</td>\n",
       "      <td>$90.6_{24.9}$</td>\n",
       "      <td>$80.3_{1.8}$</td>\n",
       "      <td>gte-base_sst2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>134</th>\n",
       "      <td>Triplet 0.1</td>\n",
       "      <td>$\\mathbf{91.2}_{25.6}$</td>\n",
       "      <td>$81.9_{1.7}$</td>\n",
       "      <td>gte-base_sst2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>135</th>\n",
       "      <td>Triplet 0.01</td>\n",
       "      <td>$90.3_{24.9}$</td>\n",
       "      <td>$81.9_{1.7}$</td>\n",
       "      <td>gte-base_sst2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>136 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                 loss                polarity               semantic  \\\n",
       "0         MultipleNeg           $69.0_{22.0}$  $\\mathbf{39.9}_{6.1}$   \n",
       "1     OnlineContr 1.0           $80.5_{27.8}$           $29.2_{6.4}$   \n",
       "2    OnlineContr 0.75           $80.0_{27.4}$           $28.5_{6.5}$   \n",
       "3     OnlineContr 0.5           $79.0_{26.9}$           $28.3_{6.5}$   \n",
       "4    OnlineContr 0.25           $78.9_{26.4}$           $27.1_{6.6}$   \n",
       "..                ...                     ...                    ...   \n",
       "131       Triplet 7.5           $90.1_{25.0}$           $79.9_{1.8}$   \n",
       "132       Triplet 5.0           $90.1_{25.1}$           $79.9_{1.8}$   \n",
       "133       Triplet 1.0           $90.6_{24.9}$           $80.3_{1.8}$   \n",
       "134       Triplet 0.1  $\\mathbf{91.2}_{25.6}$           $81.9_{1.7}$   \n",
       "135      Triplet 0.01           $90.3_{24.9}$           $81.9_{1.7}$   \n",
       "\n",
       "         model_for_data  \n",
       "0    minilm-6_sarcastic  \n",
       "1    minilm-6_sarcastic  \n",
       "2    minilm-6_sarcastic  \n",
       "3    minilm-6_sarcastic  \n",
       "4    minilm-6_sarcastic  \n",
       "..                  ...  \n",
       "131       gte-base_sst2  \n",
       "132       gte-base_sst2  \n",
       "133       gte-base_sst2  \n",
       "134       gte-base_sst2  \n",
       "135       gte-base_sst2  \n",
       "\n",
       "[136 rows x 4 columns]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from util import model_translations\n",
    "\n",
    "loss_translations = {\n",
    "    \"MultipleNegatives\": \"MultipleNeg\",\n",
    "    \"OnlineContrastive\": \"OnlineContr\",\n",
    "    \"Triplet\": \"Triplet\",\n",
    "    \"Contrastive\": \"Contrastive\",\n",
    "}\n",
    "\n",
    "df = runs_df.copy()\n",
    "# df.model = df.model.map(model_translations)\n",
    "\n",
    "df[\"loss\"] = df.loss.map(loss_translations)\n",
    "df[\"loss\"] = df[\"loss\"] + \" \" + df[\"lambda\"].map(str)\n",
    "df[\"loss\"] = df[\"loss\"].map(lambda x: x.replace(\"nan\", \"\").strip())\n",
    "df = df.drop(columns=[\"lambda\"])\n",
    "\n",
    "dataset_translation = {\n",
    "    \"sarcastic-headlines\": \"sarcastic\",\n",
    "    \"sst2\": \"sst2\",\n",
    "}\n",
    "df.dataset = df.dataset.map(dataset_translation)\n",
    "df[\"model_for_data\"] = df[\"model\"] + \"_\" + df[\"dataset\"]\n",
    "df = df.drop(columns=[\"name\", \"model\", \"dataset\"])  # not interested in the combined score here\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>polarity</th>\n",
       "      <th>semantic</th>\n",
       "      <th>model_for_data</th>\n",
       "      <th>polarity_max</th>\n",
       "      <th>semantic_max</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>MultipleNeg</td>\n",
       "      <td>$69.0_{22.0}$</td>\n",
       "      <td>$\\mathbf{39.9_{6.1}}$</td>\n",
       "      <td>minilm-6_sarcastic</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>OnlineContr 1.0</td>\n",
       "      <td>$80.5_{27.8}$</td>\n",
       "      <td>$29.2_{6.4}$</td>\n",
       "      <td>minilm-6_sarcastic</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>OnlineContr 0.75</td>\n",
       "      <td>$80.0_{27.4}$</td>\n",
       "      <td>$28.5_{6.5}$</td>\n",
       "      <td>minilm-6_sarcastic</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>OnlineContr 0.5</td>\n",
       "      <td>$79.0_{26.9}$</td>\n",
       "      <td>$28.3_{6.5}$</td>\n",
       "      <td>minilm-6_sarcastic</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>OnlineContr 0.25</td>\n",
       "      <td>$78.9_{26.4}$</td>\n",
       "      <td>$27.1_{6.6}$</td>\n",
       "      <td>minilm-6_sarcastic</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>131</th>\n",
       "      <td>Triplet 7.5</td>\n",
       "      <td>$90.1_{25.0}$</td>\n",
       "      <td>$79.9_{1.8}$</td>\n",
       "      <td>gte-base_sst2</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>132</th>\n",
       "      <td>Triplet 5.0</td>\n",
       "      <td>$90.1_{25.1}$</td>\n",
       "      <td>$79.9_{1.8}$</td>\n",
       "      <td>gte-base_sst2</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>133</th>\n",
       "      <td>Triplet 1.0</td>\n",
       "      <td>$90.6_{24.9}$</td>\n",
       "      <td>$80.3_{1.8}$</td>\n",
       "      <td>gte-base_sst2</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>134</th>\n",
       "      <td>Triplet 0.1</td>\n",
       "      <td>$\\mathbf{91.2_{25.6}}$</td>\n",
       "      <td>$81.9_{1.7}$</td>\n",
       "      <td>gte-base_sst2</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>135</th>\n",
       "      <td>Triplet 0.01</td>\n",
       "      <td>$90.3_{24.9}$</td>\n",
       "      <td>$81.9_{1.7}$</td>\n",
       "      <td>gte-base_sst2</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>136 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                 loss                polarity               semantic  \\\n",
       "0         MultipleNeg           $69.0_{22.0}$  $\\mathbf{39.9_{6.1}}$   \n",
       "1     OnlineContr 1.0           $80.5_{27.8}$           $29.2_{6.4}$   \n",
       "2    OnlineContr 0.75           $80.0_{27.4}$           $28.5_{6.5}$   \n",
       "3     OnlineContr 0.5           $79.0_{26.9}$           $28.3_{6.5}$   \n",
       "4    OnlineContr 0.25           $78.9_{26.4}$           $27.1_{6.6}$   \n",
       "..                ...                     ...                    ...   \n",
       "131       Triplet 7.5           $90.1_{25.0}$           $79.9_{1.8}$   \n",
       "132       Triplet 5.0           $90.1_{25.1}$           $79.9_{1.8}$   \n",
       "133       Triplet 1.0           $90.6_{24.9}$           $80.3_{1.8}$   \n",
       "134       Triplet 0.1  $\\mathbf{91.2_{25.6}}$           $81.9_{1.7}$   \n",
       "135      Triplet 0.01           $90.3_{24.9}$           $81.9_{1.7}$   \n",
       "\n",
       "         model_for_data  polarity_max  semantic_max  \n",
       "0    minilm-6_sarcastic         False          True  \n",
       "1    minilm-6_sarcastic         False         False  \n",
       "2    minilm-6_sarcastic         False         False  \n",
       "3    minilm-6_sarcastic         False         False  \n",
       "4    minilm-6_sarcastic         False         False  \n",
       "..                  ...           ...           ...  \n",
       "131       gte-base_sst2         False         False  \n",
       "132       gte-base_sst2         False         False  \n",
       "133       gte-base_sst2         False         False  \n",
       "134       gte-base_sst2          True         False  \n",
       "135       gte-base_sst2         False         False  \n",
       "\n",
       "[136 rows x 6 columns]"
      ]
     },
     "execution_count": 174,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# where the \"polarity_max\" is True, boldface the value\n",
    "def boldface_max(x, metric):\n",
    "    if x[f\"{metric}_max\"]:\n",
    "        body = x[metric].replace(\"$\", \"\")\n",
    "        return f\"$\\mathbf{{{body}}}$\"\n",
    "    else:\n",
    "        return x[metric]\n",
    "\n",
    "max_df[\"polarity\"] = max_df.apply(lambda x: boldface_max(x, \"polarity\"), axis=1)\n",
    "max_df[\"semantic\"] = max_df.apply(lambda x: boldface_max(x, \"semantic\"), axis=1)\n",
    "max_df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{llllllllll}\n",
      "\\toprule\n",
      "loss & lambda & e5-small_sarcastic & e5-small_sst2 & gte-base_sarcastic & gte-base_sst2 & gte-small_sarcastic & gte-small_sst2 & minilm-6_sarcastic & minilm-6_sst2 \\\\\n",
      "\\midrule\n",
      "Contrastive & 0.10 & $88.8_{24.3}$ & $89.5_{23.2}$ & $86.9_{25.6}$ & $89.2_{24.1}$ & $81.9_{27.0}$ & $88.0_{25.6}$ & $75.9_{25.6}$ & $68.0_{24.9}$ \\\\\n",
      "Contrastive & 0.25 & $89.3_{25.1}$ & $90.7_{23.2}$ & $88.2_{26.1}$ & $90.0_{25.0}$ & $84.3_{26.9}$ & $88.8_{26.1}$ & $76.8_{26.4}$ & $72.4_{26.9}$ \\\\\n",
      "Contrastive & 0.50 & $89.8_{25.6}$ & $91.2_{23.8}$ & $88.8_{26.5}$ & $90.3_{25.3}$ & $86.8_{27.5}$ & $89.1_{27.1}$ & $77.8_{27.2}$ & $75.1_{27.8}$ \\\\\n",
      "Contrastive & 0.75 & $89.9_{25.1}$ & $91.6_{23.6}$ & $88.9_{26.6}$ & $90.6_{25.1}$ & $87.7_{27.3}$ & $89.5_{26.9}$ & $79.0_{27.7}$ & $77.3_{28.6}$ \\\\\n",
      "Contrastive & 1.00 & $89.8_{25.5}$ & $91.2_{24.3}$ & $88.7_{26.7}$ & $90.7_{25.1}$ & $87.8_{27.0}$ & $89.6_{26.8}$ & $80.3_{28.1}$ & $78.4_{28.9}$ \\\\\n",
      "MultipleNeg & - & $73.6_{22.2}$ & $80.8_{22.4}$ & $73.1_{22.4}$ & $81.8_{23.5}$ & $72.0_{22.6}$ & $80.6_{23.4}$ & $69.0_{22.0}$ & $69.4_{23.1}$ \\\\\n",
      "OnlineContr & 0.10 & $89.6_{24.7}$ & $90.4_{23.7}$ & $87.4_{25.8}$ & $89.5_{24.2}$ & $82.6_{27.0}$ & $88.2_{25.8}$ & $78.9_{26.0}$ & $70.8_{26.5}$ \\\\\n",
      "OnlineContr & 0.25 & $90.0_{25.2}$ & $91.5_{23.8}$ & $88.2_{26.4}$ & $90.2_{25.4}$ & $84.4_{27.3}$ & $88.9_{26.7}$ & $78.9_{26.4}$ & $74.6_{27.8}$ \\\\\n",
      "OnlineContr & 0.50 & $89.7_{25.9}$ & $91.6_{24.4}$ & $88.2_{27.3}$ & $90.6_{26.0}$ & $86.0_{27.6}$ & $89.3_{27.2}$ & $79.0_{26.9}$ & $76.5_{27.9}$ \\\\\n",
      "OnlineContr & 0.75 & $89.5_{26.5}$ & $91.7_{24.5}$ & $88.6_{27.4}$ & $90.8_{25.6}$ & $87.2_{27.9}$ & $89.2_{27.6}$ & $80.0_{27.4}$ & $77.5_{28.2}$ \\\\\n",
      "OnlineContr & 1.00 & $89.6_{26.6}$ & $91.7_{25.0}$ & $88.3_{27.3}$ & $90.7_{26.0}$ & $87.5_{27.7}$ & $89.6_{27.5}$ & $80.5_{27.8}$ & $78.4_{28.7}$ \\\\\n",
      "Triplet & 0.01 & $90.2_{25.6}$ & $91.5_{25.1}$ & $82.5_{25.7}$ & $90.3_{24.9}$ & $84.0_{25.5}$ & $89.1_{26.2}$ & $78.5_{24.5}$ & $76.9_{26.9}$ \\\\\n",
      "Triplet & 0.10 & $\\mathbf{90.6}_{26.3}$ & $\\mathbf{91.9}_{25.0}$ & $\\mathbf{89.7}_{27.1}$ & $\\mathbf{91.2}_{25.6}$ & $\\mathbf{88.4}_{27.2}$ & $\\mathbf{89.9}_{27.0}$ & $83.5_{26.9}$ & $80.6_{28.6}$ \\\\\n",
      "Triplet & 1.00 & $90.1_{25.7}$ & $90.9_{23.5}$ & $88.4_{26.6}$ & $90.6_{24.9}$ & $87.4_{27.0}$ & $88.6_{25.7}$ & $\\mathbf{84.1}_{28.6}$ & $\\mathbf{83.2}_{31.1}$ \\\\\n",
      "Triplet & 5.00 & $88.2_{25.1}$ & $89.3_{23.4}$ & $86.5_{26.8}$ & $90.1_{25.1}$ & $84.9_{26.5}$ & $88.2_{26.1}$ & $81.5_{27.7}$ & $81.3_{30.1}$ \\\\\n",
      "Triplet & 7.50 & $88.2_{25.4}$ & $89.6_{23.1}$ & $86.6_{27.0}$ & $90.1_{25.0}$ & $84.8_{26.4}$ & $88.2_{25.9}$ & $81.4_{27.8}$ & $81.5_{30.1}$ \\\\\n",
      "Triplet & 10.00 & $88.1_{25.1}$ & $89.6_{22.9}$ & $86.8_{26.6}$ & $90.2_{24.9}$ & $84.8_{26.8}$ & $88.1_{26.2}$ & $81.6_{27.8}$ & $81.2_{30.4}$ \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def transpose_and_group(df, metric=\"polarity\", as_latex=False):\n",
    "    tmp = df.pivot(index=\"model_for_data\", columns=\"loss\", values=metric).reset_index().T\n",
    "    column_names = tmp.iloc[0]\n",
    "    tmp = tmp.drop(tmp.index[0])\n",
    "    idx = tmp.index\n",
    "    tmp = tmp.reset_index(drop=True)\n",
    "    tmp.columns = column_names\n",
    "\n",
    "    tmp[\"loss\"] = idx\n",
    "    tmp[\"lambda\"] = tmp[\"loss\"].map(lambda x: float(x.split(\" \")[1]) if len(x.split(\" \")) > 1 else \"-\")\n",
    "    tmp[\"loss\"] = tmp[\"loss\"].map(lambda x: x.split(\" \")[0])\n",
    "    tmp = tmp[[\"loss\", \"lambda\"] + list(tmp.columns[:-2])]\n",
    "    tmp = tmp.sort_values(by=[\"loss\", \"lambda\"], ascending=True)\n",
    "    tmp.columns.name = None\n",
    "\n",
    "    if as_latex:\n",
    "        return tmp.to_latex(float_format=\"%.2f\", index=False)\n",
    "    return tmp\n",
    "\n",
    "# pd.DataFrame(transpose_and_group(_df, metric=\"polarity\"))\n",
    "print(transpose_and_group(df, metric=\"polarity\", as_latex=True))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "simcse",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
