{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "import sys\n",
    "sys.path.insert(0, \"SentEval\")\n",
    "import senteval\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH_TO_SENTEVAL = \"SentEval\"\n",
    "PATH_TO_DATA = \"SentEval/data\"\n",
    "\n",
    "params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, \"kfold\": 3}\n",
    "params_senteval['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 128, 'tenacity': 5, 'epoch_size': 3}\n",
    "def prepare(params, samples):\n",
    "    pass\n",
    "\n",
    "def eval_model(model_name):\n",
    "    print(f\"Evaluating {model_name}\")\n",
    "    model = SentenceTransformer(model_name)\n",
    "    model = model.to(\"cuda\")\n",
    "\n",
    "    # https://github.com/UKPLab/sentence-transformers/issues/50#issuecomment-566452390\n",
    "    def batcher(params, batch):\n",
    "        sentences = []\n",
    "        for sample in batch:\n",
    "            untoken = ' '.join(sample).lower()\n",
    "            if untoken == '':\n",
    "                untoken = '-'\n",
    "            sentences.append(untoken)\n",
    "        return model.encode(sentences)\n",
    "\n",
    "    se = senteval.engine.SE(params_senteval, batcher, prepare)\n",
    "    transfer_tasks = ['MR', 'CR', 'SUBJ', 'MPQA', 'SICKEntailment', 'SST2', 'TREC', 'MRPC']\n",
    "    results = se.eval(transfer_tasks)\n",
    "    print(results)\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.insert(0, \"..\")\n",
    "\n",
    "from util import get_setfit_models\n",
    "\n",
    "eval_results = {}\n",
    "\n",
    "for setfit_model in get_setfit_models():\n",
    "    eval_results[setfit_model] = eval_model(setfit_model)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save dict to file:\n",
    "# import json\n",
    "import json\n",
    "\n",
    "with open('eval_results.json', 'w') as fp:\n",
    "    json.dump(eval_results, fp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_results = {}\n",
    "\n",
    "for model_name, metrics in eval_results.items():\n",
    "    model_acc = {}\n",
    "    for metric_name, metric in metrics.items():\n",
    "        model_acc[metric_name] = metric[\"acc\"]\n",
    "    acc_results[model_name] = model_acc\n",
    "acc_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.DataFrame(acc_results).T\n",
    "# create an \"avg\" column of the mean of all the accuracies\n",
    "df[\"avg\"] = df.mean(axis=1).round(2)\n",
    "# set index col as \"model\"\n",
    "df = df.reset_index().rename(columns={\"index\": \"model\"})\n",
    "df.to_csv(\"setfit_senteval_results.csv\", index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sim",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
