{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2e24a01-df45-42b5-aaac-3b4b2922fcf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import torch\n",
    "from IPython.display import clear_output\n",
    "from persim import plot_diagrams\n",
    "from ripser import ripser\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from src.baseline_sep_metrics import (\n",
    "    balanced_accuracy_index,\n",
    "    kmeans_ch_score,\n",
    "    roc_auc_index,\n",
    "    thornton_separability_index,\n",
    ")\n",
    "from src.gpu_utils import roc_auc_gpu_safe\n",
    "\n",
    "\n",
    "from src.exp_runner_multiclass import run_exp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01eeb39e-8a06-41cc-b81b-3a6811355593",
   "metadata": {},
   "source": [
    "# Experiments for persistent-homology paper\n",
    "\n",
    "This notebook implements the multi-class experiments. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acc11215-bd3f-4271-9d1b-174ed12ac4c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_scores = []\n",
    "\n",
    "# the estimator we use to summarise h0 behavior\n",
    "\n",
    "\n",
    "def estimator(x):\n",
    "    return 1 - (x.mean() - x.min()) - x.var()\n",
    "\n",
    "\n",
    "datasets = {\n",
    "    \"emotion_long\": \"SetFit/emotion\",\n",
    "    #   \"patent\": \"ccdv/patent-classification\"\n",
    "    # \"finance_long\" : \"financial_phrasebank\",\n",
    "    # \"tweet_long\" : \"tweet_eval\",\n",
    "}\n",
    "\n",
    "models = {\n",
    "    # \"minilm_var\": \"all-MiniLM-L6-v2\",\n",
    "    # \"paraphrase_var\": \"sentence-transformers/paraphrase-albert-small-v2\",\n",
    "    \"tinybert_var\": \"paraphrase-TinyBERT-L6-v2\"\n",
    "}\n",
    "\n",
    "\n",
    "for short_model_name, model_name in tqdm(models.items(), desc=\"models:\"):\n",
    "    for short_name, dataset_name in tqdm(datasets.items(), desc=\"datasets:\"):\n",
    "        all_scores_df = pd.DataFrame()\n",
    "        if os.path.isfile(\n",
    "            f\"multiclass_data/multiclass_data_{short_model_name}_{short_name}.csv\"\n",
    "        ):\n",
    "            print(\"file exists; on to the next one!\")\n",
    "        # continue\n",
    "\n",
    "        for shuffle_seed in tqdm(\n",
    "            [\n",
    "                32,\n",
    "            ],\n",
    "            desc=\"seeds:\",\n",
    "        ):\n",
    "            (\n",
    "                scores,\n",
    "                roc_auc_score_df,\n",
    "                accuracy_score_df,\n",
    "                h0,\n",
    "                stacked_embeddings,\n",
    "                data_to_embed,\n",
    "            ) = run_exp(\n",
    "                shuffle_seed,\n",
    "                estimator=estimator,\n",
    "                model_name=model_name,\n",
    "                dataset_name=dataset_name,\n",
    "                norm=True,\n",
    "                d_coef=1e-1,  # for prodigy-opt\n",
    "                epochs=6,\n",
    "                produce_plots=True,\n",
    "                batch_size=64,\n",
    "            )\n",
    "            one_score_df = pd.concat(\n",
    "                [\n",
    "                    scores,\n",
    "                    roc_auc_score_df,\n",
    "                    accuracy_score_df.drop(columns=\"epoch\"),\n",
    "                ],\n",
    "                axis=1,\n",
    "            )\n",
    "            one_score_df = one_score_df.set_index(\"epoch\")\n",
    "            one_score_df[\"seed\"] = shuffle_seed\n",
    "            one_score_df[\"dataset_name\"] = dataset_name\n",
    "            one_score_df[\"model_name\"] = model_name\n",
    "            all_scores_df = pd.concat([all_scores_df, one_score_df])\n",
    "\n",
    "        # clear_output(wait=True)\n",
    "        print(\n",
    "            f\"done with {model_name} and {dataset_name}. Writing to disk ... \"\n",
    "        )\n",
    "        all_scores_df.reset_index().to_csv(\n",
    "            f\"multiclass_data/multiclass_data_{short_model_name}_{short_name}.csv\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6def5789-a791-4c1d-8439-683a3e27ffc0",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "all_scores = []\n",
    "\n",
    "estimator = lambda x: 1 - (x.mean() - x.min()) - x.var()\n",
    "\n",
    "datasets = {\n",
    "    \"emotion\": \"SetFit/emotion\",\n",
    "    # \"newsgroup\": \"SetFit/20_newsgroups\"\n",
    "}\n",
    "\n",
    "models = {\n",
    "    \"minilm\": \"all-MiniLM-L6-v2\",\n",
    "    #  \"paraphrase\": \"sentence-transformers/paraphrase-albert-small-v2\",\n",
    "    #  \"tinybert\": \"paraphrase-TinyBERT-L6-v2\"\n",
    "}\n",
    "\n",
    "\n",
    "for short_model_name, model_name in tqdm(models.items(), desc=\"models:\"):\n",
    "    for short_name, dataset_name in tqdm(datasets.items(), desc=\"datasets:\"):\n",
    "        all_scores_df = pd.DataFrame()\n",
    "        if os.path.isfile(f\"multiclass_data_{model_name}_{short_name}.csv\"):\n",
    "            print(\"file exists; on to the next one!\")\n",
    "            continue\n",
    "\n",
    "        for shuffle_seed in tqdm([32, 24, 98, 75, 9, 10, 12], desc=\"seeds:\"):\n",
    "            (\n",
    "                scores,\n",
    "                roc_auc_score_df,\n",
    "                accuracy_score_df,\n",
    "                h0,\n",
    "                stacked_embeddings,\n",
    "                data_to_embed,\n",
    "            ) = run_exp(\n",
    "                shuffle_seed,\n",
    "                estimator=estimator,\n",
    "                model_name=model_name,\n",
    "                dataset_name=dataset_name,\n",
    "                norm=True,\n",
    "                lr=1e-4,\n",
    "                epochs=7,\n",
    "                batch_size=64,\n",
    "                produce_plots=True,\n",
    "            )\n",
    "            one_score_df = pd.concat(\n",
    "                [\n",
    "                    scores,\n",
    "                    roc_auc_score_df,\n",
    "                    accuracy_score_df.drop(columns=\"epoch\"),\n",
    "                ],\n",
    "                axis=1,\n",
    "            )\n",
    "            one_score_df = one_score_df.set_index(\"epoch\")\n",
    "            one_score_df[\"seed\"] = shuffle_seed\n",
    "            one_score_df[\"dataset_name\"] = dataset_name\n",
    "            one_score_df[\"model_name\"] = model_name\n",
    "            all_scores_df = pd.concat([all_scores_df, one_score_df])\n",
    "            clear_output(wait=True)\n",
    "\n",
    "        print(\n",
    "            f\"done with {model_name} and {dataset_name}. Writing to disk ... \"\n",
    "        )\n",
    "        all_scores_df.reset_index().to_csv(\n",
    "            f\"multiclass_data_{short_model_name}_{short_name}.csv\"\n",
    "        )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
