{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from libs import *\n",
    "from modelings.modelings_bert import *\n",
    "from modelings.modelings_roberta import *\n",
    "from modelings.modelings_gpt2 import *\n",
    "from modelings.modelings_lstm import *\n",
    "\"\"\"\n",
    "For evaluate, we use a single random seed, as\n",
    "the models are trained with 5 different seeds\n",
    "already.\n",
    "\"\"\"\n",
    "_ = random.seed(123)\n",
    "_ = np.random.seed(123)\n",
    "_ = torch.manual_seed(123)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "# rerunning the boxes below will only append stuffs to the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "The following blocks will run CEBaB benchmark in\n",
    "all the combinations of the following conditions.\n",
    "\"\"\"\n",
    "grid = {\n",
    "    \"eval_split\": [\"test\"],\n",
    "    # dev,test\n",
    "    \"control\": [\"approximate\"],\n",
    "    # ks,approximate\n",
    "    \"seed\": [42, 66, 77],\n",
    "    # 42, 66, 77\n",
    "    \"h_dim\": [192],\n",
    "    # 1,16,64,128,192\n",
    "    # 1,16,64,75\n",
    "    \"interchange_layer\" : [8],\n",
    "    # 0,1; 2,4,6,8,10,12\n",
    "    \"class_num\": [5],\n",
    "    \"k\" : [0], \n",
    "    # 0;10,100,500,1000,3000,6000,9848,19684\n",
    "    \"alpha\" : [1.0],\n",
    "    # 0.0,1.0\n",
    "    \"beta\" : [1.0],\n",
    "    # 0.0,1.0\n",
    "    \"gemma\" : [3.0],\n",
    "    # 0.0,3.0\n",
    "    \"model_arch\" : [\"roberta-base\"],\n",
    "    # lstm, bert-base-uncased, roberta-base, gpt2\n",
    "    \"lr\" : [\"8e-05\"],\n",
    "    # 8e-05; 0.001\n",
    "    \"counterfactual_type\" : [\"approximate\"],\n",
    "    # approximate,true\n",
    "    \"self_explain\" : [True],\n",
    "    # True, False\n",
    "}\n",
    "\n",
    "keys, values = zip(*grid.items())\n",
    "permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]\n",
    "\n",
    "device = 'cuda:2'\n",
    "batch_size = 32\n",
    "\n",
    "if grid[\"control\"][0] == \"hdims\" or grid[\"control\"][0] == \"layers\":\n",
    "    assert grid[\"eval_split\"][0] == \"dev\"\n",
    "else:\n",
    "    assert grid[\"eval_split\"][0] == \"test\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running for this setting:  (('eval_split', 'test'), ('control', 'approximate'), ('seed', 42), ('h_dim', 192), ('interchange_layer', 8), ('class_num', 5), ('k', 0), ('alpha', 1.0), ('beta', 1.0), ('gemma', 3.0), ('model_arch', 'roberta-base'), ('lr', '8e-05'), ('counterfactual_type', 'approximate'), ('self_explain', True))\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration CEBaB--CEBaB-ccd674d249652bd4\n",
      "Reusing dataset parquet (../train_cache/CEBaB___parquet/CEBaB--CEBaB-ccd674d249652bd4/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8bacc95d1a6941c39b0a4c1be1d9260b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dropping no majority reviews: 16.6382% of train_exclusive dataset.\n",
      "Dropping no majority reviews: 16.03% of train_inclusive dataset.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at ../proxy_training_results/RoBERTa-approximate/cebab.alpha.1.0.beta.1.0.gemma.3.0.lr.8e-05.dim.192.hightype.roberta-base.CEBaB.cls.dropout.0.1.enc.dropout.0.1.counter.type.approximate.k.0.int.layer.8.seed_42/ were not used when initializing RobertaForNonlinearSequenceClassification: ['multitask_classifier.out_proj.bias', 'multitask_classifier.dense.weight', 'multitask_classifier.out_proj.weight', 'multitask_classifier.dense.bias']\n",
      "- This IS expected if you are initializing RobertaForNonlinearSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForNonlinearSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of the model checkpoint at ../proxy_training_results/RoBERTa-approximate/cebab.alpha.1.0.beta.1.0.gemma.3.0.lr.8e-05.dim.192.hightype.roberta-base.CEBaB.cls.dropout.0.1.enc.dropout.0.1.counter.type.approximate.k.0.int.layer.8.seed_42/ were not used when initializing RobertaForNonlinearSequenceClassification: ['multitask_classifier.out_proj.bias', 'multitask_classifier.dense.weight', 'multitask_classifier.out_proj.weight', 'multitask_classifier.dense.bias']\n",
      "- This IS expected if you are initializing RobertaForNonlinearSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForNonlinearSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "intervention_h_dim=192\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "../eval_pipeline/utils/data_utils.py:219: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  pairs['intervention_type'][pairs[f'{aspect}_aspect_majority_base'] != pairs[f'{aspect}_aspect_majority_counterfactual']] = aspect\n",
      "  \"The `device` argument is deprecated and will be removed in v5 of Transformers.\", FutureWarning\n",
      "100%|██████████| 124/124 [00:21<00:00,  5.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running for this setting:  (('eval_split', 'test'), ('control', 'approximate'), ('seed', 66), ('h_dim', 192), ('interchange_layer', 8), ('class_num', 5), ('k', 0), ('alpha', 1.0), ('beta', 1.0), ('gemma', 3.0), ('model_arch', 'roberta-base'), ('lr', '8e-05'), ('counterfactual_type', 'approximate'), ('self_explain', True))\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration CEBaB--CEBaB-ccd674d249652bd4\n",
      "Reusing dataset parquet (../train_cache/CEBaB___parquet/CEBaB--CEBaB-ccd674d249652bd4/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "82fc66939fb1451e932c857c7e036064",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dropping no majority reviews: 16.6382% of train_exclusive dataset.\n",
      "Dropping no majority reviews: 16.03% of train_inclusive dataset.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at ../proxy_training_results/RoBERTa-approximate/cebab.alpha.1.0.beta.1.0.gemma.3.0.lr.8e-05.dim.192.hightype.roberta-base.CEBaB.cls.dropout.0.1.enc.dropout.0.1.counter.type.approximate.k.0.int.layer.8.seed_66/ were not used when initializing RobertaForNonlinearSequenceClassification: ['multitask_classifier.out_proj.bias', 'multitask_classifier.dense.weight', 'multitask_classifier.out_proj.weight', 'multitask_classifier.dense.bias']\n",
      "- This IS expected if you are initializing RobertaForNonlinearSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForNonlinearSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of the model checkpoint at ../proxy_training_results/RoBERTa-approximate/cebab.alpha.1.0.beta.1.0.gemma.3.0.lr.8e-05.dim.192.hightype.roberta-base.CEBaB.cls.dropout.0.1.enc.dropout.0.1.counter.type.approximate.k.0.int.layer.8.seed_66/ were not used when initializing RobertaForNonlinearSequenceClassification: ['multitask_classifier.out_proj.bias', 'multitask_classifier.dense.weight', 'multitask_classifier.out_proj.weight', 'multitask_classifier.dense.bias']\n",
      "- This IS expected if you are initializing RobertaForNonlinearSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForNonlinearSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "intervention_h_dim=192\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 124/124 [00:21<00:00,  5.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running for this setting:  (('eval_split', 'test'), ('control', 'approximate'), ('seed', 77), ('h_dim', 192), ('interchange_layer', 8), ('class_num', 5), ('k', 0), ('alpha', 1.0), ('beta', 1.0), ('gemma', 3.0), ('model_arch', 'roberta-base'), ('lr', '8e-05'), ('counterfactual_type', 'approximate'), ('self_explain', True))\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration CEBaB--CEBaB-ccd674d249652bd4\n",
      "Reusing dataset parquet (../train_cache/CEBaB___parquet/CEBaB--CEBaB-ccd674d249652bd4/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a25c2231d04d4d69993d9b692f8b8c38",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dropping no majority reviews: 16.6382% of train_exclusive dataset.\n",
      "Dropping no majority reviews: 16.03% of train_inclusive dataset.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at ../proxy_training_results/RoBERTa-approximate/cebab.alpha.1.0.beta.1.0.gemma.3.0.lr.8e-05.dim.192.hightype.roberta-base.CEBaB.cls.dropout.0.1.enc.dropout.0.1.counter.type.approximate.k.0.int.layer.8.seed_77/ were not used when initializing RobertaForNonlinearSequenceClassification: ['multitask_classifier.out_proj.bias', 'multitask_classifier.dense.weight', 'multitask_classifier.out_proj.weight', 'multitask_classifier.dense.bias']\n",
      "- This IS expected if you are initializing RobertaForNonlinearSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForNonlinearSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of the model checkpoint at ../proxy_training_results/RoBERTa-approximate/cebab.alpha.1.0.beta.1.0.gemma.3.0.lr.8e-05.dim.192.hightype.roberta-base.CEBaB.cls.dropout.0.1.enc.dropout.0.1.counter.type.approximate.k.0.int.layer.8.seed_77/ were not used when initializing RobertaForNonlinearSequenceClassification: ['multitask_classifier.out_proj.bias', 'multitask_classifier.dense.weight', 'multitask_classifier.out_proj.weight', 'multitask_classifier.dense.bias']\n",
      "- This IS expected if you are initializing RobertaForNonlinearSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForNonlinearSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "intervention_h_dim=192\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 124/124 [00:21<00:00,  5.71it/s]\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(permutations_dicts)):\n",
    "    \n",
    "    eval_split=permutations_dicts[i][\"eval_split\"]\n",
    "    seed=permutations_dicts[i][\"seed\"]\n",
    "    class_num=permutations_dicts[i][\"class_num\"]\n",
    "    alpha=permutations_dicts[i][\"alpha\"]\n",
    "    beta=permutations_dicts[i][\"beta\"]\n",
    "    gemma=permutations_dicts[i][\"gemma\"]\n",
    "    h_dim=permutations_dicts[i][\"h_dim\"]\n",
    "    dataset_type = f'{class_num}-way'\n",
    "    control=permutations_dicts[i][\"control\"]\n",
    "    model_arch=permutations_dicts[i][\"model_arch\"]\n",
    "    k=permutations_dicts[i][\"k\"]\n",
    "    interchange_layer=permutations_dicts[i][\"interchange_layer\"]\n",
    "    lr=permutations_dicts[i][\"lr\"]\n",
    "    counterfactual_type=permutations_dicts[i][\"counterfactual_type\"]\n",
    "    self_explain=permutations_dicts[i][\"self_explain\"]\n",
    "    \n",
    "    if model_arch == \"bert-base-uncased\":\n",
    "        model_path = \"BERT\"\n",
    "        model_module = BERTForCEBaB\n",
    "        explainer_module = CausalProxyModelForBERT\n",
    "    elif model_arch == \"roberta-base\":\n",
    "        model_path = \"RoBERTa\" \n",
    "        model_module = RoBERTaForCEBaB\n",
    "        explainer_module = CausalProxyModelForRoBERTa\n",
    "    elif model_arch == \"gpt2\":\n",
    "        model_path = \"gpt2\"\n",
    "        model_module = GPT2ForCEBaB\n",
    "        explainer_module = CausalProxyModelForGPT2\n",
    "    elif model_arch == \"lstm\":\n",
    "        model_path = \"lstm\"\n",
    "        model_module = LSTMForCEBaB\n",
    "        explainer_module = CausalProxyModelForLSTM\n",
    "    model_path += f\"-{control}\"\n",
    "    grid_conditions=(\n",
    "        (\"eval_split\", eval_split),\n",
    "        (\"control\", control),\n",
    "        (\"seed\", seed),\n",
    "        (\"h_dim\", h_dim),\n",
    "        (\"interchange_layer\", interchange_layer),\n",
    "        (\"class_num\", class_num),\n",
    "        (\"k\", k),\n",
    "        (\"alpha\", alpha),\n",
    "        (\"beta\", beta),\n",
    "        (\"gemma\", gemma),\n",
    "        (\"model_arch\", model_arch),\n",
    "        (\"lr\", lr),\n",
    "        (\"counterfactual_type\", counterfactual_type),\n",
    "        (\"self_explain\", self_explain)\n",
    "    )\n",
    "    print(\"Running for this setting: \", grid_conditions)\n",
    "\n",
    "    blackbox_model_path = f'CEBaB/{model_arch}.CEBaB.sa.'\\\n",
    "                          f'{class_num}-class.exclusive.seed_{seed}'\n",
    "    cpm_model_path = f'../proxy_training_results/{model_path}/'\\\n",
    "                     f'cebab.alpha.{alpha}.beta.{beta}.gemma.{gemma}.'\\\n",
    "                     f'lr.{lr}.dim.{h_dim}.hightype.{model_arch}.'\\\n",
    "                     f'CEBaB.cls.dropout.0.1.enc.dropout.0.1.counter.type.'\\\n",
    "                     f'{counterfactual_type}.k.{k}.int.layer.{interchange_layer}.'\\\n",
    "                     f'seed_{seed}/'\n",
    "    if self_explain:\n",
    "        # we throw out the black-box model in this case.\n",
    "        blackbox_model_path = cpm_model_path\n",
    "    \n",
    "    # load data from HF\n",
    "    cebab = datasets.load_dataset(\n",
    "        'CEBaB/CEBaB', use_auth_token=True,\n",
    "        cache_dir=\"../train_cache/\"\n",
    "    )\n",
    "\n",
    "    train, dev, test = preprocess_hf_dataset_inclusive(\n",
    "        cebab, verbose=1, dataset_type=dataset_type\n",
    "    )\n",
    "\n",
    "    eval_dataset = dev if eval_split == 'dev' else test\n",
    "\n",
    "    tf_model = model_module(\n",
    "        blackbox_model_path, \n",
    "        device=device, \n",
    "        batch_size=batch_size\n",
    "    )\n",
    "    explainer = explainer_module(\n",
    "        blackbox_model_path,\n",
    "        cpm_model_path, \n",
    "        device=device, \n",
    "        batch_size=batch_size,\n",
    "        intervention_h_dim=h_dim,\n",
    "        self_explain=self_explain,\n",
    "    )\n",
    "\n",
    "    result_per_example, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, \\\n",
    "    CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, \\\n",
    "    ACaCE_per_aspect, performance_report = cebab_pipeline(\n",
    "        tf_model, explainer, \n",
    "        train, eval_dataset,\n",
    "        seed, k, dataset_type=dataset_type, \n",
    "        shorten_model_name=False, \n",
    "        train_setting=\"inclusive\", \n",
    "        approximate=False if counterfactual_type == \"true\" else True\n",
    "    )\n",
    "    \n",
    "    results[grid_conditions] = (\n",
    "        result_per_example, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, \\\n",
    "        CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, \\\n",
    "        ACaCE_per_aspect, performance_report\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>eval_split</th>\n",
       "      <th>control</th>\n",
       "      <th>seed</th>\n",
       "      <th>h_dim</th>\n",
       "      <th>interchange_layer</th>\n",
       "      <th>class_num</th>\n",
       "      <th>k</th>\n",
       "      <th>beta</th>\n",
       "      <th>gemma</th>\n",
       "      <th>model_arch</th>\n",
       "      <th>lr</th>\n",
       "      <th>counterfactual_type</th>\n",
       "      <th>ICaCE-L2</th>\n",
       "      <th>ICaCE-cosine</th>\n",
       "      <th>ICaCE-normdiff</th>\n",
       "      <th>macro-f1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>test</td>\n",
       "      <td>approximate</td>\n",
       "      <td>42</td>\n",
       "      <td>192</td>\n",
       "      <td>8</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>roberta-base</td>\n",
       "      <td>8e-05</td>\n",
       "      <td>approximate</td>\n",
       "      <td>0.5862</td>\n",
       "      <td>0.4944</td>\n",
       "      <td>0.3708</td>\n",
       "      <td>0.691116</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>test</td>\n",
       "      <td>approximate</td>\n",
       "      <td>66</td>\n",
       "      <td>192</td>\n",
       "      <td>8</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>roberta-base</td>\n",
       "      <td>8e-05</td>\n",
       "      <td>approximate</td>\n",
       "      <td>0.6279</td>\n",
       "      <td>0.4770</td>\n",
       "      <td>0.4300</td>\n",
       "      <td>0.680137</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>test</td>\n",
       "      <td>approximate</td>\n",
       "      <td>77</td>\n",
       "      <td>192</td>\n",
       "      <td>8</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>roberta-base</td>\n",
       "      <td>8e-05</td>\n",
       "      <td>approximate</td>\n",
       "      <td>0.6676</td>\n",
       "      <td>0.4794</td>\n",
       "      <td>0.4592</td>\n",
       "      <td>0.708094</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  eval_split      control  seed  h_dim  interchange_layer  class_num  k  beta  \\\n",
       "0       test  approximate    42    192                  8          5  0   1.0   \n",
       "1       test  approximate    66    192                  8          5  0   1.0   \n",
       "2       test  approximate    77    192                  8          5  0   1.0   \n",
       "\n",
       "   gemma    model_arch     lr counterfactual_type  ICaCE-L2  ICaCE-cosine  \\\n",
       "0    3.0  roberta-base  8e-05         approximate    0.5862        0.4944   \n",
       "1    3.0  roberta-base  8e-05         approximate    0.6279        0.4770   \n",
       "2    3.0  roberta-base  8e-05         approximate    0.6676        0.4794   \n",
       "\n",
       "   ICaCE-normdiff  macro-f1  \n",
       "0          0.3708  0.691116  \n",
       "1          0.4300  0.680137  \n",
       "2          0.4592  0.708094  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "important_keys = [\n",
    "    \"eval_split\",\n",
    "    \"control\", \"seed\", \n",
    "    \"h_dim\", \"interchange_layer\", \n",
    "    \"class_num\", \"k\", \n",
    "    \"beta\", \"gemma\", \n",
    "    \"model_arch\", \"lr\", \"counterfactual_type\", \n",
    "]\n",
    "values = []\n",
    "for k, v in results.items():\n",
    "    _values = []\n",
    "    for ik in important_keys:\n",
    "        _values.append(dict(k)[ik])\n",
    "    _values.append(v[2][\"ICaCE-L2\"].iloc[0])\n",
    "    _values.append(v[2][\"ICaCE-cosine\"].iloc[0])\n",
    "    _values.append(v[2][\"ICaCE-normdiff\"].iloc[0])\n",
    "    _values.append(v[-1].iloc[0][0])\n",
    "    values.append(_values)\n",
    "important_keys.extend([\"ICaCE-L2\", \"ICaCE-cosine\", \"ICaCE-normdiff\", \"macro-f1\"])\n",
    "df = pd.DataFrame(values, columns=important_keys)\n",
    "df.sort_values(by=['seed', 'interchange_layer', 'h_dim'], ascending=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>k</th>\n",
       "      <th>seed</th>\n",
       "      <th>h_dim</th>\n",
       "      <th>interchange_layer</th>\n",
       "      <th>class_num</th>\n",
       "      <th>beta</th>\n",
       "      <th>gemma</th>\n",
       "      <th>ICaCE-L2</th>\n",
       "      <th>ICaCE-cosine</th>\n",
       "      <th>ICaCE-normdiff</th>\n",
       "      <th>macro-f1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>61.666667</td>\n",
       "      <td>192.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.627233</td>\n",
       "      <td>0.4836</td>\n",
       "      <td>0.42</td>\n",
       "      <td>0.693115</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   k       seed  h_dim  interchange_layer  class_num  beta  gemma  ICaCE-L2  \\\n",
       "0  0  61.666667  192.0                8.0        5.0   1.0    3.0  0.627233   \n",
       "\n",
       "   ICaCE-cosine  ICaCE-normdiff  macro-f1  \n",
       "0        0.4836            0.42  0.693115  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "groupby_key = \"k\"\n",
    "\n",
    "model_arch = grid[\"model_arch\"][0]\n",
    "control = grid[\"control\"][0]\n",
    "df.groupby([groupby_key], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>k</th>\n",
       "      <th>seed</th>\n",
       "      <th>h_dim</th>\n",
       "      <th>interchange_layer</th>\n",
       "      <th>class_num</th>\n",
       "      <th>beta</th>\n",
       "      <th>gemma</th>\n",
       "      <th>ICaCE-L2</th>\n",
       "      <th>ICaCE-cosine</th>\n",
       "      <th>ICaCE-normdiff</th>\n",
       "      <th>macro-f1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>17.897858</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.040704</td>\n",
       "      <td>0.00943</td>\n",
       "      <td>0.04504</td>\n",
       "      <td>0.014085</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   k       seed  h_dim  interchange_layer  class_num  beta  gemma  ICaCE-L2  \\\n",
       "0  0  17.897858    0.0                0.0        0.0   0.0    0.0  0.040704   \n",
       "\n",
       "   ICaCE-cosine  ICaCE-normdiff  macro-f1  \n",
       "0       0.00943         0.04504  0.014085  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.groupby([groupby_key], as_index=False).std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
