{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "from model_evaluator import evaluate_model\n",
    "from util import get_setfit_models, get_datasets, get_train_test, get_original_datasets, get_models\n",
    "import torch\n",
    "torch.cuda.set_device(1)\n",
    "\n",
    "original_data = get_original_datasets()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pre-encode all data for baselines\n",
    "baseline_encodings = {}\n",
    "train = {}\n",
    "test = {}\n",
    "\n",
    "for baseline_name, baseline_model in get_models().items():\n",
    "    baseline_sentence_transformer = SentenceTransformer(baseline_model)\n",
    "\n",
    "    for dataset in get_datasets():\n",
    "        print(f\"Loading {baseline_name} for {dataset}\")\n",
    "        train_data, test_data = get_train_test(original_data, dataset_name=dataset)\n",
    "        train[dataset] = train_data\n",
    "        test[dataset] = test_data\n",
    "\n",
    "\n",
    "        baseline_encodings[(baseline_name, dataset)] = {\n",
    "            'train': baseline_sentence_transformer.encode(train_data.text.tolist()),\n",
    "            'test': baseline_sentence_transformer.encode(test_data.text.tolist())\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "all_scores = {}\n",
    "\n",
    "for setfit_model in get_setfit_models():\n",
    "    sm = SentenceTransformer(setfit_model).to('cuda')\n",
    "    _baseline = \"-\".join(setfit_model.split(\"/\")[-1].split(\"-\")[:2])\n",
    "    print(f\"Source model: {setfit_model}\")\n",
    "    print(f\"Baseline is {_baseline}\")\n",
    "    for dataset in get_datasets():\n",
    "        print(f\"Eval dataset {dataset}\")\n",
    "        scores = evaluate_model(\n",
    "            model=sm,\n",
    "            name=setfit_model,\n",
    "            train_data=train[dataset],\n",
    "            test_data=test[dataset],\n",
    "            reference_train_emb=baseline_encodings[(_baseline, dataset)]['train'],\n",
    "            reference_test_emb=baseline_encodings[(_baseline, dataset)]['test'],\n",
    "            k=16,\n",
    "            verbose=True\n",
    "        )\n",
    "        for k,v in scores.items():\n",
    "            print(k)\n",
    "            mean = np.mean(v)\n",
    "            std = np.std(v)\n",
    "            print(f\"{k}: {mean} +- {std}\")\n",
    "\n",
    "        all_scores[(setfit_model, dataset)] = scores\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "paper_formatting = {}\n",
    "\n",
    "for (model, dataset), metrics in all_scores.items():\n",
    "    is_sst = \"sst\" in model and \"sst\" in dataset\n",
    "    is_sarc = \"sarc\" in model and \"sarc\" in dataset\n",
    "    if not is_sst and not is_sarc:\n",
    "        continue\n",
    "    model_type = \"-\".join(model.split(\"/\")[-1].split(\"-\")[:2])\n",
    "    for m, values in metrics.items():\n",
    "        # filter for ease of reading\n",
    "        _mean = np.mean(values)\n",
    "        _std = np.std(values)\n",
    "        math_str = f\"${_mean.round(1)}_{{{_std.round(1)}}}$\"\n",
    "        paper_formatting[(model_type, m, dataset)] = math_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "table_rows = defaultdict(list)\n",
    "\n",
    "for (model_name, metric, dataset) in sorted(paper_formatting.keys()):\n",
    "    table_rows[metric].append(paper_formatting[(model_name, metric, dataset)])\n",
    "\n",
    "LOSS = \"Cosine (SetFit)\"\n",
    "for metric, table_row in table_rows.items():\n",
    "    if \"_\" in metric:\n",
    "        continue\n",
    "    print(metric)\n",
    "    latex = f\"{LOSS} & - & \" + \" & \".join(table_row) + \" \\\\\\\\\"\n",
    "    print(latex)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sim",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
