{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c686b411-f360-4720-8651-34243c132699",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "from src.exp_runner_binary import run_exp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b55d201",
   "metadata": {},
   "source": [
    "# Binary Classification experiments\n",
    "\n",
    "With this notebook, you can conduct experiments with binary classification and tracking the evolution of persistence times during LM fine-tuning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68949bed-b61c-4e13-87e7-ae78a610a16b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# uncomment models and data to add them to the large run in the next cell.\n",
    "\n",
    "\n",
    "models = {\n",
    "    # \"minilm\": \"all-MiniLM-L6-v2\",\n",
    "    # \"paramini\": \"paraphrase-MiniLM-L3-v2\",\n",
    "    # \"albert\": \"sentence-transformers/paraphrase-albert-small-v2\",\n",
    "    \"tinybert_exp\": \"paraphrase-TinyBERT-L6-v2\",\n",
    "    # \"marco\": \"sentence-transformers/msmarco-distilbert-cos-v5\",\n",
    "}\n",
    "\n",
    "datasets = {\n",
    "    # \"sst2\": \"SetFit/sst2\",\n",
    "    #  \"secr\": \"SetFit/SentEval-CR\",\n",
    "    # \"polarity\": \"SetFit/amazon_polarity\",  # polarity is a large dataset\n",
    "    \"counterfactual\": \"SetFit/amazon_counterfactual\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c84377e4-fae1-45cd-9c54-abc27a09308c",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_scores = []\n",
    "\n",
    "# the estimator we use to summarise h0 behavior\n",
    "def estimator(x):\n",
    "    return 1 - (x.mean() - x.min()) - x.var()\n",
    "\n",
    "\n",
    "for short_model_name, model_name in tqdm(models.items(), desc=\"models:\"):\n",
    "    for short_name, dataset_name in tqdm(datasets.items(), desc=\"datasets:\"):\n",
    "        all_scores_df = pd.DataFrame()\n",
    "        if os.path.isfile(f\"data_{short_model_name}_{short_name}.csv\"):\n",
    "            print(\"file exists; on to the next one!\")\n",
    "            continue\n",
    "\n",
    "        for shuffle_seed in tqdm(\n",
    "            [\n",
    "                32,  # 24, 98, 75, 9, 10, 12 # seeds used to shuffle the data for each run\n",
    "            ],\n",
    "            desc=\"seeds:\",\n",
    "        ):\n",
    "            (\n",
    "                scores,\n",
    "                roc_auc_score_df,\n",
    "                accuracy_score_df,\n",
    "                h0_normed,\n",
    "                stacked_embeddings,\n",
    "                data_to_embed,\n",
    "            ) = run_exp(\n",
    "                shuffle_seed,\n",
    "                estimator=estimator,\n",
    "                model_name=model_name,\n",
    "                dataset_name=dataset_name,\n",
    "                normalize=True,\n",
    "                d_coef=1e-1,\n",
    "                epochs=7,\n",
    "                produce_plots=True,\n",
    "                batch_size=32,\n",
    "            )\n",
    "\n",
    "            one_score_df = pd.concat(\n",
    "                [\n",
    "                    scores,\n",
    "                    roc_auc_score_df,\n",
    "                    accuracy_score_df.drop(columns=\"epoch\"),\n",
    "                ],\n",
    "                axis=1,\n",
    "            )\n",
    "            one_score_df = one_score_df.set_index(\"epoch\")\n",
    "            one_score_df[\"seed\"] = shuffle_seed\n",
    "            one_score_df[\"dataset_name\"] = dataset_name\n",
    "            one_score_df[\"model_name\"] = model_name\n",
    "            all_scores_df = pd.concat([all_scores_df, one_score_df])\n",
    "            # clear_output(wait=True)\n",
    "\n",
    "        print(\n",
    "            f\"done with {model_name} and {dataset_name}. Writing to disk ... \"\n",
    "        )\n",
    "        all_scores_df.reset_index().to_csv(\n",
    "            f\"data_{short_model_name}_{short_name}.csv\"\n",
    "        )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  },
  "vscode": {
   "interpreter": {
    "hash": "b1641d34efa0dd84f57e86b47eaa36befe99074f826959190d2792bf4abfe54d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
