{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>dataset</th>\n",
       "      <th>loss</th>\n",
       "      <th>epoch</th>\n",
       "      <th>polarity_mean</th>\n",
       "      <th>semantic_mean</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>minilm-6</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>MultipleNegativesRankingLoss</td>\n",
       "      <td>1</td>\n",
       "      <td>66.541</td>\n",
       "      <td>41.155</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>minilm-6</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>MultipleNegativesRankingLoss</td>\n",
       "      <td>2</td>\n",
       "      <td>67.275</td>\n",
       "      <td>40.589</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>minilm-6</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>MultipleNegativesRankingLoss</td>\n",
       "      <td>3</td>\n",
       "      <td>67.968</td>\n",
       "      <td>40.264</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>minilm-6</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>MultipleNegativesRankingLoss</td>\n",
       "      <td>4</td>\n",
       "      <td>68.594</td>\n",
       "      <td>40.044</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>minilm-6</td>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>MultipleNegativesRankingLoss</td>\n",
       "      <td>5</td>\n",
       "      <td>68.978</td>\n",
       "      <td>39.852</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>675</th>\n",
       "      <td>gte-base</td>\n",
       "      <td>sst2</td>\n",
       "      <td>TripletLoss</td>\n",
       "      <td>1</td>\n",
       "      <td>87.952</td>\n",
       "      <td>83.035</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>676</th>\n",
       "      <td>gte-base</td>\n",
       "      <td>sst2</td>\n",
       "      <td>TripletLoss</td>\n",
       "      <td>2</td>\n",
       "      <td>89.278</td>\n",
       "      <td>82.520</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>677</th>\n",
       "      <td>gte-base</td>\n",
       "      <td>sst2</td>\n",
       "      <td>TripletLoss</td>\n",
       "      <td>3</td>\n",
       "      <td>89.579</td>\n",
       "      <td>82.152</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>678</th>\n",
       "      <td>gte-base</td>\n",
       "      <td>sst2</td>\n",
       "      <td>TripletLoss</td>\n",
       "      <td>4</td>\n",
       "      <td>90.023</td>\n",
       "      <td>82.005</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>679</th>\n",
       "      <td>gte-base</td>\n",
       "      <td>sst2</td>\n",
       "      <td>TripletLoss</td>\n",
       "      <td>5</td>\n",
       "      <td>90.295</td>\n",
       "      <td>81.913</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>680 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        model              dataset                          loss  epoch  \\\n",
       "0    minilm-6  sarcastic-headlines  MultipleNegativesRankingLoss      1   \n",
       "1    minilm-6  sarcastic-headlines  MultipleNegativesRankingLoss      2   \n",
       "2    minilm-6  sarcastic-headlines  MultipleNegativesRankingLoss      3   \n",
       "3    minilm-6  sarcastic-headlines  MultipleNegativesRankingLoss      4   \n",
       "4    minilm-6  sarcastic-headlines  MultipleNegativesRankingLoss      5   \n",
       "..        ...                  ...                           ...    ...   \n",
       "675  gte-base                 sst2                   TripletLoss      1   \n",
       "676  gte-base                 sst2                   TripletLoss      2   \n",
       "677  gte-base                 sst2                   TripletLoss      3   \n",
       "678  gte-base                 sst2                   TripletLoss      4   \n",
       "679  gte-base                 sst2                   TripletLoss      5   \n",
       "\n",
       "     polarity_mean  semantic_mean  \n",
       "0           66.541         41.155  \n",
       "1           67.275         40.589  \n",
       "2           67.968         40.264  \n",
       "3           68.594         40.044  \n",
       "4           68.978         39.852  \n",
       "..             ...            ...  \n",
       "675         87.952         83.035  \n",
       "676         89.278         82.520  \n",
       "677         89.579         82.152  \n",
       "678         90.023         82.005  \n",
       "679         90.295         81.913  \n",
       "\n",
       "[680 rows x 6 columns]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "import matplotlib\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "sys.path.append(\"../\")  # for utils\n",
    "sys.path.append(\"../section4-N=50k/\")  # the appendix is based on the 50k sample training\n",
    "from util import sns_config, matplot_config, tex_config, Baseline, aggregated_losses_for_model, filter_on_metric\n",
    "\n",
    "sns.set_style(\"whitegrid\", sns_config)\n",
    "matplotlib.rcParams.update(matplot_config)\n",
    "\n",
    "baseline = Baseline(path=\"../section3-baseline/baseline_k.csv\")\n",
    "df = pd.read_csv(\"../section4-N=50k/50k_novemberrun_history.csv\")\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['sarcastic-headlines', 'sst2'], dtype=object)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datasets = df.dataset.unique()\n",
    "datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>model</th>\n",
       "      <th>k</th>\n",
       "      <th>p_mean</th>\n",
       "      <th>p_std</th>\n",
       "      <th>s_mean</th>\n",
       "      <th>s_std</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>16</td>\n",
       "      <td>80.376</td>\n",
       "      <td>22.561</td>\n",
       "      <td>83.700</td>\n",
       "      <td>1.436</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>16</td>\n",
       "      <td>77.774</td>\n",
       "      <td>22.204</td>\n",
       "      <td>84.808</td>\n",
       "      <td>1.410</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>sst2</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>16</td>\n",
       "      <td>81.465</td>\n",
       "      <td>23.696</td>\n",
       "      <td>85.516</td>\n",
       "      <td>1.735</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>sst2</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>16</td>\n",
       "      <td>63.016</td>\n",
       "      <td>21.877</td>\n",
       "      <td>46.594</td>\n",
       "      <td>7.426</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>16</td>\n",
       "      <td>67.402</td>\n",
       "      <td>20.739</td>\n",
       "      <td>81.366</td>\n",
       "      <td>1.641</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>16</td>\n",
       "      <td>66.826</td>\n",
       "      <td>20.573</td>\n",
       "      <td>82.468</td>\n",
       "      <td>1.629</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>16</td>\n",
       "      <td>71.436</td>\n",
       "      <td>21.241</td>\n",
       "      <td>83.404</td>\n",
       "      <td>1.463</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>16</td>\n",
       "      <td>63.749</td>\n",
       "      <td>20.181</td>\n",
       "      <td>42.307</td>\n",
       "      <td>5.615</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                dataset      model   k  p_mean   p_std  s_mean  s_std\n",
       "2                  sst2   gte-base  16  80.376  22.561  83.700  1.436\n",
       "7                  sst2  gte-small  16  77.774  22.204  84.808  1.410\n",
       "12                 sst2   e5-small  16  81.465  23.696  85.516  1.735\n",
       "17                 sst2   minilm-6  16  63.016  21.877  46.594  7.426\n",
       "22  sarcastic-headlines   gte-base  16  67.402  20.739  81.366  1.641\n",
       "27  sarcastic-headlines  gte-small  16  66.826  20.573  82.468  1.629\n",
       "32  sarcastic-headlines   e5-small  16  71.436  21.241  83.404  1.463\n",
       "37  sarcastic-headlines   minilm-6  16  63.749  20.181  42.307  5.615"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "baseline.baseline_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/0k/1bg63zt532nb9d86g5tk_6vh0000gn/T/ipykernel_56016/115537045.py:56: UserWarning: FigureCanvasPgf is non-interactive, and thus cannot be shown\n",
      "  plt.show()\n",
      "/var/folders/0k/1bg63zt532nb9d86g5tk_6vh0000gn/T/ipykernel_56016/115537045.py:56: UserWarning: FigureCanvasPgf is non-interactive, and thus cannot be shown\n",
      "  plt.show()\n",
      "/var/folders/0k/1bg63zt532nb9d86g5tk_6vh0000gn/T/ipykernel_56016/115537045.py:56: UserWarning: FigureCanvasPgf is non-interactive, and thus cannot be shown\n",
      "  plt.show()\n",
      "/var/folders/0k/1bg63zt532nb9d86g5tk_6vh0000gn/T/ipykernel_56016/115537045.py:56: UserWarning: FigureCanvasPgf is non-interactive, and thus cannot be shown\n",
      "  plt.show()\n"
     ]
    }
   ],
   "source": [
    "from IPython.display import display\n",
    "\n",
    "save_latex = True\n",
    "if save_latex:\n",
    "    matplotlib.use(\"pgf\")\n",
    "    matplotlib.rcParams.update(tex_config)\n",
    "\n",
    "linestyles = [\n",
    "    (0, (1, 1)),\n",
    "    (0, (5, 5)),\n",
    "    (0, (5, 1)),\n",
    "    (0, (3, 5, 1, 5)),\n",
    "]\n",
    "\n",
    "for model in df.model.unique():\n",
    "    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(14, 8))\n",
    "    subplot_mapping = {\n",
    "        (\"sst2\", \"polarity_mean\"): axes[0, 0],\n",
    "        (\"sst2\", \"semantic_mean\"): axes[0, 1],\n",
    "        (\"sarcastic-headlines\", \"polarity_mean\"): axes[1, 0],\n",
    "        (\"sarcastic-headlines\", \"semantic_mean\"): axes[1, 1],\n",
    "    }\n",
    "\n",
    "    fig.suptitle(f\"{model} on SST-2 and Sarcastic Headlines\", fontsize=24)\n",
    "    metrics = [\"polarity_mean\", \"semantic_mean\"]\n",
    "    for metric in metrics:\n",
    "        _df = df.copy()\n",
    "        _df = _df[_df.model == model]\n",
    "        _df = filter_on_metric(_df, metric, to_drop=[\"polarity_mean\", \"semantic_mean\"])\n",
    "\n",
    "        # display(_df)\n",
    "\n",
    "        y_min = min(_df[metric].min(), baseline.get_extremes(model, agg=min, metrics=[\"p_mean\", \"s_mean\"])) - 0.01\n",
    "        y_max = max(_df[metric].max(), baseline.get_extremes(model, agg=max, metrics=[\"p_mean\", \"s_mean\"])) + 0.01\n",
    "        y_min, y_max = round(y_min, 2), round(y_max, 2)\n",
    "\n",
    "        for dataset in datasets:\n",
    "            ax = subplot_mapping[(dataset, metric)]\n",
    "            ax.set_xlim(1, 5)\n",
    "            ax.set_title(f\"{metric.replace('semantic', 'semantic similarity').split('_')[0]}, {dataset}\")\n",
    "            ax.set_xlabel(\"Epoch\")\n",
    "            ax.set_ylabel(\"Score\")\n",
    "\n",
    "            for j, loss in enumerate(_df.loss.unique()):\n",
    "                __df = _df[(_df.loss == loss) & (_df.dataset == dataset)]\n",
    "                __df = __df.groupby([\"epoch\"]).max().reset_index()\n",
    "                sns.lineplot(data=__df, x=\"epoch\", y=metric, ax=ax, label=loss, legend=False, errorbar=None, linestyle=linestyles[j])\n",
    "    \n",
    "            baseline_val = baseline.get(model=model, dataset=dataset, metric=\"p_mean\" if metric == \"polarity_mean\" else \"s_mean\")\n",
    "            if baseline_val:\n",
    "                ax.plot([1, 5], [baseline_val, baseline_val], linestyle=\"solid\", color=\"black\", alpha=1, linewidth=3, label=f\"{model} baseline\")\n",
    "\n",
    "    handles, labels = ax.get_legend_handles_labels()\n",
    "    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, -0.1), ncol=3, borderaxespad=-2)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    if save_latex:\n",
    "        folder = f\"figures/model_max_loss\"\n",
    "        os.makedirs(folder, exist_ok=True)\n",
    "        plt.savefig(f\"{folder}/{model}-max-loss.pgf\", bbox_inches=\"tight\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "simcse",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
