{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluating CPM with CEBaB\n",
    "Our Causal Proxy Model (CPM) is for providing concept-based explanation for a blackbox model. We use newly developed CEBaB benchmark for comparing CPM with other concept-based explanation methods. This notebook evaluates CPM with CEBaB benchmark under different settings.\n",
    "\n",
    "More importantly, we introduce new baselines for CPM as well. Formally, we evaluate the blackbox model with interchange intervention evaluation (which will be introduced in details below).\n",
    "\n",
    "In this notebook, we can evaluate the following models:\n",
    "- CPM: `BERT-base-uncased`\n",
    "- CPM: `RoBERTa-base`\n",
    "- CPM: `GPT2`\n",
    "- CPM: `LSTM+GloVe`\n",
    "- CPM: `Control`\n",
    "\n",
    "and we can evaluate with the following conditions:\n",
    "- 2-class\n",
    "- 3-class\n",
    "- 5-class"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Imports and Libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from libs import *\n",
    "from modelings.modelings_bert import *\n",
    "from modelings.modelings_roberta import *\n",
    "from modelings.modelings_gpt2 import *\n",
    "from modelings.modelings_lstm import *\n",
    "\"\"\"\n",
    "For evaluate, we use a single random seed, as\n",
    "the models are trained with 5 different seeds\n",
    "already.\n",
    "\"\"\"\n",
    "_ = random.seed(123)\n",
    "_ = np.random.seed(123)\n",
    "_ = torch.manual_seed(123)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Main evaluate script"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "# rerunning the boxes below will only append stuffs to the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "The following blocks will run CEBaB benchmark in\n",
    "all the combinations of the following conditions.\n",
    "\"\"\"\n",
    "grid = {\n",
    "    \"eval_split\": [\"test\"],\n",
    "    # dev,test\n",
    "    \"control\": [\"ks\"],\n",
    "    # baseline-random,baseline-blackbox,hdims,layers,ks,approximate,ablation\n",
    "    \"seed\": [42, 66, 77],\n",
    "    # 42, 66, 77\n",
    "    \"h_dim\": [64],\n",
    "    # 1,16,64,128,192\n",
    "    # 1,16,64,75\n",
    "    \"interchange_layer\" : [1],\n",
    "    # 0,1; 2,4,6,8,10,12\n",
    "    \"class_num\": [5],\n",
    "    \"k\" : [19684], \n",
    "    # 0;10,100,500,1000,3000,6000,9848,19684\n",
    "    \"alpha\" : [1.0],\n",
    "    # 0.0,1.0\n",
    "    \"beta\" : [1.0],\n",
    "    # 0.0,1.0\n",
    "    \"gemma\" : [3.0],\n",
    "    # 0.0,3.0\n",
    "    \"model_arch\" : [\"lstm\"],\n",
    "    # lstm, bert-base-uncased, roberta-base, gpt2\n",
    "    \"lr\" : [\"0.001\"],\n",
    "    # 8e-05; 0.001\n",
    "    \"counterfactual_type\" : [\"true\"],\n",
    "    # approximate,true\n",
    "    \"random_source\" : [True]\n",
    "}\n",
    "\n",
    "keys, values = zip(*grid.items())\n",
    "permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]\n",
    "\n",
    "device = 'cuda:2'\n",
    "batch_size = 32\n",
    "\n",
    "if grid[\"control\"][0] == \"hdims\" or grid[\"control\"][0] == \"layers\":\n",
    "    assert grid[\"eval_split\"][0] == \"dev\"\n",
    "else:\n",
    "    assert grid[\"eval_split\"][0] == \"test\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running for this setting:  (('eval_split', 'test'), ('control', 'ks'), ('seed', 42), ('h_dim', 64), ('interchange_layer', 1), ('class_num', 5), ('k', 19684), ('alpha', 1.0), ('beta', 1.0), ('gemma', 3.0), ('model_arch', 'lstm'), ('lr', '0.001'), ('counterfactual_type', 'true'))\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration CEBaB--CEBaB-ccd674d249652bd4\n",
      "Reusing dataset parquet (../train_cache/CEBaB___parquet/CEBaB--CEBaB-ccd674d249652bd4/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "212d060725fa459ebfdbfd5e748d441e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dropping no majority reviews: 16.6382% of train_exclusive dataset.\n",
      "Dropping no majority reviews: 16.03% of train_inclusive dataset.\n",
      "intervention_h_dim=64\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 124/124 [00:01<00:00, 74.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running for this setting:  (('eval_split', 'test'), ('control', 'ks'), ('seed', 66), ('h_dim', 64), ('interchange_layer', 1), ('class_num', 5), ('k', 19684), ('alpha', 1.0), ('beta', 1.0), ('gemma', 3.0), ('model_arch', 'lstm'), ('lr', '0.001'), ('counterfactual_type', 'true'))\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration CEBaB--CEBaB-ccd674d249652bd4\n",
      "Reusing dataset parquet (../train_cache/CEBaB___parquet/CEBaB--CEBaB-ccd674d249652bd4/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6b671abfdfce4569baabd9c427e37208",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dropping no majority reviews: 16.6382% of train_exclusive dataset.\n",
      "Dropping no majority reviews: 16.03% of train_inclusive dataset.\n",
      "intervention_h_dim=64\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 124/124 [00:01<00:00, 72.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running for this setting:  (('eval_split', 'test'), ('control', 'ks'), ('seed', 77), ('h_dim', 64), ('interchange_layer', 1), ('class_num', 5), ('k', 19684), ('alpha', 1.0), ('beta', 1.0), ('gemma', 3.0), ('model_arch', 'lstm'), ('lr', '0.001'), ('counterfactual_type', 'true'))\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration CEBaB--CEBaB-ccd674d249652bd4\n",
      "Reusing dataset parquet (../train_cache/CEBaB___parquet/CEBaB--CEBaB-ccd674d249652bd4/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a1abac9f2e0941d2a4c9c35bcb3b3b24",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dropping no majority reviews: 16.6382% of train_exclusive dataset.\n",
      "Dropping no majority reviews: 16.03% of train_inclusive dataset.\n",
      "intervention_h_dim=64\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 124/124 [00:01<00:00, 70.78it/s]\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(permutations_dicts)):\n",
    "    \n",
    "    eval_split=permutations_dicts[i][\"eval_split\"]\n",
    "    seed=permutations_dicts[i][\"seed\"]\n",
    "    class_num=permutations_dicts[i][\"class_num\"]\n",
    "    alpha=permutations_dicts[i][\"alpha\"]\n",
    "    beta=permutations_dicts[i][\"beta\"]\n",
    "    gemma=permutations_dicts[i][\"gemma\"]\n",
    "    h_dim=permutations_dicts[i][\"h_dim\"]\n",
    "    dataset_type = f'{class_num}-way'\n",
    "    control=permutations_dicts[i][\"control\"]\n",
    "    model_arch=permutations_dicts[i][\"model_arch\"]\n",
    "    k=permutations_dicts[i][\"k\"]\n",
    "    interchange_layer=permutations_dicts[i][\"interchange_layer\"]\n",
    "    lr=permutations_dicts[i][\"lr\"]\n",
    "    counterfactual_type=permutations_dicts[i][\"counterfactual_type\"]\n",
    "    random_source=permutations_dicts[i][\"random_source\"]\n",
    "    \n",
    "    if model_arch == \"bert-base-uncased\":\n",
    "        model_path = \"BERT\"\n",
    "        model_module = BERTForCEBaB\n",
    "        explainer_module = CausalProxyModelForBERT\n",
    "    elif model_arch == \"roberta-base\":\n",
    "        model_path = \"RoBERTa\" \n",
    "        model_module = RoBERTaForCEBaB\n",
    "        explainer_module = CausalProxyModelForRoBERTa\n",
    "    elif model_arch == \"gpt2\":\n",
    "        model_path = \"gpt2\"\n",
    "        model_module = GPT2ForCEBaB\n",
    "        explainer_module = CausalProxyModelForGPT2\n",
    "    elif model_arch == \"lstm\":\n",
    "        model_path = \"lstm\"\n",
    "        model_module = LSTMForCEBaB\n",
    "        explainer_module = CausalProxyModelForLSTM\n",
    "    model_path += f\"-{control}\"\n",
    "    grid_conditions=(\n",
    "        (\"eval_split\", eval_split),\n",
    "        (\"control\", control),\n",
    "        (\"seed\", seed),\n",
    "        (\"h_dim\", h_dim),\n",
    "        (\"interchange_layer\", interchange_layer),\n",
    "        (\"class_num\", class_num),\n",
    "        (\"k\", k),\n",
    "        (\"alpha\", alpha),\n",
    "        (\"beta\", beta),\n",
    "        (\"gemma\", gemma),\n",
    "        (\"model_arch\", model_arch),\n",
    "        (\"lr\", lr),\n",
    "        (\"counterfactual_type\", counterfactual_type)\n",
    "    )\n",
    "    print(\"Running for this setting: \", grid_conditions)\n",
    "\n",
    "    blackbox_model_path = f'CEBaB/{model_arch}.CEBaB.sa.'\\\n",
    "                          f'{class_num}-class.exclusive.seed_{seed}'\n",
    "    cpm_model_path = f'../proxy_training_results/{model_path}/'\\\n",
    "                     f'cebab.alpha.{alpha}.beta.{beta}.gemma.{gemma}.'\\\n",
    "                     f'lr.{lr}.dim.{h_dim}.hightype.{model_arch}.'\\\n",
    "                     f'CEBaB.cls.dropout.0.1.enc.dropout.0.1.counter.type.'\\\n",
    "                     f'{counterfactual_type}.k.{k}.int.layer.{interchange_layer}.'\\\n",
    "                     f'seed_{seed}/'\n",
    "\n",
    "    # load data from HF\n",
    "    cebab = datasets.load_dataset(\n",
    "        'CEBaB/CEBaB', use_auth_token=True,\n",
    "        cache_dir=\"../train_cache/\"\n",
    "    )\n",
    "\n",
    "    train, dev, test = preprocess_hf_dataset_inclusive(\n",
    "        cebab, verbose=1, dataset_type=dataset_type\n",
    "    )\n",
    "\n",
    "    eval_dataset = dev if eval_split == 'dev' else test\n",
    "\n",
    "    tf_model = model_module(\n",
    "        blackbox_model_path, \n",
    "        device=device, \n",
    "        batch_size=batch_size\n",
    "    )\n",
    "    explainer = explainer_module(\n",
    "        blackbox_model_path,\n",
    "        cpm_model_path, \n",
    "        device=device, \n",
    "        batch_size=batch_size,\n",
    "        intervention_h_dim=h_dim,\n",
    "        random_source=random_source\n",
    "    )\n",
    "\n",
    "    result_per_example, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, \\\n",
    "    CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, \\\n",
    "    ACaCE_per_aspect, performance_report = cebab_pipeline(\n",
    "        tf_model, explainer, \n",
    "        train, eval_dataset,\n",
    "        seed, k, dataset_type=dataset_type, \n",
    "        shorten_model_name=False, \n",
    "        train_setting=\"inclusive\", \n",
    "        approximate=False if counterfactual_type == \"true\" else True\n",
    "    )\n",
    "    \n",
    "    results[grid_conditions] = (\n",
    "        result_per_example, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, \\\n",
    "        CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, \\\n",
    "        ACaCE_per_aspect, performance_report\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Show your results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>eval_split</th>\n",
       "      <th>control</th>\n",
       "      <th>seed</th>\n",
       "      <th>h_dim</th>\n",
       "      <th>interchange_layer</th>\n",
       "      <th>class_num</th>\n",
       "      <th>k</th>\n",
       "      <th>beta</th>\n",
       "      <th>gemma</th>\n",
       "      <th>model_arch</th>\n",
       "      <th>lr</th>\n",
       "      <th>counterfactual_type</th>\n",
       "      <th>ICaCE-L2</th>\n",
       "      <th>ICaCE-cosine</th>\n",
       "      <th>ICaCE-normdiff</th>\n",
       "      <th>macro-f1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>test</td>\n",
       "      <td>ks</td>\n",
       "      <td>42</td>\n",
       "      <td>64</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>19684</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>lstm</td>\n",
       "      <td>0.001</td>\n",
       "      <td>true</td>\n",
       "      <td>0.6804</td>\n",
       "      <td>0.5774</td>\n",
       "      <td>0.5126</td>\n",
       "      <td>0.592428</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>test</td>\n",
       "      <td>ks</td>\n",
       "      <td>66</td>\n",
       "      <td>64</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>19684</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>lstm</td>\n",
       "      <td>0.001</td>\n",
       "      <td>true</td>\n",
       "      <td>0.7265</td>\n",
       "      <td>0.5942</td>\n",
       "      <td>0.5212</td>\n",
       "      <td>0.584527</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>test</td>\n",
       "      <td>ks</td>\n",
       "      <td>77</td>\n",
       "      <td>64</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>19684</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>lstm</td>\n",
       "      <td>0.001</td>\n",
       "      <td>true</td>\n",
       "      <td>0.6777</td>\n",
       "      <td>0.5776</td>\n",
       "      <td>0.5133</td>\n",
       "      <td>0.596698</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  eval_split control  seed  h_dim  interchange_layer  class_num      k  beta  \\\n",
       "0       test      ks    42     64                  1          5  19684   1.0   \n",
       "1       test      ks    66     64                  1          5  19684   1.0   \n",
       "2       test      ks    77     64                  1          5  19684   1.0   \n",
       "\n",
       "   gemma model_arch     lr counterfactual_type  ICaCE-L2  ICaCE-cosine  \\\n",
       "0    3.0       lstm  0.001                true    0.6804        0.5774   \n",
       "1    3.0       lstm  0.001                true    0.7265        0.5942   \n",
       "2    3.0       lstm  0.001                true    0.6777        0.5776   \n",
       "\n",
       "   ICaCE-normdiff  macro-f1  \n",
       "0          0.5126  0.592428  \n",
       "1          0.5212  0.584527  \n",
       "2          0.5133  0.596698  "
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "important_keys = [\n",
    "    \"eval_split\",\n",
    "    \"control\", \"seed\", \n",
    "    \"h_dim\", \"interchange_layer\", \n",
    "    \"class_num\", \"k\", \n",
    "    \"beta\", \"gemma\", \n",
    "    \"model_arch\", \"lr\", \"counterfactual_type\"\n",
    "]\n",
    "values = []\n",
    "for k, v in results.items():\n",
    "    _values = []\n",
    "    for ik in important_keys:\n",
    "        _values.append(dict(k)[ik])\n",
    "    _values.append(v[2][\"ICaCE-L2\"].iloc[0])\n",
    "    _values.append(v[2][\"ICaCE-cosine\"].iloc[0])\n",
    "    _values.append(v[2][\"ICaCE-normdiff\"].iloc[0])\n",
    "    _values.append(v[-1].iloc[0][0])\n",
    "    values.append(_values)\n",
    "important_keys.extend([\"ICaCE-L2\", \"ICaCE-cosine\", \"ICaCE-normdiff\", \"macro-f1\"])\n",
    "df = pd.DataFrame(values, columns=important_keys)\n",
    "df.sort_values(by=['seed', 'interchange_layer', 'h_dim'], ascending=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Aggregating results in the way you like"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "groupby_key = \"k\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>k</th>\n",
       "      <th>seed</th>\n",
       "      <th>h_dim</th>\n",
       "      <th>interchange_layer</th>\n",
       "      <th>class_num</th>\n",
       "      <th>beta</th>\n",
       "      <th>gemma</th>\n",
       "      <th>ICaCE-L2</th>\n",
       "      <th>ICaCE-cosine</th>\n",
       "      <th>ICaCE-normdiff</th>\n",
       "      <th>macro-f1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>19684</td>\n",
       "      <td>61.666667</td>\n",
       "      <td>64.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.694867</td>\n",
       "      <td>0.583067</td>\n",
       "      <td>0.5157</td>\n",
       "      <td>0.591218</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       k       seed  h_dim  interchange_layer  class_num  beta  gemma  \\\n",
       "0  19684  61.666667   64.0                1.0        5.0   1.0    3.0   \n",
       "\n",
       "   ICaCE-L2  ICaCE-cosine  ICaCE-normdiff  macro-f1  \n",
       "0  0.694867      0.583067          0.5157  0.591218  "
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.groupby([groupby_key], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>k</th>\n",
       "      <th>seed</th>\n",
       "      <th>h_dim</th>\n",
       "      <th>interchange_layer</th>\n",
       "      <th>class_num</th>\n",
       "      <th>beta</th>\n",
       "      <th>gemma</th>\n",
       "      <th>ICaCE-L2</th>\n",
       "      <th>ICaCE-cosine</th>\n",
       "      <th>ICaCE-normdiff</th>\n",
       "      <th>macro-f1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>19684</td>\n",
       "      <td>17.897858</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.027429</td>\n",
       "      <td>0.009642</td>\n",
       "      <td>0.004776</td>\n",
       "      <td>0.006175</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       k       seed  h_dim  interchange_layer  class_num  beta  gemma  \\\n",
       "0  19684  17.897858    0.0                0.0        0.0   0.0    0.0   \n",
       "\n",
       "   ICaCE-L2  ICaCE-cosine  ICaCE-normdiff  macro-f1  \n",
       "0  0.027429      0.009642        0.004776  0.006175  "
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.groupby([groupby_key], as_index=False).std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 215,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_arch = grid[\"model_arch\"][0]\n",
    "control = grid[\"control\"][0]\n",
    "df.groupby([groupby_key], as_index=False).mean().to_csv(\n",
    "    f\"./csv-results/{model_arch}-{control}-mean.csv\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 216,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.groupby([groupby_key], as_index=False).std().to_csv(\n",
    "    f\"./csv-results/{model_arch}-{control}-std.csv\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Save your results somewhere and load again to tabularize your results altogether"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_name = input(\"Plase give an output file name: \")\n",
    "\n",
    "output_directory = f'../proxy_training_results/{model_path}/'\n",
    "output_filename = os.path.join(output_directory, f'{output_name}.pkl')\n",
    "print(\"Writing to file: \", output_filename)\n",
    "with open(output_filename, 'wb') as f:\n",
    "    pickle.dump(results, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
