{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'gte-base': 'thenlper/gte-base', 'gte-small': 'thenlper/gte-small', 'e5-small': 'intfloat/e5-small-v2', 'minilm-6': 'sentence-transformers/all-MiniLM-L6-v2'}\n",
      "ParameterConfig(RANDOM_SEED=42, SAMPLING_SEED=42, VERBOSE=False, N_ITERS=5, EVAL_EVERY=2, BATCH_SIZE=64, DEVICE='cuda', K=16, LEARNING_RATE=3e-05)\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../../')\n",
    "from util import get_loss_data, get_original_datasets, get_models, get_train_test\n",
    "from model_config import generate_model_configs, get_parameters, ModelConfig, ParameterConfig\n",
    "from loss_config import get_loss_config\n",
    "\n",
    "loss_datas = get_loss_data(path=\"../../data_generation/data\")\n",
    "original_datasets = get_original_datasets(path=\"../../data_generation/datasets\")\n",
    "models = get_models()\n",
    "print(models)\n",
    "\n",
    "loss_config = get_loss_config(default_lambda=True)\n",
    "model_configs: ModelConfig = generate_model_configs(models, loss_config)\n",
    "params: ParameterConfig = get_parameters(path=\"../../params.yml\")\n",
    "print(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "torch.cuda.set_device(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Triplet(gte-base)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:24<00:00,  4.84s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Contrastive(gte-base)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:24<00:00,  4.82s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating OnlineContrastive(gte-base)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:24<00:00,  4.82s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating MultipleNegatives(gte-base)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:23<00:00,  4.77s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Triplet(gte-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:17<00:00,  3.40s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Contrastive(gte-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:17<00:00,  3.47s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating OnlineContrastive(gte-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:17<00:00,  3.46s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating MultipleNegatives(gte-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:17<00:00,  3.47s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Triplet(e5-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:17<00:00,  3.48s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Contrastive(e5-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:18<00:00,  3.68s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating OnlineContrastive(e5-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:18<00:00,  3.67s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating MultipleNegatives(e5-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:18<00:00,  3.68s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Triplet(minilm-6)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:16<00:00,  3.38s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Contrastive(minilm-6)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:16<00:00,  3.37s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating OnlineContrastive(minilm-6)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:16<00:00,  3.34s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating MultipleNegatives(minilm-6)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:15<00:00,  3.12s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Triplet(gte-base)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [01:17<00:00, 15.52s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Contrastive(gte-base)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [01:19<00:00, 15.88s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating OnlineContrastive(gte-base)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [01:16<00:00, 15.25s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating MultipleNegatives(gte-base)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [01:21<00:00, 16.30s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Triplet(gte-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:54<00:00, 10.99s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Contrastive(gte-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating OnlineContrastive(gte-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:58<00:00, 11.76s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating MultipleNegatives(gte-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:54<00:00, 10.93s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Triplet(e5-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:55<00:00, 11.15s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Contrastive(e5-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:55<00:00, 11.14s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating OnlineContrastive(e5-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:55<00:00, 11.13s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating MultipleNegatives(e5-small)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [01:02<00:00, 12.53s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Triplet(minilm-6)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:57<00:00, 11.43s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating Contrastive(minilm-6)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:55<00:00, 11.02s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating OnlineContrastive(minilm-6)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:55<00:00, 11.01s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating MultipleNegatives(minilm-6)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:53<00:00, 10.79s/it]\n"
     ]
    }
   ],
   "source": [
    "from model_evaluator import evaluate_model\n",
    "from tqdm import tqdm\n",
    "from util import get_train_test\n",
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "all_scores = {}\n",
    "k_values = [4, 8, 16, 32, 64]\n",
    "for dataset in original_datasets.keys():\n",
    "    all_scores[dataset] = {}\n",
    "    for model_id, model_config in model_configs.items():\n",
    "        print(f\"Evaluating {model_id}\")\n",
    "        model_name = model_config.model_name\n",
    "        model = SentenceTransformer(model_name).to(params.DEVICE)\n",
    "        train, test = get_train_test(original_datasets, dataset)\n",
    "        ref_train_emb = model.encode(train['text'].tolist())\n",
    "        ref_test_emb = model.encode(test['text'].tolist())\n",
    "        all_scores[dataset][model_name] = {}\n",
    "        for k in tqdm(k_values):\n",
    "            scores = evaluate_model(\n",
    "                model=model,\n",
    "                name=f\"{dataset}_{model_id}\",\n",
    "                train_data=train,\n",
    "                test_data=test,\n",
    "                reference_train_emb=ref_train_emb,\n",
    "                reference_test_emb=ref_test_emb,\n",
    "                k=k,\n",
    "                verbose=False\n",
    "            )\n",
    "            all_scores[dataset][model_name][k] = scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "data = []\n",
    "for dataset, dataset_scores in all_scores.items():\n",
    "    for model_name, model_scores in dataset_scores.items():\n",
    "        for k, scores in model_scores.items():\n",
    "            polarity = scores['polarity']\n",
    "            p_mean = np.round(np.mean(polarity), 3)\n",
    "            p_std = np.round(np.std(polarity), 3)\n",
    "            semantic = scores['semantic']\n",
    "            s_mean = np.round(np.mean(semantic), 3)\n",
    "            s_std = np.round(np.std(semantic), 3)\n",
    "\n",
    "            row = {\n",
    "                'dataset': dataset,\n",
    "                'model': model_name,\n",
    "                'k': k,\n",
    "                'p_mean': p_mean,\n",
    "                'p_std': p_std,\n",
    "                's_mean': s_mean,\n",
    "                's_std': s_std\n",
    "            }\n",
    "            data.append(row)\n",
    "\n",
    "df = pd.DataFrame(data)\n",
    "model_translations = {v: k for k, v in models.items()}\n",
    "df.model = df.model.apply(lambda x: model_translations[x])\n",
    "df.to_csv('baseline_k.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>model</th>\n",
       "      <th>k</th>\n",
       "      <th>p_mean</th>\n",
       "      <th>p_std</th>\n",
       "      <th>s_mean</th>\n",
       "      <th>s_std</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>4</td>\n",
       "      <td>80.447</td>\n",
       "      <td>27.519</td>\n",
       "      <td>84.980</td>\n",
       "      <td>1.629</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>8</td>\n",
       "      <td>80.691</td>\n",
       "      <td>24.036</td>\n",
       "      <td>84.349</td>\n",
       "      <td>1.501</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>16</td>\n",
       "      <td>80.376</td>\n",
       "      <td>22.561</td>\n",
       "      <td>83.700</td>\n",
       "      <td>1.436</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>32</td>\n",
       "      <td>79.942</td>\n",
       "      <td>21.068</td>\n",
       "      <td>83.044</td>\n",
       "      <td>1.395</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>64</td>\n",
       "      <td>78.813</td>\n",
       "      <td>20.539</td>\n",
       "      <td>82.368</td>\n",
       "      <td>1.360</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>4</td>\n",
       "      <td>78.268</td>\n",
       "      <td>27.836</td>\n",
       "      <td>86.029</td>\n",
       "      <td>1.599</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>8</td>\n",
       "      <td>78.139</td>\n",
       "      <td>23.856</td>\n",
       "      <td>85.427</td>\n",
       "      <td>1.483</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>16</td>\n",
       "      <td>77.774</td>\n",
       "      <td>22.204</td>\n",
       "      <td>84.808</td>\n",
       "      <td>1.410</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>32</td>\n",
       "      <td>76.867</td>\n",
       "      <td>21.215</td>\n",
       "      <td>84.182</td>\n",
       "      <td>1.360</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>64</td>\n",
       "      <td>75.502</td>\n",
       "      <td>20.692</td>\n",
       "      <td>83.532</td>\n",
       "      <td>1.315</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>sst2</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>4</td>\n",
       "      <td>82.167</td>\n",
       "      <td>27.201</td>\n",
       "      <td>86.510</td>\n",
       "      <td>1.855</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>sst2</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>8</td>\n",
       "      <td>82.010</td>\n",
       "      <td>25.039</td>\n",
       "      <td>86.035</td>\n",
       "      <td>1.789</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>sst2</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>16</td>\n",
       "      <td>81.465</td>\n",
       "      <td>23.696</td>\n",
       "      <td>85.516</td>\n",
       "      <td>1.735</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>sst2</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>32</td>\n",
       "      <td>80.637</td>\n",
       "      <td>22.752</td>\n",
       "      <td>84.955</td>\n",
       "      <td>1.681</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>sst2</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>64</td>\n",
       "      <td>79.707</td>\n",
       "      <td>21.919</td>\n",
       "      <td>84.347</td>\n",
       "      <td>1.627</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>sst2</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>4</td>\n",
       "      <td>63.991</td>\n",
       "      <td>30.276</td>\n",
       "      <td>50.940</td>\n",
       "      <td>7.684</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>sst2</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>8</td>\n",
       "      <td>63.675</td>\n",
       "      <td>25.750</td>\n",
       "      <td>48.802</td>\n",
       "      <td>7.514</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>sst2</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>16</td>\n",
       "      <td>63.016</td>\n",
       "      <td>21.877</td>\n",
       "      <td>46.594</td>\n",
       "      <td>7.426</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>sst2</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>32</td>\n",
       "      <td>62.106</td>\n",
       "      <td>19.560</td>\n",
       "      <td>44.315</td>\n",
       "      <td>7.338</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>sst2</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>64</td>\n",
       "      <td>61.065</td>\n",
       "      <td>17.624</td>\n",
       "      <td>41.926</td>\n",
       "      <td>7.230</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>4</td>\n",
       "      <td>70.943</td>\n",
       "      <td>27.981</td>\n",
       "      <td>83.097</td>\n",
       "      <td>1.904</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>8</td>\n",
       "      <td>69.141</td>\n",
       "      <td>23.593</td>\n",
       "      <td>82.253</td>\n",
       "      <td>1.751</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>16</td>\n",
       "      <td>67.402</td>\n",
       "      <td>20.739</td>\n",
       "      <td>81.366</td>\n",
       "      <td>1.641</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>32</td>\n",
       "      <td>65.618</td>\n",
       "      <td>18.782</td>\n",
       "      <td>80.450</td>\n",
       "      <td>1.566</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>64</td>\n",
       "      <td>63.953</td>\n",
       "      <td>17.460</td>\n",
       "      <td>79.515</td>\n",
       "      <td>1.529</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>4</td>\n",
       "      <td>70.130</td>\n",
       "      <td>28.301</td>\n",
       "      <td>84.158</td>\n",
       "      <td>1.887</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>8</td>\n",
       "      <td>68.210</td>\n",
       "      <td>23.654</td>\n",
       "      <td>83.334</td>\n",
       "      <td>1.736</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>16</td>\n",
       "      <td>66.826</td>\n",
       "      <td>20.573</td>\n",
       "      <td>82.468</td>\n",
       "      <td>1.629</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>32</td>\n",
       "      <td>65.113</td>\n",
       "      <td>18.643</td>\n",
       "      <td>81.579</td>\n",
       "      <td>1.555</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>64</td>\n",
       "      <td>63.536</td>\n",
       "      <td>17.233</td>\n",
       "      <td>80.674</td>\n",
       "      <td>1.510</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>4</td>\n",
       "      <td>74.539</td>\n",
       "      <td>27.119</td>\n",
       "      <td>84.632</td>\n",
       "      <td>1.630</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>8</td>\n",
       "      <td>72.881</td>\n",
       "      <td>23.682</td>\n",
       "      <td>84.032</td>\n",
       "      <td>1.536</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>16</td>\n",
       "      <td>71.436</td>\n",
       "      <td>21.241</td>\n",
       "      <td>83.404</td>\n",
       "      <td>1.463</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>32</td>\n",
       "      <td>69.647</td>\n",
       "      <td>19.992</td>\n",
       "      <td>82.756</td>\n",
       "      <td>1.417</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>64</td>\n",
       "      <td>67.725</td>\n",
       "      <td>18.867</td>\n",
       "      <td>82.096</td>\n",
       "      <td>1.391</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>4</td>\n",
       "      <td>67.474</td>\n",
       "      <td>28.390</td>\n",
       "      <td>47.744</td>\n",
       "      <td>6.646</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>8</td>\n",
       "      <td>65.545</td>\n",
       "      <td>23.404</td>\n",
       "      <td>45.108</td>\n",
       "      <td>6.118</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>16</td>\n",
       "      <td>63.749</td>\n",
       "      <td>20.181</td>\n",
       "      <td>42.307</td>\n",
       "      <td>5.615</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>32</td>\n",
       "      <td>62.192</td>\n",
       "      <td>17.515</td>\n",
       "      <td>39.393</td>\n",
       "      <td>5.137</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>64</td>\n",
       "      <td>60.797</td>\n",
       "      <td>15.612</td>\n",
       "      <td>36.384</td>\n",
       "      <td>4.684</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                dataset      model   k  p_mean   p_std  s_mean  s_std\n",
       "0                  sst2   gte-base   4  80.447  27.519  84.980  1.629\n",
       "1                  sst2   gte-base   8  80.691  24.036  84.349  1.501\n",
       "2                  sst2   gte-base  16  80.376  22.561  83.700  1.436\n",
       "3                  sst2   gte-base  32  79.942  21.068  83.044  1.395\n",
       "4                  sst2   gte-base  64  78.813  20.539  82.368  1.360\n",
       "5                  sst2  gte-small   4  78.268  27.836  86.029  1.599\n",
       "6                  sst2  gte-small   8  78.139  23.856  85.427  1.483\n",
       "7                  sst2  gte-small  16  77.774  22.204  84.808  1.410\n",
       "8                  sst2  gte-small  32  76.867  21.215  84.182  1.360\n",
       "9                  sst2  gte-small  64  75.502  20.692  83.532  1.315\n",
       "10                 sst2   e5-small   4  82.167  27.201  86.510  1.855\n",
       "11                 sst2   e5-small   8  82.010  25.039  86.035  1.789\n",
       "12                 sst2   e5-small  16  81.465  23.696  85.516  1.735\n",
       "13                 sst2   e5-small  32  80.637  22.752  84.955  1.681\n",
       "14                 sst2   e5-small  64  79.707  21.919  84.347  1.627\n",
       "15                 sst2   minilm-6   4  63.991  30.276  50.940  7.684\n",
       "16                 sst2   minilm-6   8  63.675  25.750  48.802  7.514\n",
       "17                 sst2   minilm-6  16  63.016  21.877  46.594  7.426\n",
       "18                 sst2   minilm-6  32  62.106  19.560  44.315  7.338\n",
       "19                 sst2   minilm-6  64  61.065  17.624  41.926  7.230\n",
       "20  sarcastic-headlines   gte-base   4  70.943  27.981  83.097  1.904\n",
       "21  sarcastic-headlines   gte-base   8  69.141  23.593  82.253  1.751\n",
       "22  sarcastic-headlines   gte-base  16  67.402  20.739  81.366  1.641\n",
       "23  sarcastic-headlines   gte-base  32  65.618  18.782  80.450  1.566\n",
       "24  sarcastic-headlines   gte-base  64  63.953  17.460  79.515  1.529\n",
       "25  sarcastic-headlines  gte-small   4  70.130  28.301  84.158  1.887\n",
       "26  sarcastic-headlines  gte-small   8  68.210  23.654  83.334  1.736\n",
       "27  sarcastic-headlines  gte-small  16  66.826  20.573  82.468  1.629\n",
       "28  sarcastic-headlines  gte-small  32  65.113  18.643  81.579  1.555\n",
       "29  sarcastic-headlines  gte-small  64  63.536  17.233  80.674  1.510\n",
       "30  sarcastic-headlines   e5-small   4  74.539  27.119  84.632  1.630\n",
       "31  sarcastic-headlines   e5-small   8  72.881  23.682  84.032  1.536\n",
       "32  sarcastic-headlines   e5-small  16  71.436  21.241  83.404  1.463\n",
       "33  sarcastic-headlines   e5-small  32  69.647  19.992  82.756  1.417\n",
       "34  sarcastic-headlines   e5-small  64  67.725  18.867  82.096  1.391\n",
       "35  sarcastic-headlines   minilm-6   4  67.474  28.390  47.744  6.646\n",
       "36  sarcastic-headlines   minilm-6   8  65.545  23.404  45.108  6.118\n",
       "37  sarcastic-headlines   minilm-6  16  63.749  20.181  42.307  5.615\n",
       "38  sarcastic-headlines   minilm-6  32  62.192  17.515  39.393  5.137\n",
       "39  sarcastic-headlines   minilm-6  64  60.797  15.612  36.384  4.684"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>model</th>\n",
       "      <th>k</th>\n",
       "      <th>p_mean</th>\n",
       "      <th>p_std</th>\n",
       "      <th>s_mean</th>\n",
       "      <th>s_std</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>4</td>\n",
       "      <td>80.447</td>\n",
       "      <td>27.519</td>\n",
       "      <td>84.980</td>\n",
       "      <td>1.629</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>8</td>\n",
       "      <td>80.691</td>\n",
       "      <td>24.036</td>\n",
       "      <td>84.349</td>\n",
       "      <td>1.501</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>16</td>\n",
       "      <td>80.376</td>\n",
       "      <td>22.561</td>\n",
       "      <td>83.700</td>\n",
       "      <td>1.436</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>32</td>\n",
       "      <td>79.942</td>\n",
       "      <td>21.068</td>\n",
       "      <td>83.044</td>\n",
       "      <td>1.395</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>64</td>\n",
       "      <td>78.813</td>\n",
       "      <td>20.539</td>\n",
       "      <td>82.368</td>\n",
       "      <td>1.360</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>4</td>\n",
       "      <td>78.268</td>\n",
       "      <td>27.836</td>\n",
       "      <td>86.029</td>\n",
       "      <td>1.599</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>8</td>\n",
       "      <td>78.139</td>\n",
       "      <td>23.856</td>\n",
       "      <td>85.427</td>\n",
       "      <td>1.483</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>16</td>\n",
       "      <td>77.774</td>\n",
       "      <td>22.204</td>\n",
       "      <td>84.808</td>\n",
       "      <td>1.410</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>32</td>\n",
       "      <td>76.867</td>\n",
       "      <td>21.215</td>\n",
       "      <td>84.182</td>\n",
       "      <td>1.360</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>sst2</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>64</td>\n",
       "      <td>75.502</td>\n",
       "      <td>20.692</td>\n",
       "      <td>83.532</td>\n",
       "      <td>1.315</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>sst2</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>4</td>\n",
       "      <td>82.167</td>\n",
       "      <td>27.201</td>\n",
       "      <td>86.510</td>\n",
       "      <td>1.855</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>sst2</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>8</td>\n",
       "      <td>82.010</td>\n",
       "      <td>25.039</td>\n",
       "      <td>86.035</td>\n",
       "      <td>1.789</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>sst2</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>16</td>\n",
       "      <td>81.465</td>\n",
       "      <td>23.696</td>\n",
       "      <td>85.516</td>\n",
       "      <td>1.735</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>sst2</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>32</td>\n",
       "      <td>80.637</td>\n",
       "      <td>22.752</td>\n",
       "      <td>84.955</td>\n",
       "      <td>1.681</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>sst2</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>64</td>\n",
       "      <td>79.707</td>\n",
       "      <td>21.919</td>\n",
       "      <td>84.347</td>\n",
       "      <td>1.627</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>sst2</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>4</td>\n",
       "      <td>63.991</td>\n",
       "      <td>30.276</td>\n",
       "      <td>50.940</td>\n",
       "      <td>7.684</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>sst2</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>8</td>\n",
       "      <td>63.675</td>\n",
       "      <td>25.750</td>\n",
       "      <td>48.802</td>\n",
       "      <td>7.514</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>sst2</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>16</td>\n",
       "      <td>63.016</td>\n",
       "      <td>21.877</td>\n",
       "      <td>46.594</td>\n",
       "      <td>7.426</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>sst2</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>32</td>\n",
       "      <td>62.106</td>\n",
       "      <td>19.560</td>\n",
       "      <td>44.315</td>\n",
       "      <td>7.338</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>sst2</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>64</td>\n",
       "      <td>61.065</td>\n",
       "      <td>17.624</td>\n",
       "      <td>41.926</td>\n",
       "      <td>7.230</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>4</td>\n",
       "      <td>70.943</td>\n",
       "      <td>27.981</td>\n",
       "      <td>83.097</td>\n",
       "      <td>1.904</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>8</td>\n",
       "      <td>69.141</td>\n",
       "      <td>23.593</td>\n",
       "      <td>82.253</td>\n",
       "      <td>1.751</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>16</td>\n",
       "      <td>67.402</td>\n",
       "      <td>20.739</td>\n",
       "      <td>81.366</td>\n",
       "      <td>1.641</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>32</td>\n",
       "      <td>65.618</td>\n",
       "      <td>18.782</td>\n",
       "      <td>80.450</td>\n",
       "      <td>1.566</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-base</td>\n",
       "      <td>64</td>\n",
       "      <td>63.953</td>\n",
       "      <td>17.460</td>\n",
       "      <td>79.515</td>\n",
       "      <td>1.529</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>4</td>\n",
       "      <td>70.130</td>\n",
       "      <td>28.301</td>\n",
       "      <td>84.158</td>\n",
       "      <td>1.887</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>8</td>\n",
       "      <td>68.210</td>\n",
       "      <td>23.654</td>\n",
       "      <td>83.334</td>\n",
       "      <td>1.736</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>16</td>\n",
       "      <td>66.826</td>\n",
       "      <td>20.573</td>\n",
       "      <td>82.468</td>\n",
       "      <td>1.629</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>32</td>\n",
       "      <td>65.113</td>\n",
       "      <td>18.643</td>\n",
       "      <td>81.579</td>\n",
       "      <td>1.555</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>gte-small</td>\n",
       "      <td>64</td>\n",
       "      <td>63.536</td>\n",
       "      <td>17.233</td>\n",
       "      <td>80.674</td>\n",
       "      <td>1.510</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>4</td>\n",
       "      <td>74.539</td>\n",
       "      <td>27.119</td>\n",
       "      <td>84.632</td>\n",
       "      <td>1.630</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>8</td>\n",
       "      <td>72.881</td>\n",
       "      <td>23.682</td>\n",
       "      <td>84.032</td>\n",
       "      <td>1.536</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>16</td>\n",
       "      <td>71.436</td>\n",
       "      <td>21.241</td>\n",
       "      <td>83.404</td>\n",
       "      <td>1.463</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>32</td>\n",
       "      <td>69.647</td>\n",
       "      <td>19.992</td>\n",
       "      <td>82.756</td>\n",
       "      <td>1.417</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>e5-small</td>\n",
       "      <td>64</td>\n",
       "      <td>67.725</td>\n",
       "      <td>18.867</td>\n",
       "      <td>82.096</td>\n",
       "      <td>1.391</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>4</td>\n",
       "      <td>67.474</td>\n",
       "      <td>28.390</td>\n",
       "      <td>47.744</td>\n",
       "      <td>6.646</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>8</td>\n",
       "      <td>65.545</td>\n",
       "      <td>23.404</td>\n",
       "      <td>45.108</td>\n",
       "      <td>6.118</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>16</td>\n",
       "      <td>63.749</td>\n",
       "      <td>20.181</td>\n",
       "      <td>42.307</td>\n",
       "      <td>5.615</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>32</td>\n",
       "      <td>62.192</td>\n",
       "      <td>17.515</td>\n",
       "      <td>39.393</td>\n",
       "      <td>5.137</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>minilm-6</td>\n",
       "      <td>64</td>\n",
       "      <td>60.797</td>\n",
       "      <td>15.612</td>\n",
       "      <td>36.384</td>\n",
       "      <td>4.684</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                dataset      model   k  p_mean   p_std  s_mean  s_std\n",
       "0                  sst2   gte-base   4  80.447  27.519  84.980  1.629\n",
       "1                  sst2   gte-base   8  80.691  24.036  84.349  1.501\n",
       "2                  sst2   gte-base  16  80.376  22.561  83.700  1.436\n",
       "3                  sst2   gte-base  32  79.942  21.068  83.044  1.395\n",
       "4                  sst2   gte-base  64  78.813  20.539  82.368  1.360\n",
       "5                  sst2  gte-small   4  78.268  27.836  86.029  1.599\n",
       "6                  sst2  gte-small   8  78.139  23.856  85.427  1.483\n",
       "7                  sst2  gte-small  16  77.774  22.204  84.808  1.410\n",
       "8                  sst2  gte-small  32  76.867  21.215  84.182  1.360\n",
       "9                  sst2  gte-small  64  75.502  20.692  83.532  1.315\n",
       "10                 sst2   e5-small   4  82.167  27.201  86.510  1.855\n",
       "11                 sst2   e5-small   8  82.010  25.039  86.035  1.789\n",
       "12                 sst2   e5-small  16  81.465  23.696  85.516  1.735\n",
       "13                 sst2   e5-small  32  80.637  22.752  84.955  1.681\n",
       "14                 sst2   e5-small  64  79.707  21.919  84.347  1.627\n",
       "15                 sst2   minilm-6   4  63.991  30.276  50.940  7.684\n",
       "16                 sst2   minilm-6   8  63.675  25.750  48.802  7.514\n",
       "17                 sst2   minilm-6  16  63.016  21.877  46.594  7.426\n",
       "18                 sst2   minilm-6  32  62.106  19.560  44.315  7.338\n",
       "19                 sst2   minilm-6  64  61.065  17.624  41.926  7.230\n",
       "20  sarcastic-headlines   gte-base   4  70.943  27.981  83.097  1.904\n",
       "21  sarcastic-headlines   gte-base   8  69.141  23.593  82.253  1.751\n",
       "22  sarcastic-headlines   gte-base  16  67.402  20.739  81.366  1.641\n",
       "23  sarcastic-headlines   gte-base  32  65.618  18.782  80.450  1.566\n",
       "24  sarcastic-headlines   gte-base  64  63.953  17.460  79.515  1.529\n",
       "25  sarcastic-headlines  gte-small   4  70.130  28.301  84.158  1.887\n",
       "26  sarcastic-headlines  gte-small   8  68.210  23.654  83.334  1.736\n",
       "27  sarcastic-headlines  gte-small  16  66.826  20.573  82.468  1.629\n",
       "28  sarcastic-headlines  gte-small  32  65.113  18.643  81.579  1.555\n",
       "29  sarcastic-headlines  gte-small  64  63.536  17.233  80.674  1.510\n",
       "30  sarcastic-headlines   e5-small   4  74.539  27.119  84.632  1.630\n",
       "31  sarcastic-headlines   e5-small   8  72.881  23.682  84.032  1.536\n",
       "32  sarcastic-headlines   e5-small  16  71.436  21.241  83.404  1.463\n",
       "33  sarcastic-headlines   e5-small  32  69.647  19.992  82.756  1.417\n",
       "34  sarcastic-headlines   e5-small  64  67.725  18.867  82.096  1.391\n",
       "35  sarcastic-headlines   minilm-6   4  67.474  28.390  47.744  6.646\n",
       "36  sarcastic-headlines   minilm-6   8  65.545  23.404  45.108  6.118\n",
       "37  sarcastic-headlines   minilm-6  16  63.749  20.181  42.307  5.615\n",
       "38  sarcastic-headlines   minilm-6  32  62.192  17.515  39.393  5.137\n",
       "39  sarcastic-headlines   minilm-6  64  60.797  15.612  36.384  4.684"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#multiply all numbers by 100\n",
    "df[['p_mean', 'p_std', 's_mean', 's_std']] = df[['p_mean', 'p_std', 's_mean', 's_std']]\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "polarity\n",
      "Reference & - & $71.4_{21.2}$ & $81.5_{23.7}$ & $67.4_{20.7}$ & $80.4_{22.6}$ & $66.8_{20.6}$ & $77.8_{22.2}$ & $63.7_{20.2}$ & $63.0_{21.9}$ \\\\\n",
      "semantic\n",
      "Reference & - & $83.4_{1.5}$ & $85.5_{1.7}$ & $81.4_{1.6}$ & $83.7_{1.4}$ & $82.5_{1.6}$ & $84.8_{1.4}$ & $42.3_{5.6}$ & $46.6_{7.4}$ \\\\\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>p_mean</th>\n",
       "      <th>p_std</th>\n",
       "      <th>s_mean</th>\n",
       "      <th>s_std</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>model</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>e5-small</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>71.4</td>\n",
       "      <td>21.2</td>\n",
       "      <td>83.4</td>\n",
       "      <td>1.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>e5-small</th>\n",
       "      <td>sst2</td>\n",
       "      <td>81.5</td>\n",
       "      <td>23.7</td>\n",
       "      <td>85.5</td>\n",
       "      <td>1.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gte-base</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>67.4</td>\n",
       "      <td>20.7</td>\n",
       "      <td>81.4</td>\n",
       "      <td>1.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gte-base</th>\n",
       "      <td>sst2</td>\n",
       "      <td>80.4</td>\n",
       "      <td>22.6</td>\n",
       "      <td>83.7</td>\n",
       "      <td>1.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gte-small</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>66.8</td>\n",
       "      <td>20.6</td>\n",
       "      <td>82.5</td>\n",
       "      <td>1.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gte-small</th>\n",
       "      <td>sst2</td>\n",
       "      <td>77.8</td>\n",
       "      <td>22.2</td>\n",
       "      <td>84.8</td>\n",
       "      <td>1.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>minilm-6</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>63.7</td>\n",
       "      <td>20.2</td>\n",
       "      <td>42.3</td>\n",
       "      <td>5.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>minilm-6</th>\n",
       "      <td>sst2</td>\n",
       "      <td>63.0</td>\n",
       "      <td>21.9</td>\n",
       "      <td>46.6</td>\n",
       "      <td>7.4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                       dataset  p_mean  p_std  s_mean  s_std\n",
       "model                                                       \n",
       "e5-small   sarcastic-headlines    71.4   21.2    83.4    1.5\n",
       "e5-small                  sst2    81.5   23.7    85.5    1.7\n",
       "gte-base   sarcastic-headlines    67.4   20.7    81.4    1.6\n",
       "gte-base                  sst2    80.4   22.6    83.7    1.4\n",
       "gte-small  sarcastic-headlines    66.8   20.6    82.5    1.6\n",
       "gte-small                 sst2    77.8   22.2    84.8    1.4\n",
       "minilm-6   sarcastic-headlines    63.7   20.2    42.3    5.6\n",
       "minilm-6                  sst2    63.0   21.9    46.6    7.4"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>p_mean</th>\n",
       "      <th>p_std</th>\n",
       "      <th>s_mean</th>\n",
       "      <th>s_std</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>model</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>e5-small</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>71.4</td>\n",
       "      <td>21.2</td>\n",
       "      <td>83.4</td>\n",
       "      <td>1.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>e5-small</th>\n",
       "      <td>sst2</td>\n",
       "      <td>81.5</td>\n",
       "      <td>23.7</td>\n",
       "      <td>85.5</td>\n",
       "      <td>1.7</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      dataset  p_mean  p_std  s_mean  s_std\n",
       "model                                                      \n",
       "e5-small  sarcastic-headlines    71.4   21.2    83.4    1.5\n",
       "e5-small                 sst2    81.5   23.7    85.5    1.7"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>p_mean</th>\n",
       "      <th>p_std</th>\n",
       "      <th>s_mean</th>\n",
       "      <th>s_std</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>model</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>gte-base</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>67.4</td>\n",
       "      <td>20.7</td>\n",
       "      <td>81.4</td>\n",
       "      <td>1.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gte-base</th>\n",
       "      <td>sst2</td>\n",
       "      <td>80.4</td>\n",
       "      <td>22.6</td>\n",
       "      <td>83.7</td>\n",
       "      <td>1.4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      dataset  p_mean  p_std  s_mean  s_std\n",
       "model                                                      \n",
       "gte-base  sarcastic-headlines    67.4   20.7    81.4    1.6\n",
       "gte-base                 sst2    80.4   22.6    83.7    1.4"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>p_mean</th>\n",
       "      <th>p_std</th>\n",
       "      <th>s_mean</th>\n",
       "      <th>s_std</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>model</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>gte-small</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>66.8</td>\n",
       "      <td>20.6</td>\n",
       "      <td>82.5</td>\n",
       "      <td>1.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gte-small</th>\n",
       "      <td>sst2</td>\n",
       "      <td>77.8</td>\n",
       "      <td>22.2</td>\n",
       "      <td>84.8</td>\n",
       "      <td>1.4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                       dataset  p_mean  p_std  s_mean  s_std\n",
       "model                                                       \n",
       "gte-small  sarcastic-headlines    66.8   20.6    82.5    1.6\n",
       "gte-small                 sst2    77.8   22.2    84.8    1.4"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>p_mean</th>\n",
       "      <th>p_std</th>\n",
       "      <th>s_mean</th>\n",
       "      <th>s_std</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>model</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>minilm-6</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>63.7</td>\n",
       "      <td>20.2</td>\n",
       "      <td>42.3</td>\n",
       "      <td>5.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>minilm-6</th>\n",
       "      <td>sst2</td>\n",
       "      <td>63.0</td>\n",
       "      <td>21.9</td>\n",
       "      <td>46.6</td>\n",
       "      <td>7.4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      dataset  p_mean  p_std  s_mean  s_std\n",
       "model                                                      \n",
       "minilm-6  sarcastic-headlines    63.7   20.2    42.3    5.6\n",
       "minilm-6                 sst2    63.0   21.9    46.6    7.4"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import display\n",
    "\n",
    "\n",
    "k_16 = df[df.k == 16].reset_index().drop(columns=['index', 'k'])\n",
    "# group by dataset, such that we get p_mean_dataset1, p_mean_dataset2, etc.\n",
    "# set \"model\" to index:\n",
    "k_16 = k_16.set_index(['model']).sort_index().round(1)\n",
    "\n",
    "from collections import defaultdict\n",
    "latex_rows = defaultdict(list)  # For Tables 5 and 6 in Section 4.\n",
    "k_16 = k_16.sort_values(by=['model', 'dataset'])\n",
    "latex_rows[\"polarity\"] = k_16.apply(lambda x: f\"{x['p_mean']}_{{{x['p_std']}}}\", axis=1).values\n",
    "latex_rows[\"semantic\"] = k_16.apply(lambda x: f\"{x['s_mean']}_{{{x['s_std']}}}\", axis=1).values\n",
    "\n",
    "for metric, latex_row in latex_rows.items():\n",
    "    print(metric)\n",
    "    latex_row = [f\"${x}$\" for x in latex_row]\n",
    "    print(f\"Reference & - & {' & '.join(latex_row)} \\\\\\\\\")\n",
    "    \n",
    "display(k_16)\n",
    "\n",
    "for model in sorted(models.keys()):\n",
    "    model_df = k_16.loc[model]\n",
    "    model_df = model_df.round(1)\n",
    "    # for m in [\"polarity\", \"semantic\"]:\n",
    "    #     display(model_df[[f\"{m}_mean\", f\"{m}_std\"]].to_latex())\n",
    "    display(model_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>p_mean</th>\n",
       "      <th>p_std</th>\n",
       "      <th>s_mean</th>\n",
       "      <th>s_std</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>model</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>e5-small</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>71.436</td>\n",
       "      <td>21.241</td>\n",
       "      <td>83.404</td>\n",
       "      <td>1.463</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gte-base</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>67.402</td>\n",
       "      <td>20.739</td>\n",
       "      <td>81.366</td>\n",
       "      <td>1.641</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gte-small</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>66.826</td>\n",
       "      <td>20.573</td>\n",
       "      <td>82.468</td>\n",
       "      <td>1.629</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>minilm-6</th>\n",
       "      <td>sarcastic-headlines</td>\n",
       "      <td>63.749</td>\n",
       "      <td>20.181</td>\n",
       "      <td>42.307</td>\n",
       "      <td>5.615</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                       dataset  p_mean   p_std  s_mean  s_std\n",
       "model                                                        \n",
       "e5-small   sarcastic-headlines  71.436  21.241  83.404  1.463\n",
       "gte-base   sarcastic-headlines  67.402  20.739  81.366  1.641\n",
       "gte-small  sarcastic-headlines  66.826  20.573  82.468  1.629\n",
       "minilm-6   sarcastic-headlines  63.749  20.181  42.307  5.615"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "k_16[k_16.dataset == 'sarcastic-headlines']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "simcse",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
