{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 124,
   "id": "e61a0720",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "8a120102",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import argparse\n",
    "import numpy as np\n",
    "import os\n",
    "import pandas as pd\n",
    "import scipy as sp\n",
    "import sys\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import warnings\n",
    "import random\n",
    "import collections\n",
    "\n",
    "# ACD Imports\n",
    "import math\n",
    "import tqdm\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from torch import nn\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "base_dir = os.path.split(os.getcwd())[0]\n",
    "sys.path.append(base_dir)\n",
    "\n",
    "from argparse import Namespace\n",
    "from methods.bag_of_ngrams.processing import cleanReports, cleanSplit, stripChars\n",
    "from pyfunctions.general import extractListFromDic, readJson\n",
    "from pyfunctions.pathology import extract_synoptic, fixLabelProstateGleason, fixProstateLabels, fixLabel, exclude_labels\n",
    "from pyfunctions.cd import *\n",
    "from sklearn import preprocessing\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "from transformers import BertTokenizer, BertForSequenceClassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "id": "f6d6ada4-4781-4789-b3b1-1d044c11b3d3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x7f3795492fd0>"
      ]
     },
     "execution_count": 126,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.autograd.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7651660-7f59-4b59-a574-afecc52dc306",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "## Model Arguments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "id": "c3183be1-3bf6-4f5a-8134-9bdd83db0a56",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "args = {\n",
    "    'model_type': 'bert', # bert, medical_bert, pubmed_bert, biobert, clinical_biobert\n",
    "    'task': 'path',\n",
    "    'field': 'PrimaryGleason'\n",
    "}\n",
    "\n",
    "device = 'cuda:0'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d377977-fe3b-45dd-9d00-1c19e5366038",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "id": "8053ccd9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if args['model_type'] == 'bert':\n",
    "    bert_path = 'bert-base-uncased'\n",
    "    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "elif args['model_type'] == 'medical_bert':\n",
    "    bert_path = f\"{base_dir}/models/pretrained/bert_pretrain_output_all_notes_150000/\"\n",
    "    tokenizer = BertTokenizer.from_pretrained(bert_path, local_files_only=True)\n",
    "elif args['model_type'] == 'pubmed_bert':\n",
    "    bert_path = \"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\"\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\")\n",
    "elif args['model_type'] == 'pubmed_bert_full':\n",
    "    bert_path = \"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext\"\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext\")\n",
    "elif args['model_type'] == 'biobert':\n",
    "    bert_path = \"dmis-lab/biobert-v1.1\"\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"dmis-lab/biobert-v1.1\")\n",
    "elif args['model_type'] == 'clinical_biobert':\n",
    "    bert_path = \"emilyalsentzer/Bio_ClinicalBERT\"\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"emilyalsentzer/Bio_ClinicalBERT\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "id": "4ea84464",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (1345 > 512). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2066 517 324\n"
     ]
    }
   ],
   "source": [
    "# Read in data\n",
    "#field = 'PrimaryGleason' # out of PrimaryGleason, SecondaryGleason', 'MarginStatusNone', 'SeminalVesicleNone'\n",
    "path = f\"../data/prostate.json\"\n",
    "data = readJson(path)\n",
    "\n",
    "# Clean reports\n",
    "data = cleanSplit(data, stripChars)\n",
    "data['dev_test'] = cleanReports(data['dev_test'], stripChars)\n",
    "data = fixLabel(data)\n",
    "\n",
    "train_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['train']]\n",
    "val_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['val']]\n",
    "test_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['test']]\n",
    "print(len(train_documents), len(val_documents),len(test_documents))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "id": "dae7a95c-3b73-42b0-84fe-9aa08fd927e3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Create datasets\n",
    "train_labels = [patient['labels'][args['field']] for patient in data['train']]\n",
    "val_labels = [patient['labels'][args['field']] for patient in data['val']]\n",
    "test_labels = [patient['labels'][args['field']] for patient in data['test']]\n",
    "\n",
    "train_documents, train_labels = exclude_labels(train_documents, train_labels)\n",
    "val_documents, val_labels = exclude_labels(val_documents, val_labels)\n",
    "test_documents, test_labels = exclude_labels(test_documents, test_labels)\n",
    "\n",
    "le = preprocessing.LabelEncoder()\n",
    "le.fit(train_labels)\n",
    "\n",
    "# Map raw label to processed label\n",
    "le_dict = dict(zip(le.classes_, le.transform(le.classes_)))\n",
    "le_dict = {str(key):le_dict[key] for key in le_dict}\n",
    "\n",
    "for label in val_labels + test_labels:\n",
    "    if str(label) not in le_dict:\n",
    "        le_dict[str(label)] = len(le_dict)\n",
    "\n",
    "# Map processed label back to raw label\n",
    "inv_le_dict = {v: k for k, v in le_dict.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "id": "d23e85ca",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "documents_full = train_documents + val_documents + test_documents\n",
    "labels_full = train_labels + val_labels + test_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "id": "a7382bb5-9b4d-4b84-845a-53423fa3efba",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "transformers.models.bert.tokenization_bert.BertTokenizer"
      ]
     },
     "execution_count": 132,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6ffd4d3-5e8f-4587-bb2a-f06b61918c09",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "## Load Trained Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 133,
   "id": "dd45f9a1",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BertForSequenceClassification(\n",
       "  (bert): BertModel(\n",
       "    (embeddings): BertEmbeddings(\n",
       "      (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
       "      (position_embeddings): Embedding(512, 768)\n",
       "      (token_type_embeddings): Embedding(2, 768)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (encoder): BertEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0-11): 12 x BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (pooler): BertPooler(\n",
       "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "      (activation): Tanh()\n",
       "    )\n",
       "  )\n",
       "  (dropout): Dropout(p=0.1, inplace=False)\n",
       "  (classifier): Linear(in_features=768, out_features=3, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 133,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#load finetuned model\n",
    "model_path = f\"{base_dir}/models/{args['task']}/{args['model_type']}_{args['field']}\"\n",
    "checkpoint_file = f\"{model_path}/save_output\"\n",
    "config_file = f\"{model_path}/save_output/config.json\"\n",
    "\n",
    "model = BertForSequenceClassification.from_pretrained(checkpoint_file, num_labels=len(le_dict), output_hidden_states=True)\n",
    "\n",
    "model = model.eval()\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "841e90ea-d6aa-409b-bf15-30261acef043",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\nencoding = tokenizer.encode_plus(train_documents[0], \\n                                         add_special_tokens=True, \\n                                         max_length=512,\\n                                         truncation=True, \\n                                         padding = \"max_length\", \\n                                         return_attention_mask=True, \\n                                         pad_to_max_length=True,\\n                                         return_tensors=\"pt\")\\n\\nlist(encoding.keys())\\n'"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''\n",
    "encoding = tokenizer.encode_plus(train_documents[0], \n",
    "                                         add_special_tokens=True, \n",
    "                                         max_length=512,\n",
    "                                         truncation=True, \n",
    "                                         padding = \"max_length\", \n",
    "                                         return_attention_mask=True, \n",
    "                                         pad_to_max_length=True,\n",
    "                                         return_tensors=\"pt\")\n",
    "\n",
    "list(encoding.keys())\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "092e646b-e834-485d-b112-e4cc947e7175",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "## Path Patching Code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "adfa6f56-40d8-42be-b252-fd0aeed4517c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "'''\n",
    "def patch_context(rel, irrel, patched_entries, sa_module):\n",
    "    rel = reshape_separate_attention_heads(rel, sa_module)\n",
    "    irrel = reshape_separate_attention_heads(irrel, sa_module)\n",
    "    \n",
    "    for entry in patched_entries:\n",
    "        pos = entry[1]\n",
    "        att_head = entry[2]\n",
    "\n",
    "        rel[:, pos, att_head, :] = rel[:, pos, att_head, :] + irrel[:, pos, att_head, :]\n",
    "        irrel[:, pos, att_head, :] = 0\n",
    "        \n",
    "        # irrel[:, pos, att_head, :] = rel[:, pos, att_head, :] + irrel[:, pos, att_head, :]\n",
    "        # rel[:, pos, att_head, :] = 0\n",
    "\n",
    "    \n",
    "    rel = reshape_concatenate_attention_heads(rel, sa_module)\n",
    "    irrel = reshape_concatenate_attention_heads(irrel, sa_module)\n",
    "    \n",
    "    return rel, irrel\n",
    "\n",
    "def prop_self_attention_patched(rel, irrel, attention_mask, \n",
    "                                head_mask, patched_entries, \n",
    "                                sa_module, att_probs = None):\n",
    "    if att_probs is not None:\n",
    "        att_probs = att_probs\n",
    "    else:\n",
    "        att_probs = get_attention_probs(rel + irrel, attention_mask, head_mask, sa_module)\n",
    "    \n",
    "    rel_value, irrel_value = prop_linear(rel, irrel, sa_module.value)\n",
    "    \n",
    "    rel_context = mul_att(att_probs, rel_value, sa_module)\n",
    "    irrel_context = mul_att(att_probs, irrel_value, sa_module)\n",
    "    \n",
    "    rel_context, irrel_context = patch_context(rel_context, irrel_context, patched_entries, sa_module)\n",
    "    \n",
    "    return rel_context, irrel_context\n",
    "\n",
    "def prop_attention_patched(rel, irrel, attention_mask, \n",
    "                           head_mask, patched_entries, a_module, \n",
    "                           att_probs = None):\n",
    "    \n",
    "    rel_context, irrel_context = prop_self_attention_patched(rel, irrel, \n",
    "                                                             attention_mask, \n",
    "                                                             head_mask, \n",
    "                                                             patched_entries,\n",
    "                                                             a_module.self, att_probs)\n",
    "    \n",
    "    # if len(patched_entries):\n",
    "    #     print(rel_context[0, 0, :])\n",
    "    #     print(irrel_context[0, 0, :])\n",
    "    \n",
    "    output_module = a_module.output\n",
    "    \n",
    "    rel_dense, irrel_dense = prop_linear(rel_context, irrel_context, output_module.dense)\n",
    "    rel_tot = rel_dense + rel\n",
    "    irrel_tot = irrel_dense + irrel\n",
    "    \n",
    "    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, output_module.LayerNorm)\n",
    "    \n",
    "#     print('AttRes: ', torch.norm(rel_tot[:, 0]), torch.norm(irrel_tot[:, 0]))\n",
    "    \n",
    "#     rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, output_module.LayerNorm)\n",
    "    \n",
    "#     print('AttOut: ', torch.norm(rel_out[:, 0]), torch.norm(irrel_out[:, 0]))\n",
    "\n",
    "    \n",
    "    return rel_out, irrel_out\n",
    "\n",
    "def prop_layer_patched(rel, irrel, attention_mask, head_mask, patched_entries, layer_module, att_probs = None):\n",
    "    rel_a, irrel_a = prop_attention_patched(rel, irrel, attention_mask, head_mask, patched_entries, layer_module.attention, att_probs)\n",
    "    \n",
    "    # print('Attention: ', torch.norm(rel_a[:, 0]), torch.norm(irrel_a[:, 0]))\n",
    "    \n",
    "    i_module = layer_module.intermediate\n",
    "    rel_id, irrel_id = prop_linear(rel_a, irrel_a, i_module.dense)\n",
    "    rel_iact, irrel_iact = prop_act(rel_id, irrel_id, i_module.intermediate_act_fn)\n",
    "    \n",
    "    # print('Intermediate: ', torch.norm(rel_iact[:, 0]), torch.norm(irrel_iact[:, 0]))\n",
    "    \n",
    "    o_module = layer_module.output\n",
    "    rel_od, irrel_od = prop_linear(rel_iact, irrel_iact, o_module.dense)\n",
    "    \n",
    "    # print('Output: ', torch.norm(rel_od[:, 0]), torch.norm(irrel_od[:, 0]))\n",
    "    \n",
    "    rel_tot = rel_od + rel_a\n",
    "    irrel_tot = irrel_od + irrel_a\n",
    "    \n",
    "    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, o_module.LayerNorm)\n",
    "    \n",
    "    # print('LayerNorm: ', torch.norm(rel_out[:, 0]), torch.norm(irrel_out[:, 0]))\n",
    "    \n",
    "    # import pdb; pdb.set_trace()\n",
    "    \n",
    "    return rel_out, irrel_out\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "1981e90a-ff76-4ddb-81b0-a9fe89a7bc47",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "'''\n",
    "def prop_classifier_model_patched(encoding, model, patched_entries, att_list = None):\n",
    "    embedding_output = get_embeddings_bert(encoding, model.bert)\n",
    "    input_shape = encoding['input_ids'].size()\n",
    "    extended_attention_mask = get_extended_attention_mask(attention_mask = encoding['attention_mask'], \n",
    "                                                          input_shape = input_shape, \n",
    "                                                          bert_model = model.bert)\n",
    "    \n",
    "    head_mask = [None] * model.bert.config.num_hidden_layers\n",
    "    encoder_module = model.bert.encoder\n",
    "    \n",
    "    sh = list(embedding_output.shape)\n",
    "    \n",
    "    rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)\n",
    "    irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)\n",
    "    \n",
    "    irrel[:] = embedding_output[:]\n",
    "    \n",
    "    \n",
    "    for i, layer_module in enumerate(encoder_module.layer):\n",
    "        layer_patched_entries = [p_entry for p_entry in patched_entries if p_entry[0] == i]\n",
    "        layer_head_mask = head_mask[i]\n",
    "        \n",
    "        rel_n, irrel_n = prop_layer_patched(rel, irrel, extended_attention_mask, layer_head_mask, layer_patched_entries, layer_module, att_probs = None)\n",
    "        # print(torch.norm(rel_n[:, 0]), torch.norm(irrel_n[:, 0]))\n",
    "        normalize_rel_irrel(rel_n, irrel_n)\n",
    "        rel, irrel = rel_n, irrel_n\n",
    "        # if i == 11:\n",
    "        # print(torch.norm(rel[:, 0]), torch.norm(irrel[:, 0]))\n",
    "    \n",
    "    rel_pool, irrel_pool = prop_pooler(rel, irrel, model.bert.pooler)\n",
    "    rel_out, irrel_out = prop_linear(rel_pool, irrel_pool, model.classifier)\n",
    "    \n",
    "    return rel_out, irrel_out\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "40283990-ce19-4736-bf9d-2b4a835d1a67",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "text = documents_full[0]\n",
    "label = labels_full[0]\n",
    "encoding = get_encoding(text, tokenizer, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "8c1880a8-0140-4667-ba70-43cd343fdffe",
   "metadata": {},
   "outputs": [],
   "source": [
    "patched_entries_1 = [(i, i, i) for i in range(12)]\n",
    "patched_entries_2 = [(11, 0, i) for i in range(12)]\n",
    "patched_entries_3 = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "118234c5-ce29-4ad0-965f-936801717c61",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get raw output\n",
    "raw_logit = ft_model(**encoding, output_hidden_states = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "5dad39b0-184b-43ba-9455-294bb58e76b8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SequenceClassifierOutput(loss=None, logits=tensor([[-2.4799, -1.4704,  8.0667, -2.2106, -2.3697]], device='cuda:1'), hidden_states=None, attentions=None)"
      ]
     },
     "execution_count": 249,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_logit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe8f1326-c1ac-4959-87b2-bb8ae4e31285",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "rel_2, irrel_2 = prop_classifier_model_patched(encoding, ft_model, patched_entries_2)\n",
    "rel_3, irrel_3 = prop_classifier_model_patched(encoding, ft_model, patched_entries_3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "e096b9bb-7dc3-47ff-ac26-7beebca5eeb6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 5.0000, -0.9562,  5.1698, -1.4072, -1.6844]], device='cuda:1')"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rel_2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76277da4-2c38-4e47-8c4a-1c5b462943b7",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "## Head-head Patching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "id": "06f921a8-1d7e-4261-a460-a1fe8ea1d742",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def reshape_separate_attention_heads(context_layer, sa_module):\n",
    "    new_shape = context_layer.size()[:-1] + (sa_module.num_attention_heads, sa_module.attention_head_size)\n",
    "    context_layer = context_layer.view(new_shape)\n",
    "    return context_layer\n",
    "\n",
    "def reshape_concatenate_attention_heads(context_layer, sa_module):\n",
    "    new_shape = context_layer.size()[:-2] + (sa_module.all_head_size,)\n",
    "    context_layer = context_layer.view(*new_shape)\n",
    "    return context_layer\n",
    "\n",
    "def patch_context_hh(rel, irrel, source_node_list, target_nodes, level, sa_module):\n",
    "    rel = reshape_separate_attention_heads(rel, sa_module)\n",
    "    irrel = reshape_separate_attention_heads(irrel, sa_module)\n",
    "    \n",
    "    target_nodes_at_level = [node for node in target_nodes if node[0] == level]\n",
    "    target_decomps = []\n",
    "    \n",
    "    for s_ind, sn_list in enumerate(source_node_list):\n",
    "        out_shape = (len(target_nodes_at_level), sa_module.attention_head_size)\n",
    "        \n",
    "        rel_st = torch.zeros(out_shape, dtype = rel.dtype, device = device)\n",
    "        irrel_st = torch.zeros(out_shape, dtype = rel.dtype, device = device)\n",
    "        \n",
    "        for t_ind, t in enumerate(target_nodes_at_level):\n",
    "            if t[0] == level:\n",
    "                t_pos = t[1]\n",
    "                t_head = t[2]\n",
    "\n",
    "                rel_st[t_ind, :] = rel[s_ind, t_pos, t_head, :]\n",
    "                irrel_st[t_ind, :] = irrel[s_ind, t_pos, t_head, :]\n",
    "        \n",
    "        target_decomps.append((rel_st.detach().cpu().numpy(), irrel_st.detach().cpu().numpy()))\n",
    "        \n",
    "        for entry in sn_list:\n",
    "            if entry[0] == level:\n",
    "                pos = entry[1]\n",
    "                att_head = entry[2]\n",
    "\n",
    "                rel[s_ind, pos, att_head, :] = rel[s_ind, pos, att_head, :] + irrel[s_ind, pos, att_head, :]\n",
    "                irrel[s_ind, pos, att_head, :] = 0\n",
    "\n",
    "    \n",
    "    rel = reshape_concatenate_attention_heads(rel, sa_module)\n",
    "    irrel = reshape_concatenate_attention_heads(irrel, sa_module)\n",
    "    \n",
    "    return rel, irrel, target_decomps\n",
    "\n",
    "def patch_context_hh_mean_ablated(rel, irrel, source_node_list, target_nodes, level, layer_patched_values, sa_module):\n",
    "    rel = reshape_separate_attention_heads(rel, sa_module)\n",
    "    irrel = reshape_separate_attention_heads(irrel, sa_module)\n",
    "    \n",
    "    target_nodes_at_level = [node for node in target_nodes if node[0] == level]\n",
    "    target_decomps = []\n",
    "    \n",
    "    if layer_patched_values is not None:\n",
    "        layer_patched_values = layer_patched_values[None, :, :, :]\n",
    "\n",
    "    for s_ind, sn_list in enumerate(source_node_list):\n",
    "        out_shape = (len(target_nodes_at_level), sa_module.attention_head_size)\n",
    "        \n",
    "        rel_st = torch.zeros(out_shape, dtype = rel.dtype, device = device)\n",
    "        irrel_st = torch.zeros(out_shape, dtype = rel.dtype, device = device)\n",
    "\n",
    "        for t_ind, t in enumerate(target_nodes_at_level):\n",
    "            if t[0] == level:\n",
    "                t_pos = t[1]\n",
    "                t_head = t[2]\n",
    "                rel_st[t_ind, :] = rel[s_ind, t_pos, t_head, :]\n",
    "                irrel_st[t_ind, :] = irrel[s_ind, t_pos, t_head, :]\n",
    "\n",
    "        \n",
    "        target_decomps.append((rel_st.detach().cpu().numpy(), irrel_st.detach().cpu().numpy()))\n",
    "        \n",
    "        for entry in sn_list:\n",
    "            if entry[0] == level:\n",
    "                pos = entry[1]\n",
    "                att_head = entry[2]\n",
    "\n",
    "                #rel[s_ind, pos, att_head, :] = rel[s_ind, pos, att_head, :] + irrel[s_ind, pos, att_head, :]\n",
    "                #irrel[s_ind, pos, att_head, :] = 0\n",
    "                \n",
    "                rel[s_ind, pos, att_head, :] = irrel[s_ind, pos, att_head, :] + rel[s_ind, pos, att_head, :] - torch.Tensor(layer_patched_values[:, pos, att_head, :]).to(device)\n",
    "                irrel[s_ind, pos, att_head, :] = torch.Tensor(layer_patched_values[:, pos, att_head, :]).to(device)\n",
    "\n",
    "    \n",
    "    rel = reshape_concatenate_attention_heads(rel, sa_module)\n",
    "    irrel = reshape_concatenate_attention_heads(irrel, sa_module)\n",
    "    \n",
    "    return rel, irrel, target_decomps\n",
    "\n",
    "def prop_self_attention_hh(rel, irrel, attention_mask, \n",
    "                           head_mask, source_node_list, target_nodes, \n",
    "                           level, sa_module, att_probs = None, output_att_prob=False):\n",
    "    \n",
    "    if att_probs is not None:\n",
    "        att_probs = att_probs\n",
    "    else:\n",
    "        att_probs = get_attention_probs(rel[0].unsqueeze(0) + irrel[0].unsqueeze(0), attention_mask, head_mask, sa_module)\n",
    "\n",
    "    rel_value, irrel_value = prop_linear(rel, irrel, sa_module.value)\n",
    "\n",
    "    rel_context = mul_att(att_probs, rel_value, sa_module)\n",
    "\n",
    "    irrel_context = mul_att(att_probs, irrel_value, sa_module)\n",
    "    \n",
    "    #rel_context, irrel_context, target_decomps = patch_context_hh(rel_context, irrel_context, source_node_list, target_nodes, level, sa_module)\n",
    "    \n",
    "    if output_att_prob:\n",
    "        return rel_context, irrel_context, att_probs\n",
    "    else:\n",
    "        return rel_context, irrel_context, None\n",
    "    \n",
    "    #return rel_context, irrel_context, target_decomps\n",
    "\n",
    "def prop_attention_hh(rel, irrel, attention_mask, \n",
    "                      head_mask, source_node_list, target_nodes, level,\n",
    "                      layer_patched_values,\n",
    "                      a_module, att_probs = None, output_att_prob=False, mean_ablated=False):\n",
    "    \n",
    "    rel_context, irrel_context, returned_att_probs = prop_self_attention_hh(rel, irrel, \n",
    "                                                                        attention_mask, \n",
    "                                                                        head_mask, \n",
    "                                                                        source_node_list,\n",
    "                                                                        target_nodes,\n",
    "                                                                        level,\n",
    "                                                                        a_module.self, att_probs,\n",
    "                                                                        output_att_prob=output_att_prob)\n",
    "    normalize_rel_irrel(rel_context, irrel_context)\n",
    "    \n",
    "    output_module = a_module.output\n",
    "    \n",
    "    rel_dense, irrel_dense = prop_linear(rel_context, irrel_context, output_module.dense)\n",
    "    \n",
    "    normalize_rel_irrel(rel_dense, irrel_dense)\n",
    "    \n",
    "    rel_tot = rel_dense + rel\n",
    "    irrel_tot = irrel_dense + irrel\n",
    "    \n",
    "    normalize_rel_irrel(rel_tot, irrel_tot)\n",
    "    \n",
    "    if not mean_ablated:\n",
    "        rel_tot, irrel_tot, target_decomps = patch_context_hh(rel_tot, irrel_tot, source_node_list, target_nodes, level, a_module.self)\n",
    "    else:\n",
    "        rel_tot, irrel_tot, target_decomps = patch_context_hh_mean_ablated(rel_tot, irrel_tot, source_node_list, target_nodes, level,\n",
    "                                                                            layer_patched_values, a_module.self)\n",
    "    \n",
    "    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, output_module.LayerNorm)\n",
    "\n",
    "    normalize_rel_irrel(rel_out, irrel_out)\n",
    "    \n",
    "    return rel_out, irrel_out, target_decomps, returned_att_probs\n",
    "\n",
    "def prop_layer_hh(rel, irrel, attention_mask, head_mask, \n",
    "                  source_node_list, target_nodes, level, layer_patched_values,\n",
    "                  layer_module, att_probs = None, output_att_prob=False, mean_ablated=False):\n",
    "    \n",
    "    rel_a, irrel_a, target_decomps, returned_att_probs = prop_attention_hh(rel, irrel, attention_mask, \n",
    "                                                                           head_mask, source_node_list, \n",
    "                                                                           target_nodes, level, layer_patched_values,\n",
    "                                                                           layer_module.attention,\n",
    "                                                                           att_probs, output_att_prob, mean_ablated=mean_ablated)\n",
    "\n",
    "    i_module = layer_module.intermediate\n",
    "    rel_id, irrel_id = prop_linear(rel_a, irrel_a, i_module.dense)\n",
    "    normalize_rel_irrel(rel_id, irrel_id)\n",
    "    \n",
    "    rel_iact, irrel_iact = prop_act(rel_id, irrel_id, i_module.intermediate_act_fn)\n",
    "    \n",
    "    o_module = layer_module.output\n",
    "    rel_od, irrel_od = prop_linear(rel_iact, irrel_iact, o_module.dense)\n",
    "    normalize_rel_irrel(rel_od, irrel_od)\n",
    "    \n",
    "    rel_tot = rel_od + rel_a\n",
    "    irrel_tot = irrel_od + irrel_a\n",
    "    normalize_rel_irrel(rel_tot, irrel_tot)\n",
    "\n",
    "    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, o_module.LayerNorm)\n",
    "    \n",
    "    \n",
    "    return rel_out, irrel_out, target_decomps, returned_att_probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "id": "631d08da-f5da-4f4e-84a7-8fd892ab0139",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def prop_classifier_model_hh(encoding, model, source_node_list, target_nodes, \n",
    "                             patched_values=None, att_list = None, output_att_prob=False, mean_ablated=False):\n",
    "    embedding_output = get_embeddings_bert(encoding, model.bert)\n",
    "    input_shape = encoding['input_ids'].size()\n",
    "    extended_attention_mask = get_extended_attention_mask(attention_mask = encoding['attention_mask'], \n",
    "                                                          input_shape = input_shape, \n",
    "                                                          bert_model = model.bert,\n",
    "                                                          device = device)\n",
    "    \n",
    "    head_mask = [None] * model.bert.config.num_hidden_layers\n",
    "    encoder_module = model.bert.encoder\n",
    "    \n",
    "    sh = list(embedding_output.shape)\n",
    "    sh[0] = len(source_node_list)\n",
    "    \n",
    "    rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)\n",
    "    irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)\n",
    "    \n",
    "    irrel[:] = embedding_output[:]\n",
    "    \n",
    "    target_decomps = []\n",
    "    att_probs_lst = []\n",
    "    for i, layer_module in enumerate(encoder_module.layer):\n",
    "        layer_head_mask = head_mask[i]\n",
    "        att_probs = None\n",
    "        \n",
    "        if patched_values is not None:\n",
    "            layer_patched_values = patched_values[i] #[512, 12, 64]\n",
    "        else:\n",
    "            layer_patched_values = None\n",
    "            \n",
    "        rel_n, irrel_n, layer_target_decomps, returned_att_probs = prop_layer_hh(rel, irrel, extended_attention_mask, \n",
    "                                                                                 layer_head_mask, source_node_list, \n",
    "                                                                                 target_nodes, i, \n",
    "                                                                                 layer_patched_values,\n",
    "                                                                                 layer_module, att_probs, output_att_prob,\n",
    "                                                                                 mean_ablated=mean_ablated)\n",
    "        target_decomps.append(layer_target_decomps)\n",
    "        normalize_rel_irrel(rel_n, irrel_n)\n",
    "        rel, irrel = rel_n, irrel_n\n",
    "        \n",
    "        if output_att_prob:\n",
    "            att_probs_lst.append(returned_att_probs.squeeze(0))\n",
    "    \n",
    "    rel_pool, irrel_pool = prop_pooler(rel, irrel, model.bert.pooler)\n",
    "    rel_out, irrel_out = prop_linear(rel_pool, irrel_pool, model.classifier)\n",
    "    \n",
    "    out_decomps = []\n",
    "\n",
    "    for i, sn_list in enumerate(source_node_list):\n",
    "        rel_vec = rel_out[i, :].detach().cpu().numpy()\n",
    "        irrel_vec = irrel_out[i, :].detach().cpu().numpy()\n",
    "        \n",
    "        out_decomps.append((rel_vec, irrel_vec))\n",
    "    \n",
    "    return out_decomps, target_decomps, att_probs_lst\n",
    "\n",
    "def prop_classifier_model_hh_batched(encoding, model, source_node_list, target_nodes, patched_values=None, \n",
    "                                     num_at_time = 64, n_layers = 12, att_list = None, output_att_prob=False, mean_ablated=False):\n",
    "    out_decomps = []\n",
    "    target_decomps = [[] for i in range(n_layers)]\n",
    "    \n",
    "    n_source_lists = len(source_node_list)\n",
    "    n_batches = int((n_source_lists + (num_at_time - 1)) / num_at_time)\n",
    "\n",
    "    for b_no in range(n_batches):\n",
    "        b_st = b_no * num_at_time\n",
    "        b_end = min(b_st + num_at_time, n_source_lists)\n",
    "        layer_out_decomps, layer_target_decomps, att_probs_lst = prop_classifier_model_hh(encoding, model, \n",
    "                                                                           source_node_list[b_st: b_end],\n",
    "                                                                           target_nodes, patched_values,\n",
    "                                                                           att_list=att_list,\n",
    "                                                                           output_att_prob=output_att_prob,\n",
    "                                                                           mean_ablated=mean_ablated)\n",
    "        out_decomps = out_decomps + layer_out_decomps\n",
    "        target_decomps = [target_decomps[i] + layer_target_decomps[i] for i in range(n_layers)]\n",
    "    \n",
    "    return out_decomps, target_decomps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "id": "1e570be6-3ba0-4ab5-8ef4-1655c263719d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# codes for second pass: ablate the target nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "id": "066702b9-e8c3-4f33-b1c9-bbeae9c51269",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def prop_self_attention_patched(rel, irrel, attention_mask, \n",
    "                                head_mask, patched_entries, layer_patched_values,\n",
    "                                sa_module, att_probs = None):\n",
    "    \n",
    "    if att_probs is not None:\n",
    "        att_probs = att_probs\n",
    "    else:\n",
    "        att_probs = get_attention_probs(rel + irrel, attention_mask, head_mask, sa_module)\n",
    "    \n",
    "    rel_value, irrel_value = prop_linear(rel, irrel, sa_module.value)\n",
    "    \n",
    "    rel_context = mul_att(att_probs, rel_value, sa_module)\n",
    "    irrel_context = mul_att(att_probs, irrel_value, sa_module)\n",
    "    \n",
    "    #rel_context, irrel_context = patch_context(rel_context, irrel_context, patched_entries, layer_patched_values, sa_module)\n",
    "    \n",
    "    return rel_context, irrel_context\n",
    "    \n",
    "def patch_context_baseline(rel, irrel, patched_entries, layer_patched_values, sa_module):\n",
    "    rel = reshape_separate_attention_heads(rel, sa_module)\n",
    "    irrel = reshape_separate_attention_heads(irrel, sa_module)\n",
    "\n",
    "    for i, entry in enumerate(patched_entries):\n",
    "        pos = entry[1]\n",
    "        att_head = entry[2]\n",
    "        \n",
    "        saved_rel = torch.Tensor(layer_patched_values[0][i])\n",
    "        saved_irrel = torch.Tensor(layer_patched_values[1][i])\n",
    "        \n",
    "        rel[:, pos, att_head, :] = saved_rel\n",
    "        irrel[:, pos, att_head, :] = saved_irrel\n",
    "        \n",
    "    rel = reshape_concatenate_attention_heads(rel, sa_module)\n",
    "    irrel = reshape_concatenate_attention_heads(irrel, sa_module)\n",
    "    return rel, irrel\n",
    "\n",
    "def prop_attention_patched_baseline(rel, irrel, attention_mask, \n",
    "                           head_mask, patched_entries, layer_patched_values, a_module, \n",
    "                           att_probs = None):\n",
    "\n",
    "    \n",
    "    rel_context, irrel_context = prop_self_attention_patched(rel, irrel, \n",
    "                                                             attention_mask, \n",
    "                                                             head_mask, \n",
    "                                                             patched_entries,\n",
    "                                                             layer_patched_values,\n",
    "                                                             a_module.self, att_probs)\n",
    "\n",
    "    output_module = a_module.output\n",
    "    \n",
    "    rel_dense, irrel_dense = prop_linear(rel_context, irrel_context, output_module.dense)\n",
    "    rel_tot = rel_dense + rel\n",
    "    irrel_tot = irrel_dense + irrel\n",
    "        \n",
    "    rel_tot, irrel_tot = patch_context_baseline(rel_tot, irrel_tot, patched_entries, layer_patched_values, a_module.self)\n",
    "    \n",
    "    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, output_module.LayerNorm)\n",
    "    \n",
    "    return rel_out, irrel_out\n",
    "\n",
    "def prop_layer_patched(rel, irrel, attention_mask, head_mask, patched_entries, layer_patched_values, layer_module, att_probs = None):\n",
    "    \n",
    "    rel_a, irrel_a = prop_attention_patched_baseline(rel, irrel, attention_mask, head_mask,\n",
    "                                                     patched_entries, layer_patched_values,\n",
    "                                                     layer_module.attention, att_probs)\n",
    "    \n",
    "    i_module = layer_module.intermediate\n",
    "    rel_id, irrel_id = prop_linear(rel_a, irrel_a, i_module.dense)\n",
    "    rel_iact, irrel_iact = prop_act(rel_id, irrel_id, i_module.intermediate_act_fn)\n",
    "    \n",
    "    o_module = layer_module.output\n",
    "    rel_od, irrel_od = prop_linear(rel_iact, irrel_iact, o_module.dense)\n",
    "    \n",
    "    rel_tot = rel_od + rel_a\n",
    "    irrel_tot = irrel_od + irrel_a\n",
    "    \n",
    "    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, o_module.LayerNorm)\n",
    "    \n",
    "    # import pdb; pdb.set_trace()\n",
    "    \n",
    "    return rel_out, irrel_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "id": "5e921dbf-39de-41d6-9dfb-4c7776cf7be5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def ablate_target_nodes(encoding, model, patched_entries, patched_values=None, att_list = None):\n",
    "    \n",
    "    embedding_output = get_embeddings_bert(encoding, model.bert)\n",
    "    input_shape = encoding['input_ids'].size()\n",
    "    extended_attention_mask = get_extended_attention_mask(attention_mask = encoding['attention_mask'], \n",
    "                                                          input_shape = input_shape, \n",
    "                                                          bert_model = model.bert,\n",
    "                                                          device = device)\n",
    "    \n",
    "    head_mask = [None] * model.bert.config.num_hidden_layers\n",
    "    encoder_module = model.bert.encoder\n",
    "    \n",
    "    sh = list(embedding_output.shape)\n",
    "    \n",
    "    rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)\n",
    "    irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)\n",
    "    \n",
    "    irrel[:] = embedding_output[:]\n",
    "\n",
    "    for i, layer_module in enumerate(encoder_module.layer):\n",
    "        layer_patched_entries = [p_entry for p_entry in patched_entries if p_entry[0] == i]\n",
    "        layer_head_mask = head_mask[i]\n",
    "        att_probs = None\n",
    "        \n",
    "        if patched_values is not None:\n",
    "            layer_patched_values = patched_values[i]\n",
    "        else:\n",
    "            layer_patched_values = None\n",
    "        \n",
    "        rel_n, irrel_n = prop_layer_patched(rel, irrel, extended_attention_mask,\n",
    "                                            layer_head_mask, layer_patched_entries,\n",
    "                                            layer_patched_values,\n",
    "                                            layer_module, att_probs)\n",
    "        normalize_rel_irrel(rel_n, irrel_n)\n",
    "        rel, irrel = rel_n, irrel_n\n",
    "    \n",
    "    rel_pool, irrel_pool = prop_pooler(rel, irrel, model.bert.pooler)\n",
    "    rel_out, irrel_out = prop_linear(rel_pool, irrel_pool, model.classifier)\n",
    "    \n",
    "    return rel_out, irrel_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "id": "0b008fcb-8bcf-4033-b58a-bcb5aadc5462",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "path = f\"{base_dir}/output/{args['task']}/{args['model_type']}_{args['field']}/h_to_logits\"\n",
    "os.makedirs(path, exist_ok=True)\n",
    "\n",
    "with open(os.path.join(path, f\"mean_head_out_res_500.pkl\"), 'rb') as handle:\n",
    "    back = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "id": "693cd0e8-2fd3-4e60-bb2f-00efa1356dff",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "def patch_hh_at_pos_baseline(encoding, label_idx, model, target_nodes, pos=0, mean_act=None, mean_ablated=False):\n",
    "    \n",
    "    raw_logit = model(**encoding)[0][0][label_idx]\n",
    "    \n",
    "    pos_specific_hs = [\n",
    "        [i for i in range(12)],\n",
    "        [pos],\n",
    "        [i for i in range(12)]\n",
    "    ]\n",
    "    all_heads = list(itertools.product(*pos_specific_hs))\n",
    "\n",
    "    # patch one node at a time\n",
    "    h_ctbn_list = []\n",
    "    \n",
    "    source_list = [[node] for node in all_heads if node not in target_nodes]\n",
    "    out_decomps, target_decomps = prop_classifier_model_hh_batched(encoding, model, source_list, target_nodes,\n",
    "                                                                  patched_values=mean_act, mean_ablated=True)\n",
    "    \n",
    "    for i, _ in enumerate(source_list):\n",
    "        tmp = []\n",
    "        for l in range(12):\n",
    "            if target_decomps[l][i][0].shape[0] != 0:\n",
    "                tmp.append(target_decomps[l][i])\n",
    "            else:\n",
    "                tmp.append([])\n",
    "        \n",
    "        rel_out, irrel_out = ablate_target_nodes(encoding, model, target_nodes, tmp, att_list = None)\n",
    "        logit_diff = (rel_out[0][label_idx] + irrel_out[0][label_idx]) - raw_logit\n",
    "        h_ctbn_list.append(logit_diff / abs(raw_logit) * 100)\n",
    "        \n",
    "    return source_list, h_ctbn_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "id": "dff0fe2e-2ba8-4470-9dcd-7c65b048ccfe",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "text = documents_full[0]\n",
    "label = labels_full[0]\n",
    "encoding = get_encoding(text, tokenizer, device)\n",
    "label_idx = le_dict[label]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "id": "bd8d1919-0e91-4f59-a07d-05847645d6ea",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 512/512 [3:37:26<00:00, 25.48s/it]  \n"
     ]
    }
   ],
   "source": [
    "#target_nodes = [(11, 0, 1), (11, 0, 7), (11, 0, 5), (11, 0, 3), (11, 0, 0), (11, 0, 8)]\n",
    "#target_nodes = [(8, 132, 1), (8, 275, 0), (6, 397, 1), (8, 66, 6), (8, 380, 8), (1, 195, 0)]\n",
    "#target_nodes = [(1, 169, 2), (2, 169, 2), (2, 169, 3), (4, 169, 8), (1, 411, 3), (2, 169, 1)]\n",
    "#### backup^^^\n",
    "\n",
    "target_nodes = [(11, 506, 6), (11, 506, 7), (11, 506, 8), (11, 506, 9), (11, 506, 10), (11, 506, 11)]\n",
    "\n",
    "all_source_hs = []\n",
    "all_htbn = []\n",
    "for pos in tqdm.tqdm(range(512)):\n",
    "    with torch.no_grad():\n",
    "        source_list, h_ctbn_list = patch_hh_at_pos_baseline(encoding, label_idx, model, target_nodes,\n",
    "                                                            pos=pos, mean_act=back, mean_ablated=True)\n",
    "    torch.cuda.empty_cache()\n",
    "    all_source_hs.append(source_list)\n",
    "    all_htbn.append(h_ctbn_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d9fd727-b1bd-4f60-a374-f8ba09399c6d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "h_ctbn_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "id": "e6c0bf3b-0617-4884-80d7-8091ddedce64",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "flat_ctbn = [c for sublist in all_htbn for c in sublist]\n",
    "flat_source_h = [c for sublist in all_source_hs for c in sublist]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "id": "d0e92299-54f4-4428-bf98-b920b237831e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "top_idx = sorted(range(len(flat_ctbn)), key=lambda i: flat_ctbn[i])[-6:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "id": "7d48ead3-05fc-43d4-b90f-a32d822bc2a6",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(11, 511, 6)] tensor(-1.5471e-05, device='cuda:0')\n",
      "[(11, 511, 7)] tensor(-1.5471e-05, device='cuda:0')\n",
      "[(11, 511, 8)] tensor(-1.5471e-05, device='cuda:0')\n",
      "[(11, 511, 9)] tensor(-1.5471e-05, device='cuda:0')\n",
      "[(11, 511, 10)] tensor(-1.5471e-05, device='cuda:0')\n",
      "[(11, 511, 11)] tensor(-1.5471e-05, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "for i in top_idx:\n",
    "    print(flat_source_h[i], flat_ctbn[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "id": "98bc384b-e405-4815-b919-3b6d03739e2f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# mean-ablated\n",
    "import pickle\n",
    "\n",
    "path = f\"{base_dir}/output/{args['task']}/{args['model_type']}_{args['field']}/h2\"\n",
    "os.makedirs(path, exist_ok=True)\n",
    "\n",
    "with open(os.path.join(path, f\"flat_source_h_baseline.pkl\"), 'wb') as handle:\n",
    "    pickle.dump(flat_source_h, handle)\n",
    "    \n",
    "with open(os.path.join(path, f\"flat_source_h_baseline.pkl\"), 'rb') as handle:\n",
    "    back = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "id": "47c83989-1799-4a6b-b754-5f2e91832080",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(2, 169, 1)]"
      ]
     },
     "execution_count": 110,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "back[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "id": "ec1bf9c5-979b-4473-8958-e34adab200a1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def collect_attended_tokens(positives_heads, N=100, Z_thres=2):\n",
    "    index_lst = random.sample(range(0, len(documents_full)), N)\n",
    "    docs = [documents_full[i] for i in index_lst]\n",
    "    \n",
    "    collect = collections.defaultdict(int)\n",
    "    for doc in docs:\n",
    "        encoding = get_encoding(doc, tokenizer, device)\n",
    "        \n",
    "        _, _, raw_att_probs_lst = prop_classifier_model_hh(encoding, model, [[]], [], output_att_prob=True)\n",
    "        raw_att_probs = torch.stack(raw_att_probs_lst).cpu().numpy()\n",
    "\n",
    "        avg_att_m = np.zeros((512))\n",
    "        for level, pos, h in positives_heads:\n",
    "            att_m = raw_att_probs[level, h, pos, :]\n",
    "            avg_att_m += att_m\n",
    "\n",
    "        avg_att_m /= len(positives)\n",
    "        \n",
    "        # convert to word level\n",
    "        interval_dict, word_lst = compute_word_intervals(encoding)\n",
    "        word_att_m = combine_token_attn(interval_dict, avg_att_m)\n",
    "        \n",
    "        Z = (word_att_m - np.mean(word_att_m)) / np.std(word_att_m)\n",
    "\n",
    "        positive_words = np.where(Z > Z_thres)\n",
    "        \n",
    "        for w_idx in positive_words[0]:\n",
    "            w = word_lst[w_idx]\n",
    "            #collect[w] += 1\n",
    "            collect[w] += word_att_m[w_idx]\n",
    "            \n",
    "    return collect\n",
    "\n",
    "\n",
    "def combine_token_attn(interval_dict, avg_att_m):\n",
    "    word_cnt = len(interval_dict)\n",
    "    new_att_m = np.zeros(word_cnt)\n",
    "    for i in range(word_cnt):\n",
    "        t_idx_lst = interval_dict[i+1]\n",
    "        if len(t_idx_lst) == 1:\n",
    "            new_att_m[i] = avg_att_m[t_idx_lst[0]]\n",
    "        else:\n",
    "            new_att_m[i] = np.sum(avg_att_m[t_idx_lst[0]:t_idx_lst[-1]+1])\n",
    "    return new_att_m\n",
    "\n",
    "\n",
    "def compute_word_intervals(encoding):\n",
    "    word_cnt = 0\n",
    "    interval_dict = collections.defaultdict(list)\n",
    "    \n",
    "    pretok_sent = \"\"\n",
    "    for i in range(512):\n",
    "        tok = tokenizer.decode(encoding['input_ids'][:, i])\n",
    "        if tok.startswith(\"##\"):\n",
    "            interval_dict[word_cnt].append(i)\n",
    "            pretok_sent += tok[2:]\n",
    "        else:\n",
    "            word_cnt += 1\n",
    "            interval_dict[word_cnt].append(i)\n",
    "            pretok_sent += \" \" + tok\n",
    "    pretok_sent = pretok_sent[1:]\n",
    "    word_lst = pretok_sent.split(\" \")\n",
    "    \n",
    "    assert(len(interval_dict) == len(word_lst))\n",
    "    \n",
    "    return interval_dict, word_lst"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "id": "42c0c0cc-0f63-466d-b586-822bbca6e8f7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#h1 = [(11, 0, 1), (11, 0, 7), (11, 0, 5), (11, 0, 3), (11, 0, 0), (11, 0, 8)]\n",
    "#h2 = [(8, 132, 1), (8, 275, 0), (6, 397, 1), (8, 66, 6), (8, 380, 8), (1, 195, 0)]\n",
    "h3 = [(1, 169, 2), (2, 169, 2), (2, 169, 3), (4, 169, 8), (1, 411, 3), (2, 169, 1)]\n",
    "positives = h3\n",
    "positive_attended_token_freq = collect_attended_tokens(positives, N=500, Z_thres=3)\n",
    "positive_attended_token_freq = sorted(positive_attended_token_freq.items(), key=lambda k_v: k_v[1], reverse=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "id": "5fb143fa-8d8d-4185-8931-53d94c329144",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "with open('result_h3_baseline.json', 'w') as fp:\n",
    "    json.dump(positive_attended_token_freq, fp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf57390b-4eed-48e5-a721-0c7afc9d1747",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e41caeed-2bc6-4d62-9ff7-8b382ebfe062",
   "metadata": {},
   "source": [
    "### Function description:\n",
    "\n",
    "prop_classifier_model_hh_batched(encoding, model, source_list, target_nodes):\n",
    "\n",
    "- encoding - Encoding given by tokenizer\n",
    "- model - BERT model\n",
    "- source_list - List of lists where each list consists of tuples (layer, position, head) indexing a particular attention head whose influence is to be calculated\n",
    "- target_nodes - A single list of tuples (layer, position, head) containing attention heads on whom the influence is to be measured\n",
    "- num_at_time (optional) - Number of source_lists to be processed in a batch\n",
    "- n_layers - Number of layers\n",
    "- att_list - Attention probabilities if precomputed\n",
    "\n",
    "Output consists of two lists - out_decomps and target_decomps:\n",
    "- out_decomps - Consists of a list of tuples (rel, irrel) reflecting the decomposition of the _output_\n",
    "- target_decomps - A list containining 12 (one for each layer) where each list is of length len(source_list). For any layer l, each entry of target_decomps[l] is a tuple (rel, irrel) decomposition of the target nodes at that layer for the corresponding set of source nodes. rel, irrel are of dimension #number of target nodes in layer l x head_size and the ordering of the target nodes in this layer is the same as provided "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 324,
   "id": "82d049e2-f8f0-4bc9-b8e3-61916f9dbf98",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 325,
   "id": "b186954a-e2e3-4680-80d3-ffd53b187920",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.329473972320557\n",
      "4.332589626312256\n",
      "8.348850011825562\n",
      "8.351091861724854\n"
     ]
    }
   ],
   "source": [
    "st = time.time()\n",
    "out_decomps, target_decomps = prop_classifier_model_hh_batched(encoding, ft_model, source_list_30, target_nodes)\n",
    "end = time.time()\n",
    "print(end - st)\n",
    "\n",
    "st = time.time()\n",
    "out_decomps, target_decomps = prop_classifier_model_hh_batched(encoding, ft_model, source_list_60, target_nodes)\n",
    "end = time.time()\n",
    "print(end - st)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00bbce98-d33b-4799-9855-ddeff4a9396a",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_decomps[11]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18aac387-41f8-450a-8adb-e6ae9d26469d",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "# Appendix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58ce1576-692c-4ea2-91dc-35bfed96e2e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def patch_context_dot_w_embed(embed, rel, irrel, patched_entries, sa_module):\n",
    "    rel = reshape_separate_attention_heads(rel, sa_module)\n",
    "    irrel = reshape_separate_attention_heads(irrel, sa_module)\n",
    "    \n",
    "    for entry in patched_entries:\n",
    "        pos = entry[1]\n",
    "        att_head = entry[2]\n",
    "\n",
    "        rel[:, pos, att_head, :] = rel[:, pos, att_head, :] + irrel[:, pos, att_head, :]\n",
    "        irrel[:, pos, att_head, :] = 0\n",
    "        #rel[:, pos, att_head, :] = 0\n",
    "        #irrel[:, pos, att_head, :] = rel[:, pos, att_head, :] + irrel[:, pos, att_head, :]\n",
    "\n",
    "        \n",
    "    \n",
    "    rel = reshape_concatenate_attention_heads(rel, sa_module)\n",
    "    irrel = reshape_concatenate_attention_heads(irrel, sa_module)\n",
    "    \n",
    "    return rel, irrel\n",
    "\n",
    "def prop_self_attention_patched_dot_w_embed(embed, rel, irrel, attention_mask, \n",
    "                                head_mask, patched_entries, \n",
    "                                sa_module, att_probs = None, output_att_prob=False):\n",
    "    if att_probs is not None:\n",
    "        att_probs = att_probs\n",
    "    else:\n",
    "        att_probs = get_attention_probs(rel + irrel, attention_mask, head_mask, sa_module)\n",
    "    \n",
    "    rel_value, irrel_value = prop_linear(rel, irrel, sa_module.value)\n",
    "    \n",
    "    rel_context = mul_att(att_probs, rel_value, sa_module)\n",
    "    irrel_context = mul_att(att_probs, irrel_value, sa_module)\n",
    "    \n",
    "    rel_context, irrel_context = patch_context(embed, rel_context, irrel_context, patched_entries, sa_module)\n",
    "    \n",
    "    if output_att_prob:\n",
    "        return rel_context, irrel_context, att_probs\n",
    "    else:\n",
    "        return rel_context, irrel_context, None\n",
    "    \n",
    "def prop_attention_patched_dot_w_embed(embed, rel, irrel, attention_mask, \n",
    "                           head_mask, patched_entries, a_module, \n",
    "                           att_probs = None,\n",
    "                           output_att_prob=False):\n",
    "    \n",
    "    rel_context, irrel_context, returned_att_probs = prop_self_attention_patched(rel, irrel, \n",
    "                                                             attention_mask, \n",
    "                                                             head_mask, \n",
    "                                                             patched_entries,\n",
    "                                                             a_module.self, att_probs, output_att_prob)\n",
    "    \n",
    "    # if len(patched_entries):\n",
    "    #     print(rel_context[0, 0, :])\n",
    "    #     print(irrel_context[0, 0, :])\n",
    "    \n",
    "    output_module = a_module.output\n",
    "    \n",
    "    rel_dense, irrel_dense = prop_linear(rel_context, irrel_context, output_module.dense)\n",
    "    rel_tot = rel_dense + rel\n",
    "    irrel_tot = irrel_dense + irrel\n",
    "    \n",
    "    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, output_module.LayerNorm)\n",
    "    \n",
    "    return rel_out, irrel_out, returned_att_probs\n",
    "\n",
    "def prop_layer_patched_dot_w_embed(embed, rel, irrel, attention_mask, head_mask, patched_entries, layer_module, att_probs = None, output_att_prob=False):\n",
    "    rel_a, irrel_a, returned_att_probs = prop_attention_patched_dot_w_embed(embed, rel, irrel, attention_mask, head_mask, patched_entries, layer_module.attention, att_probs, output_att_prob)\n",
    "    \n",
    "    i_module = layer_module.intermediate\n",
    "    rel_id, irrel_id = prop_linear(rel_a, irrel_a, i_module.dense)\n",
    "    rel_iact, irrel_iact = prop_act(rel_id, irrel_id, i_module.intermediate_act_fn)\n",
    "    \n",
    "    o_module = layer_module.output\n",
    "    rel_od, irrel_od = prop_linear(rel_iact, irrel_iact, o_module.dense)\n",
    "    \n",
    "    rel_tot = rel_od + rel_a\n",
    "    irrel_tot = irrel_od + irrel_a\n",
    "    \n",
    "    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, o_module.LayerNorm)\n",
    "    \n",
    "    # import pdb; pdb.set_trace()\n",
    "    \n",
    "    return rel_out, irrel_out, returned_att_probs\n",
    "\n",
    "def prop_classifier_model_patched_dot_w_embed(encoding, model, patched_entries, att_list = None, output_att_prob=False):\n",
    "    # patched_entries: attention heads to patch. format: [(level, pos, head)]\n",
    "    # level: 0-11, pos: 0-511, head: 0-11\n",
    "    # rel_out: the contribution of the patched_entries\n",
    "    # irrel_out: the contribution of everything else\n",
    "    \n",
    "    embedding_output = get_embeddings_bert(encoding, model.bert)\n",
    "    input_shape = encoding['input_ids'].size()\n",
    "    extended_attention_mask = get_extended_attention_mask(attention_mask = encoding['attention_mask'], \n",
    "                                                          input_shape = input_shape, \n",
    "                                                          bert_model = model.bert)\n",
    "    \n",
    "    head_mask = [None] * model.bert.config.num_hidden_layers\n",
    "    encoder_module = model.bert.encoder\n",
    "    \n",
    "    sh = list(embedding_output.shape)\n",
    "    \n",
    "    rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)\n",
    "    irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)\n",
    "    \n",
    "    #rel[:] = embedding_output[:]\n",
    "    irrel[:] = embedding_output[:]\n",
    "\n",
    "    att_probs_lst = []\n",
    "    for i, layer_module in enumerate(encoder_module.layer):\n",
    "        layer_patched_entries = [p_entry for p_entry in patched_entries if p_entry[0] == i]\n",
    "        layer_head_mask = head_mask[i]\n",
    "        att_probs = None\n",
    "        rel_n, irrel_n, returned_att_probs = prop_layer_patched_dot_w_embed(embedding_output, rel, irrel, extended_attention_mask,\n",
    "                                                                layer_head_mask, layer_patched_entries,\n",
    "                                                                layer_module, att_probs, output_att_prob)\n",
    "        normalize_rel_irrel(rel_n, irrel_n)\n",
    "        rel, irrel = rel_n, irrel_n\n",
    "        \n",
    "        if output_att_prob:\n",
    "            att_probs_lst.append(returned_att_probs.squeeze(0))\n",
    "    \n",
    "    rel_pool, irrel_pool = prop_pooler(rel, irrel, model.bert.pooler)\n",
    "    rel_out, irrel_out = prop_linear(rel_pool, irrel_pool, model.classifier)\n",
    "    \n",
    "    return rel_out, irrel_out, att_probs_lst"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da689b72-7751-4c27-862f-edb4b89b123c",
   "metadata": {},
   "outputs": [],
   "source": [
    "text = documents_full[0]\n",
    "label = labels_full[0]\n",
    "encoding = get_encoding(text, tokenizer).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3807ab08-5c62-4fc4-93b2-4f55f844201c",
   "metadata": {},
   "outputs": [],
   "source": [
    "rel, irrel, _ = prop_classifier_model_patched_dot_w_embed(encoding, model, [(11, 0, 0)])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
