{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Full Evaluating CPM with CEBaB\n",
    "\n",
    "This note book is based on `cebab_evaluation.ipynb`. It takes all the saved results data from running that script, and aggregate portable results. This notebook is only for sorting saved results. Nothing more.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_types = [\"BERT\", \"lstm\", \"RoBERTa\", \"gpt2\"]\n",
    "mega_results = {}\n",
    "for model in model_types:\n",
    "    with open(f'../proxy_training_results/{model}-results/results.pkl', 'rb') as f:\n",
    "        results_dict = pickle.load(f)\n",
    "    with open(f'../proxy_training_results/{model}-control-results/results.pkl', 'rb') as f:\n",
    "        control_results_dict = pickle.load(f)\n",
    "    for k, v in results_dict.items():\n",
    "        mega_results[k] = v\n",
    "    for k, v in control_results_dict.items():\n",
    "        mega_results[k] = v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "important_keys = [\n",
    "    \"seed\", \"h_dim\", \"class_num\", \n",
    "    \"control\", \"beta\", \"gemma\", \n",
    "    \"cls_dropout\", \"enc_dropout\", \n",
    "    \"model_arch\"\n",
    "]\n",
    "values = []\n",
    "for k, v in mega_results.items():\n",
    "    _values = []\n",
    "    for ik in important_keys:\n",
    "        _values.append(dict(k)[ik])\n",
    "    _values.append(v[2][\"ICaCE-L2\"].iloc[0])\n",
    "    _values.append(v[2][\"ICaCE-cosine\"].iloc[0])\n",
    "    _values.append(v[2][\"ICaCE-normdiff\"].iloc[0])\n",
    "    _values.append(v[-1].iloc[0][0])\n",
    "    values.append(_values)\n",
    "important_keys.extend([\"ICaCE-L2\", \"ICaCE-cosine\", \"ICaCE-normdiff\", \"macro-f1\"])\n",
    "df = pd.DataFrame(values, columns=important_keys)\n",
    "df = df.sort_values(by=['class_num'], ascending=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>ICaCE-L2</th>\n",
       "      <th>ICaCE-cosine</th>\n",
       "      <th>ICaCE-normdiff</th>\n",
       "      <th>macro-f1</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>class_num</th>\n",
       "      <th>control</th>\n",
       "      <th>model_arch</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"8\" valign=\"top\">2</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">False</th>\n",
       "      <th>bert-base-uncased</th>\n",
       "      <td>0.19742</td>\n",
       "      <td>0.77402</td>\n",
       "      <td>0.19654</td>\n",
       "      <td>0.973154</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gpt2</th>\n",
       "      <td>0.21030</td>\n",
       "      <td>0.73714</td>\n",
       "      <td>0.20826</td>\n",
       "      <td>0.966701</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lstm</th>\n",
       "      <td>0.25894</td>\n",
       "      <td>0.78814</td>\n",
       "      <td>0.25310</td>\n",
       "      <td>0.940496</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>roberta-base</th>\n",
       "      <td>0.19468</td>\n",
       "      <td>0.78866</td>\n",
       "      <td>0.19460</td>\n",
       "      <td>0.976074</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">True</th>\n",
       "      <th>bert-base-uncased</th>\n",
       "      <td>0.29142</td>\n",
       "      <td>0.79914</td>\n",
       "      <td>0.28974</td>\n",
       "      <td>0.973154</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gpt2</th>\n",
       "      <td>0.28668</td>\n",
       "      <td>0.75164</td>\n",
       "      <td>0.28442</td>\n",
       "      <td>0.966701</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lstm</th>\n",
       "      <td>0.27996</td>\n",
       "      <td>0.87750</td>\n",
       "      <td>0.27278</td>\n",
       "      <td>0.940496</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>roberta-base</th>\n",
       "      <td>0.28672</td>\n",
       "      <td>0.82220</td>\n",
       "      <td>0.28656</td>\n",
       "      <td>0.976074</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"8\" valign=\"top\">3</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">False</th>\n",
       "      <th>bert-base-uncased</th>\n",
       "      <td>0.40952</td>\n",
       "      <td>0.61886</td>\n",
       "      <td>0.33622</td>\n",
       "      <td>0.818143</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gpt2</th>\n",
       "      <td>0.44150</td>\n",
       "      <td>0.62222</td>\n",
       "      <td>0.34850</td>\n",
       "      <td>0.802804</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lstm</th>\n",
       "      <td>0.51288</td>\n",
       "      <td>0.65658</td>\n",
       "      <td>0.43578</td>\n",
       "      <td>0.749421</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>roberta-base</th>\n",
       "      <td>0.43158</td>\n",
       "      <td>0.63954</td>\n",
       "      <td>0.35934</td>\n",
       "      <td>0.831138</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">True</th>\n",
       "      <th>bert-base-uncased</th>\n",
       "      <td>0.54710</td>\n",
       "      <td>0.73488</td>\n",
       "      <td>0.52492</td>\n",
       "      <td>0.818143</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gpt2</th>\n",
       "      <td>0.51930</td>\n",
       "      <td>0.68638</td>\n",
       "      <td>0.47088</td>\n",
       "      <td>0.802804</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lstm</th>\n",
       "      <td>0.52990</td>\n",
       "      <td>0.78284</td>\n",
       "      <td>0.45884</td>\n",
       "      <td>0.749421</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>roberta-base</th>\n",
       "      <td>0.54304</td>\n",
       "      <td>0.73768</td>\n",
       "      <td>0.52848</td>\n",
       "      <td>0.831138</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"8\" valign=\"top\">5</th>\n",
       "      <th rowspan=\"4\" valign=\"top\">False</th>\n",
       "      <th>bert-base-uncased</th>\n",
       "      <td>0.63130</td>\n",
       "      <td>0.46300</td>\n",
       "      <td>0.43188</td>\n",
       "      <td>0.695808</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gpt2</th>\n",
       "      <td>0.54534</td>\n",
       "      <td>0.48694</td>\n",
       "      <td>0.35918</td>\n",
       "      <td>0.653323</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lstm</th>\n",
       "      <td>0.70378</td>\n",
       "      <td>0.53916</td>\n",
       "      <td>0.53582</td>\n",
       "      <td>0.596330</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>roberta-base</th>\n",
       "      <td>0.67350</td>\n",
       "      <td>0.47492</td>\n",
       "      <td>0.48934</td>\n",
       "      <td>0.703218</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">True</th>\n",
       "      <th>bert-base-uncased</th>\n",
       "      <td>0.79230</td>\n",
       "      <td>0.70180</td>\n",
       "      <td>0.70162</td>\n",
       "      <td>0.695808</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gpt2</th>\n",
       "      <td>0.65612</td>\n",
       "      <td>0.64920</td>\n",
       "      <td>0.55032</td>\n",
       "      <td>0.653323</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>lstm</th>\n",
       "      <td>0.74136</td>\n",
       "      <td>0.72098</td>\n",
       "      <td>0.57978</td>\n",
       "      <td>0.596330</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>roberta-base</th>\n",
       "      <td>0.82404</td>\n",
       "      <td>0.63068</td>\n",
       "      <td>0.78400</td>\n",
       "      <td>0.703218</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                     ICaCE-L2  ICaCE-cosine  ICaCE-normdiff  \\\n",
       "class_num control model_arch                                                  \n",
       "2         False   bert-base-uncased   0.19742       0.77402         0.19654   \n",
       "                  gpt2                0.21030       0.73714         0.20826   \n",
       "                  lstm                0.25894       0.78814         0.25310   \n",
       "                  roberta-base        0.19468       0.78866         0.19460   \n",
       "          True    bert-base-uncased   0.29142       0.79914         0.28974   \n",
       "                  gpt2                0.28668       0.75164         0.28442   \n",
       "                  lstm                0.27996       0.87750         0.27278   \n",
       "                  roberta-base        0.28672       0.82220         0.28656   \n",
       "3         False   bert-base-uncased   0.40952       0.61886         0.33622   \n",
       "                  gpt2                0.44150       0.62222         0.34850   \n",
       "                  lstm                0.51288       0.65658         0.43578   \n",
       "                  roberta-base        0.43158       0.63954         0.35934   \n",
       "          True    bert-base-uncased   0.54710       0.73488         0.52492   \n",
       "                  gpt2                0.51930       0.68638         0.47088   \n",
       "                  lstm                0.52990       0.78284         0.45884   \n",
       "                  roberta-base        0.54304       0.73768         0.52848   \n",
       "5         False   bert-base-uncased   0.63130       0.46300         0.43188   \n",
       "                  gpt2                0.54534       0.48694         0.35918   \n",
       "                  lstm                0.70378       0.53916         0.53582   \n",
       "                  roberta-base        0.67350       0.47492         0.48934   \n",
       "          True    bert-base-uncased   0.79230       0.70180         0.70162   \n",
       "                  gpt2                0.65612       0.64920         0.55032   \n",
       "                  lstm                0.74136       0.72098         0.57978   \n",
       "                  roberta-base        0.82404       0.63068         0.78400   \n",
       "\n",
       "                                     macro-f1  \n",
       "class_num control model_arch                   \n",
       "2         False   bert-base-uncased  0.973154  \n",
       "                  gpt2               0.966701  \n",
       "                  lstm               0.940496  \n",
       "                  roberta-base       0.976074  \n",
       "          True    bert-base-uncased  0.973154  \n",
       "                  gpt2               0.966701  \n",
       "                  lstm               0.940496  \n",
       "                  roberta-base       0.976074  \n",
       "3         False   bert-base-uncased  0.818143  \n",
       "                  gpt2               0.802804  \n",
       "                  lstm               0.749421  \n",
       "                  roberta-base       0.831138  \n",
       "          True    bert-base-uncased  0.818143  \n",
       "                  gpt2               0.802804  \n",
       "                  lstm               0.749421  \n",
       "                  roberta-base       0.831138  \n",
       "5         False   bert-base-uncased  0.695808  \n",
       "                  gpt2               0.653323  \n",
       "                  lstm               0.596330  \n",
       "                  roberta-base       0.703218  \n",
       "          True    bert-base-uncased  0.695808  \n",
       "                  gpt2               0.653323  \n",
       "                  lstm               0.596330  \n",
       "                  roberta-base       0.703218  "
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "group_keys= [\n",
    "    \"class_num\", \n",
    "    \"control\",\n",
    "    \"model_arch\"\n",
    "]\n",
    "df.groupby(group_keys)[[\"ICaCE-L2\", \"ICaCE-cosine\", \"ICaCE-normdiff\", \"macro-f1\"]].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
