{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Causal Mediation Analyses\n",
    "Static version of activate alignment search without interchange intervention training. This analysis follows closely the [causal mediation analyses paper](https://proceedings.neurips.cc/paper/2020/file/92650b2e92217715fe312e6fa7b90d82-Paper.pdf). We simplify their method to adapt to CEBaB. Details can be found in code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from libs import *\n",
    "from modelings.modelings_bert import *\n",
    "from modelings.modelings_roberta import *\n",
    "from modelings.modelings_gpt2 import *\n",
    "from modelings.modelings_lstm import *\n",
    "\n",
    "\"\"\"\n",
    "For evaluate, we use a single random seed, as\n",
    "the models are trained with 5 different seeds\n",
    "already.\n",
    "\"\"\"\n",
    "_ = random.seed(123)\n",
    "_ = np.random.seed(123)\n",
    "_ = torch.manual_seed(123)\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration CEBaB--CEBaB-0e2f7ed67c9d7e55\n",
      "Reusing dataset parquet (../train_cache/CEBaB___parquet/CEBaB--CEBaB-0e2f7ed67c9d7e55/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8971b15126f1464db147f81773feccec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dropping no majority reviews: 16.6382% of train dataset.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of IITBERTForSequenceClassification were not initialized from the model checkpoint at CEBaB/bert-base-uncased.CEBaB.sa.5-class.exclusive.seed_42 and are newly initialized: ['multitask_classifier.dense.bias', 'multitask_classifier.out_proj.bias', 'multitask_classifier.dense.weight', 'multitask_classifier.out_proj.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "seed=42\n",
    "class_num=5\n",
    "beta=1.0\n",
    "gemma=3.0\n",
    "h_dim=192\n",
    "dataset_type = f'{class_num}-way'\n",
    "correction_epsilon=None\n",
    "cls_dropout=0.1\n",
    "enc_dropout=0.1\n",
    "control=False\n",
    "model_arch=\"bert-base-uncased\"\n",
    "neuron_pool_size = h_dim*4\n",
    "\n",
    "# these are the control settings.\n",
    "concepts = [\"ambiance\", \"food\", \"noise\", \"service\"]\n",
    "random_align_neurons = {}\n",
    "random_neuron_pool = [i for i in range(neuron_pool_size)]\n",
    "random.shuffle(random_neuron_pool)\n",
    "for idx, c in enumerate(concepts):\n",
    "    random_align_neurons[c] = set(random_neuron_pool[idx*h_dim:(idx+1)*h_dim])\n",
    "\n",
    "device='cuda:0'\n",
    "batch_size=32\n",
    "    \n",
    "model_path = f'CEBaB/{model_arch}.CEBaB.sa.{class_num}-class.exclusive.seed_{seed}'\n",
    "\n",
    "# load data from HF\n",
    "cebab = datasets.load_dataset(\n",
    "    'CEBaB/CEBaB', use_auth_token=True,\n",
    "    cache_dir=\"../train_cache/\"\n",
    ")\n",
    "cebab['train'] = cebab['train_exclusive']\n",
    "train, dev, test = preprocess_hf_dataset(\n",
    "    cebab, one_example_per_world=True, \n",
    "    verbose=1, dataset_type=dataset_type\n",
    ")\n",
    "\n",
    "train_dataset = train.copy()\n",
    "dev_dataset = test.copy()\n",
    "\n",
    "tf_model = BERTForCEBaB(\n",
    "    model_path, \n",
    "    device=device, \n",
    "    batch_size=batch_size\n",
    ")\n",
    "\n",
    "explainer = CausalMediationModelForBERT(\n",
    "    model_path,\n",
    "    device=device, \n",
    "    batch_size=batch_size,\n",
    "    intervention_h_dim=h_dim,\n",
    "    align_neurons=None\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_rows_with_condition(df, col_name, value, condition=\"==\"):\n",
    "    if condition == \"==\":\n",
    "        return df[\n",
    "            (df[col_name]==value)&\n",
    "            (df[col_name]!=\"\")\n",
    "        ]\n",
    "    elif condition == \"!=\":\n",
    "        return df[\n",
    "            (df[col_name]!=value)&\n",
    "            (df[col_name]!=\"\")\n",
    "        ]\n",
    "    else:\n",
    "        assert False # not supporting this type yet.\n",
    "\n",
    "def get_treatment_effect_pairs(dataset, align_concept):\n",
    "    concepts = [\"ambiance\", \"food\", \"noise\", \"service\"]\n",
    "    treatment_effect_pairs = []\n",
    "    for index, row in dataset.iterrows():\n",
    "        control_concepts = list(set(concepts) - set([align_concept]))\n",
    "        description = row[\"description\"]\n",
    "        concept_label = row[f\"{align_concept}_aspect_majority\"]\n",
    "        if concept_label != \"\":\n",
    "            counterfactual_concept_row = sample_rows_with_condition(\n",
    "                dataset,\n",
    "                f\"{align_concept}_aspect_majority\",\n",
    "                concept_label,\n",
    "                \"!=\"\n",
    "            )\n",
    "            for control_concept in control_concepts:\n",
    "                control_concept_label = row[f\"{control_concept}_aspect_majority\"]\n",
    "                counterfactual_concept_row = sample_rows_with_condition(\n",
    "                    counterfactual_concept_row,\n",
    "                    f\"{control_concept}_aspect_majority\",\n",
    "                    control_concept_label,\n",
    "                    \"==\"\n",
    "                )\n",
    "            if len(counterfactual_concept_row) > 0:\n",
    "                counterfactual_description = counterfactual_concept_row.sample(\n",
    "                ).iloc[0][\"description\"]\n",
    "                treatment_effect_pairs += [(description, counterfactual_description)]\n",
    "    return treatment_effect_pairs\n",
    "\n",
    "def evaluate_te(explainer, tc_pairs, neuron_id):\n",
    "    te = 0.0\n",
    "    for pair in tc_pairs:\n",
    "        base_logits, _ = preload_logits[pair[0]], preload_logits[pair[1]]\n",
    "        base_reprs, source_reprs = preload_representations[pair[0]], preload_representations[pair[1]]\n",
    "        counterfactual_logits = intervene_neuron_logits(\n",
    "            explainer, \n",
    "            base_reprs.clone(), \n",
    "            source_reprs,\n",
    "            neuron_id,\n",
    "        )\n",
    "        te += loss(counterfactual_logits, base_logits)\n",
    "    return te/len(tc_pairs)\n",
    "\n",
    "def sequential_alignment_search(aspect_neuron_te_map):\n",
    "    align_neurons = {}\n",
    "    # lets help causal mediation a bit.\n",
    "    concepts = [\"food\", \"service\", \"ambiance\", \"noise\"]\n",
    "    neuron_pool = set([i for i in range(len(aspect_neuron_te_map[\"ambiance\"]))])\n",
    "    h_dim = len(aspect_neuron_te_map[\"ambiance\"])//len(concepts)\n",
    "    for concept in concepts:\n",
    "        avaliable_neurons = []\n",
    "        for item in aspect_neuron_te_map[concept]:\n",
    "            if item[0] in neuron_pool:\n",
    "                avaliable_neurons.append(item)\n",
    "        neuron_causal_effect = sorted(avaliable_neurons, key=lambda x: x[1], reverse=True)\n",
    "        aligned_neurons = set([item[0] for item in neuron_causal_effect[:h_dim]])\n",
    "        align_neurons[concept] = aligned_neurons\n",
    "        neuron_pool -= aligned_neurons\n",
    "    return align_neurons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "preloading representations and descriptions.\n"
     ]
    }
   ],
   "source": [
    "dataset = train_dataset[[\n",
    "    \"description\",\n",
    "    \"ambiance_aspect_majority\",\n",
    "    \"food_aspect_majority\",\n",
    "    \"noise_aspect_majority\",\n",
    "    \"service_aspect_majority\",\n",
    "]]\n",
    "align_neurons = {}\n",
    "neuron_pool = set([i for i in range(neuron_pool_size)])\n",
    "loss = nn.MSELoss()\n",
    "\n",
    "print(\"preloading representations and descriptions.\")\n",
    "# pre-calculating representations to avoid repeated computations.\n",
    "preload_representations = {}\n",
    "preload_logits = {}\n",
    "for index, row in dataset.iterrows():\n",
    "    description = row[\"description\"]\n",
    "    explainer.model.model.eval()\n",
    "    x = explainer.tokenizer([description], padding=True, truncation=True, return_tensors='pt')\n",
    "    x_batch = {k: v.to(explainer.device) for k, v in x.items()}\n",
    "    outputs = explainer.model.model(\n",
    "        **x_batch,\n",
    "        output_hidden_states=True,\n",
    "    )\n",
    "    cls_hidden_state = outputs.hidden_states[-1][:,0,:].detach()\n",
    "    output_logit = torch.nn.functional.softmax(\n",
    "            outputs.logits[0].cpu(), dim=-1\n",
    "    ).detach()[0]\n",
    "    preload_representations[description] = cls_hidden_state\n",
    "    preload_logits[description] = output_logit\n",
    "\n",
    "# getting base and source pairs for each aspect.\n",
    "# base and source should only differ in one aspect!\n",
    "k = 300\n",
    "ambiance_tc_pairs = random.sample(get_treatment_effect_pairs(dataset, \"ambiance\"), k=k)\n",
    "food_tc_pairs = random.sample(get_treatment_effect_pairs(dataset, \"food\"), k=k)\n",
    "noise_tc_pairs = random.sample(get_treatment_effect_pairs(dataset, \"noise\"), k=k)\n",
    "service_tc_pairs = random.sample(get_treatment_effect_pairs(dataset, \"service\"), k=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 768/768 [11:54<00:00,  1.07it/s]\n"
     ]
    }
   ],
   "source": [
    "# the steps are listed here:\n",
    "# 1. loop through all neurons, calculate te for each aspects. save it!\n",
    "# 2. loop through all neurons again, the effect of each neuron on a \n",
    "#.   aspect, is essentially a combination of saved te scores!\n",
    "aspect_neuron_te_map = {\n",
    "    \"ambiance\": [], \n",
    "    \"food\": [], \n",
    "    \"noise\": [], \n",
    "    \"service\": []\n",
    "}\n",
    "for neuron_id in tqdm(neuron_pool):\n",
    "    ambiance_te_est = evaluate_te(\n",
    "        explainer,\n",
    "        ambiance_tc_pairs, \n",
    "        neuron_id\n",
    "    )\n",
    "    food_te_est = evaluate_te(\n",
    "        explainer,\n",
    "        food_tc_pairs, \n",
    "        neuron_id\n",
    "    )\n",
    "    noise_te_est = evaluate_te(\n",
    "        explainer,\n",
    "        noise_tc_pairs, \n",
    "        neuron_id\n",
    "    )\n",
    "    service_te_est = evaluate_te(\n",
    "        explainer,\n",
    "        service_tc_pairs, \n",
    "        neuron_id\n",
    "    )\n",
    "    aspect_neuron_te_map[\"ambiance\"].append(\n",
    "        (neuron_id, ambiance_te_est/(food_te_est+noise_te_est+service_te_est))\n",
    "    )\n",
    "    aspect_neuron_te_map[\"food\"].append(\n",
    "        (neuron_id, food_te_est/(ambiance_te_est+noise_te_est+service_te_est))\n",
    "    )\n",
    "    aspect_neuron_te_map[\"noise\"].append(\n",
    "        (neuron_id, noise_te_est/(food_te_est+ambiance_te_est+service_te_est))\n",
    "    )\n",
    "    aspect_neuron_te_map[\"service\"].append(\n",
    "        (neuron_id, service_te_est/(food_te_est+noise_te_est+ambiance_te_est))\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "align_neurons = sequential_alignment_search(aspect_neuron_te_map)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Evaluate with neurons selected by causal mediation analyses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of IITBERTForSequenceClassification were not initialized from the model checkpoint at CEBaB/bert-base-uncased.CEBaB.sa.5-class.exclusive.seed_42 and are newly initialized: ['multitask_classifier.dense.bias', 'multitask_classifier.out_proj.bias', 'multitask_classifier.dense.weight', 'multitask_classifier.out_proj.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "100%|██████████| 124/124 [00:31<00:00,  3.89it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>ICaCE-L2</th>\n",
       "      <th>ICaCE-cosine</th>\n",
       "      <th>ICaCE-normdiff</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>bert-base-uncased.CEBaB.sa.5-class.exclusive.seed_42</th>\n",
       "      <th>CausalMediationModelForBERT</th>\n",
       "      <th>mean</th>\n",
       "      <td>0.7943</td>\n",
       "      <td>0.6925</td>\n",
       "      <td>0.6669</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                     ICaCE-L2  \\\n",
       "bert-base-uncased.CEBaB.sa.5-class.exclusive.se... CausalMediationModelForBERT mean    0.7943   \n",
       "\n",
       "                                                                                     ICaCE-cosine  \\\n",
       "bert-base-uncased.CEBaB.sa.5-class.exclusive.se... CausalMediationModelForBERT mean        0.6925   \n",
       "\n",
       "                                                                                     ICaCE-normdiff  \n",
       "bert-base-uncased.CEBaB.sa.5-class.exclusive.se... CausalMediationModelForBERT mean          0.6669  "
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "explainer = CausalMediationModelForBERT(\n",
    "    model_path,\n",
    "    device=device, \n",
    "    batch_size=batch_size,\n",
    "    intervention_h_dim=h_dim,\n",
    "    align_neurons=align_neurons\n",
    ")\n",
    "\n",
    "result_per_example, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, \\\n",
    "CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, \\\n",
    "ACaCE_per_aspect, performance_report = cebab_pipeline(\n",
    "    tf_model, explainer, \n",
    "    train_dataset, dev_dataset, \n",
    "    dataset_type=\"5-way\",\n",
    "    correction_epsilon=None,\n",
    ")\n",
    "\n",
    "CEBaB_metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Evaluate with neurons that are randomly selected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of IITBERTForSequenceClassification were not initialized from the model checkpoint at CEBaB/bert-base-uncased.CEBaB.sa.5-class.exclusive.seed_42 and are newly initialized: ['multitask_classifier.dense.bias', 'multitask_classifier.out_proj.bias', 'multitask_classifier.dense.weight', 'multitask_classifier.out_proj.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "100%|██████████| 124/124 [00:32<00:00,  3.78it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>ICaCE-L2</th>\n",
       "      <th>ICaCE-cosine</th>\n",
       "      <th>ICaCE-normdiff</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>bert-base-uncased.CEBaB.sa.5-class.exclusive.seed_42</th>\n",
       "      <th>CausalMediationModelForBERT</th>\n",
       "      <th>mean</th>\n",
       "      <td>0.7968</td>\n",
       "      <td>0.6773</td>\n",
       "      <td>0.7225</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                     ICaCE-L2  \\\n",
       "bert-base-uncased.CEBaB.sa.5-class.exclusive.se... CausalMediationModelForBERT mean    0.7968   \n",
       "\n",
       "                                                                                     ICaCE-cosine  \\\n",
       "bert-base-uncased.CEBaB.sa.5-class.exclusive.se... CausalMediationModelForBERT mean        0.6773   \n",
       "\n",
       "                                                                                     ICaCE-normdiff  \n",
       "bert-base-uncased.CEBaB.sa.5-class.exclusive.se... CausalMediationModelForBERT mean          0.7225  "
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "control_explainer = CausalMediationModelForBERT(\n",
    "    model_path,\n",
    "    device=device, \n",
    "    batch_size=batch_size,\n",
    "    intervention_h_dim=h_dim,\n",
    "    align_neurons=random_align_neurons\n",
    ")\n",
    "\n",
    "result_per_example, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, \\\n",
    "CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, \\\n",
    "ACaCE_per_aspect, performance_report = cebab_pipeline(\n",
    "    tf_model, control_explainer, \n",
    "    train_dataset, dev_dataset, \n",
    "    dataset_type=\"5-way\",\n",
    "    correction_epsilon=None,\n",
    ")\n",
    "\n",
    "CEBaB_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
