{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from util import get_data_from_api\n",
    "# data = get_data_from_api(\"INSERT_USER/INSERT_RUN_ID\")\n",
    "# df = pd.DataFrame(data)\n",
    "# df.to_csv(\"50k.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>polarity</th>\n",
       "      <th>semantic</th>\n",
       "      <th>model_for_data</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>MultipleNegatives</td>\n",
       "      <td>0.740</td>\n",
       "      <td>0.426</td>\n",
       "      <td>minilm-6_sst2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>OnlineContrastive 1.0</td>\n",
       "      <td>0.827</td>\n",
       "      <td>0.325</td>\n",
       "      <td>minilm-6_sst2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>OnlineContrastive 0.75</td>\n",
       "      <td>0.818</td>\n",
       "      <td>0.322</td>\n",
       "      <td>minilm-6_sst2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>OnlineContrastive 0.5</td>\n",
       "      <td>0.803</td>\n",
       "      <td>0.317</td>\n",
       "      <td>minilm-6_sst2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>OnlineContrastive 0.25</td>\n",
       "      <td>0.785</td>\n",
       "      <td>0.314</td>\n",
       "      <td>minilm-6_sst2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>131</th>\n",
       "      <td>Triplet 7.5</td>\n",
       "      <td>0.887</td>\n",
       "      <td>0.745</td>\n",
       "      <td>gte-base_sarcastic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>132</th>\n",
       "      <td>Triplet 5.0</td>\n",
       "      <td>0.887</td>\n",
       "      <td>0.745</td>\n",
       "      <td>gte-base_sarcastic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>133</th>\n",
       "      <td>Triplet 1.0</td>\n",
       "      <td>0.900</td>\n",
       "      <td>0.766</td>\n",
       "      <td>gte-base_sarcastic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>134</th>\n",
       "      <td>Triplet 0.1</td>\n",
       "      <td>0.907</td>\n",
       "      <td>0.774</td>\n",
       "      <td>gte-base_sarcastic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>135</th>\n",
       "      <td>Triplet 0.01</td>\n",
       "      <td>0.887</td>\n",
       "      <td>0.766</td>\n",
       "      <td>gte-base_sarcastic</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>136 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                       loss  polarity  semantic      model_for_data\n",
       "0         MultipleNegatives     0.740     0.426       minilm-6_sst2\n",
       "1     OnlineContrastive 1.0     0.827     0.325       minilm-6_sst2\n",
       "2    OnlineContrastive 0.75     0.818     0.322       minilm-6_sst2\n",
       "3     OnlineContrastive 0.5     0.803     0.317       minilm-6_sst2\n",
       "4    OnlineContrastive 0.25     0.785     0.314       minilm-6_sst2\n",
       "..                      ...       ...       ...                 ...\n",
       "131             Triplet 7.5     0.887     0.745  gte-base_sarcastic\n",
       "132             Triplet 5.0     0.887     0.745  gte-base_sarcastic\n",
       "133             Triplet 1.0     0.900     0.766  gte-base_sarcastic\n",
       "134             Triplet 0.1     0.907     0.774  gte-base_sarcastic\n",
       "135            Triplet 0.01     0.887     0.766  gte-base_sarcastic\n",
       "\n",
       "[136 rows x 4 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from util import model_translations\n",
    "\n",
    "df = pd.read_csv(\"50k.csv\")\n",
    "df.model = df.model.map(model_translations)\n",
    "df.loss = df.model_id.map(lambda x: x.split(\"(\")[0])\n",
    "df = df.drop(columns=[\"model_id\", \"samples\"])\n",
    "df[\"loss\"] = df[\"loss\"] + \" \" + df[\"lambda\"].map(str)\n",
    "df[\"loss\"] = df[\"loss\"].map(lambda x: x.replace(\"nan\", \"\").strip())\n",
    "df = df.drop(columns=[\"lambda\"])\n",
    "\n",
    "dataset_translation = {\n",
    "    \"sarcastic-headlines\": \"sarcastic\",\n",
    "    \"sst2\": \"sst2\",\n",
    "}\n",
    "df.dataset = df.dataset.map(dataset_translation)\n",
    "df[\"model_for_data\"] = df[\"model\"] + \"_\" + df[\"dataset\"]\n",
    "df = df.drop(columns=[\"model\", \"dataset\", \"combined\"])  # not interested in the combined score here\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model_for_data</th>\n",
       "      <th>loss</th>\n",
       "      <th>polarity</th>\n",
       "      <th>semantic</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>e5-small_sarcastic</td>\n",
       "      <td>Contrastive 0.1</td>\n",
       "      <td>0.910</td>\n",
       "      <td>0.791</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>51</th>\n",
       "      <td>gte-base_sst2</td>\n",
       "      <td>Contrastive 0.1</td>\n",
       "      <td>0.909</td>\n",
       "      <td>0.805</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>85</th>\n",
       "      <td>gte-small_sst2</td>\n",
       "      <td>Contrastive 0.1</td>\n",
       "      <td>0.893</td>\n",
       "      <td>0.818</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>gte-base_sarcastic</td>\n",
       "      <td>Contrastive 0.1</td>\n",
       "      <td>0.901</td>\n",
       "      <td>0.746</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>102</th>\n",
       "      <td>minilm-6_sarcastic</td>\n",
       "      <td>Contrastive 0.1</td>\n",
       "      <td>0.836</td>\n",
       "      <td>0.215</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50</th>\n",
       "      <td>gte-base_sarcastic</td>\n",
       "      <td>Triplet 7.5</td>\n",
       "      <td>0.887</td>\n",
       "      <td>0.745</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>e5-small_sst2</td>\n",
       "      <td>Triplet 7.5</td>\n",
       "      <td>0.908</td>\n",
       "      <td>0.819</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>e5-small_sarcastic</td>\n",
       "      <td>Triplet 7.5</td>\n",
       "      <td>0.892</td>\n",
       "      <td>0.784</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>118</th>\n",
       "      <td>minilm-6_sarcastic</td>\n",
       "      <td>Triplet 7.5</td>\n",
       "      <td>0.853</td>\n",
       "      <td>0.201</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>135</th>\n",
       "      <td>minilm-6_sst2</td>\n",
       "      <td>Triplet 7.5</td>\n",
       "      <td>0.833</td>\n",
       "      <td>0.304</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>136 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         model_for_data             loss  polarity  semantic\n",
       "0    e5-small_sarcastic  Contrastive 0.1     0.910     0.791\n",
       "51        gte-base_sst2  Contrastive 0.1     0.909     0.805\n",
       "85       gte-small_sst2  Contrastive 0.1     0.893     0.818\n",
       "34   gte-base_sarcastic  Contrastive 0.1     0.901     0.746\n",
       "102  minilm-6_sarcastic  Contrastive 0.1     0.836     0.215\n",
       "..                  ...              ...       ...       ...\n",
       "50   gte-base_sarcastic      Triplet 7.5     0.887     0.745\n",
       "33        e5-small_sst2      Triplet 7.5     0.908     0.819\n",
       "16   e5-small_sarcastic      Triplet 7.5     0.892     0.784\n",
       "118  minilm-6_sarcastic      Triplet 7.5     0.853     0.201\n",
       "135       minilm-6_sst2      Triplet 7.5     0.833     0.304\n",
       "\n",
       "[136 rows x 4 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_df = df.groupby([\"model_for_data\", \"loss\"]).mean()\n",
    "_df = _df.reset_index()\n",
    "_df = _df.sort_values(by=\"loss\", ascending=True)\n",
    "_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{llllllllll}\n",
      "\\toprule\n",
      "loss & lambda & e5-small_sarcastic & e5-small_sst2 & gte-base_sarcastic & gte-base_sst2 & gte-small_sarcastic & gte-small_sst2 & minilm-6_sarcastic & minilm-6_sst2 \\\\\n",
      "\\midrule\n",
      "Contrastive & 0.100 & 0.910 & 0.919 & 0.901 & 0.909 & 0.878 & 0.893 & 0.836 & 0.738 \\\\\n",
      "Contrastive & 0.250 & 0.914 & 0.921 & 0.905 & 0.909 & 0.892 & 0.900 & 0.843 & 0.766 \\\\\n",
      "Contrastive & 0.500 & 0.916 & 0.924 & 0.905 & 0.910 & 0.897 & 0.900 & 0.849 & 0.800 \\\\\n",
      "Contrastive & 0.750 & 0.915 & 0.925 & 0.904 & 0.913 & 0.896 & 0.902 & 0.855 & 0.816 \\\\\n",
      "Contrastive & 1.000 & 0.915 & 0.924 & 0.906 & 0.913 & 0.897 & 0.904 & 0.863 & 0.825 \\\\\n",
      "MultipleNegatives & - & 0.786 & 0.833 & 0.790 & 0.859 & 0.772 & 0.837 & 0.750 & 0.740 \\\\\n",
      "OnlineContrastive & 0.100 & 0.915 & 0.925 & 0.904 & 0.908 & 0.886 & 0.897 & 0.842 & 0.760 \\\\\n",
      "OnlineContrastive & 0.250 & 0.917 & 0.927 & 0.905 & 0.912 & 0.896 & 0.901 & 0.847 & 0.785 \\\\\n",
      "OnlineContrastive & 0.500 & 0.917 & 0.926 & 0.905 & 0.913 & 0.893 & 0.902 & 0.853 & 0.803 \\\\\n",
      "OnlineContrastive & 0.750 & 0.917 & 0.927 & 0.906 & 0.910 & 0.897 & 0.902 & 0.861 & 0.818 \\\\\n",
      "OnlineContrastive & 1.000 & 0.914 & 0.926 & 0.909 & 0.910 & 0.898 & 0.902 & 0.864 & 0.827 \\\\\n",
      "Triplet & 0.010 & 0.918 & 0.934 & 0.887 & 0.912 & 0.896 & 0.901 & 0.848 & 0.823 \\\\\n",
      "Triplet & 0.100 & 0.920 & 0.930 & 0.907 & 0.913 & 0.901 & 0.905 & 0.878 & 0.844 \\\\\n",
      "Triplet & 1.000 & 0.916 & 0.921 & 0.900 & 0.908 & 0.893 & 0.893 & 0.876 & 0.854 \\\\\n",
      "Triplet & 5.000 & 0.893 & 0.907 & 0.887 & 0.910 & 0.872 & 0.890 & 0.853 & 0.833 \\\\\n",
      "Triplet & 7.500 & 0.892 & 0.908 & 0.887 & 0.905 & 0.873 & 0.889 & 0.853 & 0.833 \\\\\n",
      "Triplet & 10.000 & 0.895 & 0.908 & 0.886 & 0.909 & 0.873 & 0.891 & 0.852 & 0.833 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def transpose_and_group(df, metric=\"polarity\", as_latex=False):\n",
    "    tmp = df.pivot(index=\"model_for_data\", columns=\"loss\", values=metric).reset_index().T\n",
    "    column_names = tmp.iloc[0]\n",
    "    tmp = tmp.drop(tmp.index[0])\n",
    "    idx = tmp.index\n",
    "    tmp = tmp.reset_index(drop=True)\n",
    "    tmp.columns = column_names\n",
    "\n",
    "    tmp[\"loss\"] = idx\n",
    "    tmp[\"lambda\"] = tmp[\"loss\"].map(lambda x: float(x.split(\" \")[1]) if len(x.split(\" \")) > 1 else \"-\")\n",
    "    tmp[\"loss\"] = tmp[\"loss\"].map(lambda x: x.split(\" \")[0])\n",
    "    tmp = tmp[[\"loss\", \"lambda\"] + list(tmp.columns[:-2])]\n",
    "    tmp = tmp.sort_values(by=[\"loss\", \"lambda\"], ascending=True)\n",
    "    tmp.columns.name = None\n",
    "\n",
    "    if as_latex:\n",
    "        return tmp.to_latex(float_format=\"%.3f\", index=False)\n",
    "    return tmp\n",
    "\n",
    "# pd.DataFrame(transpose_and_group(_df, metric=\"polarity\"))\n",
    "print(transpose_and_group(_df, metric=\"polarity\", as_latex=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{llllllllll}\n",
      "\\toprule\n",
      "loss & lambda & e5-small_sarcastic & e5-small_sst2 & gte-base_sarcastic & gte-base_sst2 & gte-small_sarcastic & gte-small_sst2 & minilm-6_sarcastic & minilm-6_sst2 \\\\\n",
      "\\midrule\n",
      "Contrastive & 0.100 & 0.791 & 0.832 & 0.746 & 0.805 & 0.772 & 0.818 & 0.215 & 0.316 \\\\\n",
      "Contrastive & 0.250 & 0.794 & 0.836 & 0.752 & 0.810 & 0.779 & 0.823 & 0.217 & 0.314 \\\\\n",
      "Contrastive & 0.500 & 0.795 & 0.835 & 0.757 & 0.813 & 0.785 & 0.825 & 0.210 & 0.309 \\\\\n",
      "Contrastive & 0.750 & 0.797 & 0.835 & 0.758 & 0.813 & 0.784 & 0.825 & 0.218 & 0.310 \\\\\n",
      "Contrastive & 1.000 & 0.795 & 0.835 & 0.761 & 0.813 & 0.782 & 0.824 & 0.231 & 0.312 \\\\\n",
      "MultipleNegatives & - & 0.829 & 0.851 & 0.801 & 0.825 & 0.815 & 0.838 & 0.385 & 0.426 \\\\\n",
      "OnlineContrastive & 0.100 & 0.799 & 0.836 & 0.752 & 0.810 & 0.780 & 0.823 & 0.223 & 0.308 \\\\\n",
      "OnlineContrastive & 0.250 & 0.803 & 0.838 & 0.761 & 0.814 & 0.787 & 0.825 & 0.225 & 0.314 \\\\\n",
      "OnlineContrastive & 0.500 & 0.805 & 0.838 & 0.771 & 0.818 & 0.791 & 0.828 & 0.224 & 0.317 \\\\\n",
      "OnlineContrastive & 0.750 & 0.805 & 0.837 & 0.766 & 0.818 & 0.791 & 0.827 & 0.226 & 0.322 \\\\\n",
      "OnlineContrastive & 1.000 & 0.806 & 0.837 & 0.768 & 0.816 & 0.788 & 0.825 & 0.230 & 0.325 \\\\\n",
      "Triplet & 0.010 & 0.815 & 0.838 & 0.766 & 0.812 & 0.786 & 0.824 & 0.208 & 0.310 \\\\\n",
      "Triplet & 0.100 & 0.812 & 0.837 & 0.774 & 0.817 & 0.794 & 0.824 & 0.269 & 0.328 \\\\\n",
      "Triplet & 1.000 & 0.795 & 0.828 & 0.766 & 0.802 & 0.773 & 0.815 & 0.237 & 0.316 \\\\\n",
      "Triplet & 5.000 & 0.785 & 0.818 & 0.745 & 0.801 & 0.759 & 0.809 & 0.202 & 0.304 \\\\\n",
      "Triplet & 7.500 & 0.784 & 0.819 & 0.745 & 0.801 & 0.759 & 0.808 & 0.201 & 0.304 \\\\\n",
      "Triplet & 10.000 & 0.784 & 0.819 & 0.744 & 0.800 & 0.758 & 0.808 & 0.205 & 0.304 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(transpose_and_group(_df, metric=\"semantic\", as_latex=True))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "simcse",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
