{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>MR</th>\n",
       "      <th>CR</th>\n",
       "      <th>SUBJ</th>\n",
       "      <th>MPQA</th>\n",
       "      <th>SST2</th>\n",
       "      <th>TREC</th>\n",
       "      <th>avg</th>\n",
       "      <th>ID</th>\n",
       "      <th>model_type</th>\n",
       "      <th>dataset</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>84.30</td>\n",
       "      <td>88.85</td>\n",
       "      <td>90.91</td>\n",
       "      <td>86.08</td>\n",
       "      <td>89.18</td>\n",
       "      <td>86.0</td>\n",
       "      <td>84.27</td>\n",
       "      <td>setfit</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>sst2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>81.61</td>\n",
       "      <td>86.52</td>\n",
       "      <td>90.01</td>\n",
       "      <td>87.50</td>\n",
       "      <td>88.69</td>\n",
       "      <td>86.0</td>\n",
       "      <td>81.92</td>\n",
       "      <td>setfit</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>sarcastic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>85.43</td>\n",
       "      <td>85.16</td>\n",
       "      <td>86.58</td>\n",
       "      <td>83.93</td>\n",
       "      <td>91.05</td>\n",
       "      <td>88.0</td>\n",
       "      <td>82.18</td>\n",
       "      <td>setfit</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>sst2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>82.69</td>\n",
       "      <td>83.97</td>\n",
       "      <td>90.65</td>\n",
       "      <td>86.80</td>\n",
       "      <td>88.80</td>\n",
       "      <td>90.2</td>\n",
       "      <td>81.62</td>\n",
       "      <td>setfit</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>sarcastic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>71.20</td>\n",
       "      <td>66.44</td>\n",
       "      <td>86.57</td>\n",
       "      <td>79.63</td>\n",
       "      <td>80.94</td>\n",
       "      <td>74.4</td>\n",
       "      <td>74.61</td>\n",
       "      <td>finetuned</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>sarcastic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>81.21</td>\n",
       "      <td>84.53</td>\n",
       "      <td>87.43</td>\n",
       "      <td>84.76</td>\n",
       "      <td>86.49</td>\n",
       "      <td>81.2</td>\n",
       "      <td>81.53</td>\n",
       "      <td>finetuned</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>sst2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>82.40</td>\n",
       "      <td>76.27</td>\n",
       "      <td>90.47</td>\n",
       "      <td>85.75</td>\n",
       "      <td>89.95</td>\n",
       "      <td>71.4</td>\n",
       "      <td>78.81</td>\n",
       "      <td>finetuned</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>sarcastic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>88.95</td>\n",
       "      <td>88.98</td>\n",
       "      <td>91.06</td>\n",
       "      <td>86.28</td>\n",
       "      <td>93.41</td>\n",
       "      <td>79.8</td>\n",
       "      <td>84.97</td>\n",
       "      <td>finetuned</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>sst2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>84.33</td>\n",
       "      <td>88.82</td>\n",
       "      <td>92.82</td>\n",
       "      <td>88.04</td>\n",
       "      <td>90.83</td>\n",
       "      <td>88.4</td>\n",
       "      <td>85.01</td>\n",
       "      <td>finetuned</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>sarcastic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>89.31</td>\n",
       "      <td>89.27</td>\n",
       "      <td>92.91</td>\n",
       "      <td>85.95</td>\n",
       "      <td>93.19</td>\n",
       "      <td>80.8</td>\n",
       "      <td>85.50</td>\n",
       "      <td>finetuned</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>sst2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>80.51</td>\n",
       "      <td>83.52</td>\n",
       "      <td>90.17</td>\n",
       "      <td>86.11</td>\n",
       "      <td>87.59</td>\n",
       "      <td>84.6</td>\n",
       "      <td>81.84</td>\n",
       "      <td>finetuned</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>sarcastic</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>87.72</td>\n",
       "      <td>89.59</td>\n",
       "      <td>90.85</td>\n",
       "      <td>86.86</td>\n",
       "      <td>91.38</td>\n",
       "      <td>79.0</td>\n",
       "      <td>84.83</td>\n",
       "      <td>finetuned</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>sst2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      MR     CR   SUBJ   MPQA   SST2  TREC    avg         ID model_type  \\\n",
       "0  84.30  88.85  90.91  86.08  89.18  86.0  84.27     setfit   gte-base   \n",
       "1  81.61  86.52  90.01  87.50  88.69  86.0  81.92     setfit   gte-base   \n",
       "2  85.43  85.16  86.58  83.93  91.05  88.0  82.18     setfit   e5-small   \n",
       "3  82.69  83.97  90.65  86.80  88.80  90.2  81.62     setfit   e5-small   \n",
       "0  71.20  66.44  86.57  79.63  80.94  74.4  74.61  finetuned   minilm-6   \n",
       "1  81.21  84.53  87.43  84.76  86.49  81.2  81.53  finetuned   minilm-6   \n",
       "2  82.40  76.27  90.47  85.75  89.95  71.4  78.81  finetuned   e5-small   \n",
       "3  88.95  88.98  91.06  86.28  93.41  79.8  84.97  finetuned   e5-small   \n",
       "4  84.33  88.82  92.82  88.04  90.83  88.4  85.01  finetuned   gte-base   \n",
       "5  89.31  89.27  92.91  85.95  93.19  80.8  85.50  finetuned   gte-base   \n",
       "6  80.51  83.52  90.17  86.11  87.59  84.6  81.84  finetuned  gte-small   \n",
       "7  87.72  89.59  90.85  86.86  91.38  79.0  84.83  finetuned  gte-small   \n",
       "\n",
       "     dataset  \n",
       "0       sst2  \n",
       "1  sarcastic  \n",
       "2       sst2  \n",
       "3  sarcastic  \n",
       "0  sarcastic  \n",
       "1       sst2  \n",
       "2  sarcastic  \n",
       "3       sst2  \n",
       "4  sarcastic  \n",
       "5       sst2  \n",
       "6  sarcastic  \n",
       "7       sst2  "
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "senteval_setfit = pd.read_csv(\"setfit/setfit_senteval_results.csv\")\n",
    "senteval_setfit[\"ID\"] = [\"setfit\"] * len(senteval_setfit)\n",
    "senteval_finetuned = pd.read_csv(\"finetuned/finetuned_senteval_results.csv\")\n",
    "senteval_finetuned[\"ID\"] = [\"finetuned\"] * len(senteval_finetuned)\n",
    "merged = pd.concat([senteval_setfit, senteval_finetuned])\n",
    "model_col = merged.model.tolist()\n",
    "dataset_col = [\"sst2\" if \"sst2\" in model else \"sarcastic\" for model in model_col]\n",
    "model_types = merged.model.apply(lambda x: \"-\".join(x.replace(\"USERNAME/\", \"\").split(\"-\")[:2]))\n",
    "model_types = model_types.apply(lambda x: \"minilm-6\" if x == \"all-MiniLM\" else x)\n",
    "merged[\"model_type\"] = model_types\n",
    "merged[\"dataset\"] = dataset_col\n",
    "merged = merged.drop(\"model\", axis=1)\n",
    "merged = merged.drop(columns=[\"SICKEntailment\", \"MRPC\"])\n",
    "merged.loc[merged.ID == \"pretrained\", \"dataset\"] = \"-\"\n",
    "merged"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lllrrrrrrr}\n",
      "\\toprule\n",
      "Variant & Base model & Dataset & MR & CR & SUBJ & MPQA & SST2 & TREC & avg \\\\\n",
      "\\midrule\n",
      "finetuned & gte-base & sst2 & 89.31 & 89.27 & 92.91 & 85.95 & 93.19 & 80.80 & 85.50 \\\\\n",
      "finetuned & gte-base & sarcastic & 84.33 & 88.82 & 92.82 & 88.04 & 90.83 & 88.40 & 85.01 \\\\\n",
      "finetuned & e5-small & sst2 & 88.95 & 88.98 & 91.06 & 86.28 & 93.41 & 79.80 & 84.97 \\\\\n",
      "finetuned & gte-small & sst2 & 87.72 & 89.59 & 90.85 & 86.86 & 91.38 & 79.00 & 84.83 \\\\\n",
      "setfit & gte-base & sst2 & 84.30 & 88.85 & 90.91 & 86.08 & 89.18 & 86.00 & 84.27 \\\\\n",
      "setfit & e5-small & sst2 & 85.43 & 85.16 & 86.58 & 83.93 & 91.05 & 88.00 & 82.18 \\\\\n",
      "setfit & gte-base & sarcastic & 81.61 & 86.52 & 90.01 & 87.50 & 88.69 & 86.00 & 81.92 \\\\\n",
      "finetuned & gte-small & sarcastic & 80.51 & 83.52 & 90.17 & 86.11 & 87.59 & 84.60 & 81.84 \\\\\n",
      "setfit & e5-small & sarcastic & 82.69 & 83.97 & 90.65 & 86.80 & 88.80 & 90.20 & 81.62 \\\\\n",
      "finetuned & minilm-6 & sst2 & 81.21 & 84.53 & 87.43 & 84.76 & 86.49 & 81.20 & 81.53 \\\\\n",
      "finetuned & e5-small & sarcastic & 82.40 & 76.27 & 90.47 & 85.75 & 89.95 & 71.40 & 78.81 \\\\\n",
      "finetuned & minilm-6 & sarcastic & 71.20 & 66.44 & 86.57 & 79.63 & 80.94 & 74.40 & 74.61 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# sort by AVG score\n",
    "merged = merged.sort_values(\"avg\", ascending=False)\n",
    "merged = merged.reset_index(drop=True)\n",
    "\n",
    "merged = merged.rename(columns={\"ID\": \"Variant\", \"model_type\": \"Base model\", \"dataset\": \"Dataset\"})\n",
    "cols = merged.columns.tolist()\n",
    "cols = cols[-3:] + cols[:-3]\n",
    "merged = merged[cols]\n",
    "merged = merged.reset_index(drop=True)\n",
    "\n",
    "print(merged.to_latex(float_format=\"%.2f\", index=False))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
