{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "e61a0720",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "8a120102",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import argparse\n",
    "import numpy as np\n",
    "import os\n",
    "import pandas as pd\n",
    "import scipy as sp\n",
    "import sys\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import warnings\n",
    "import random\n",
    "import collections\n",
    "import functools\n",
    "\n",
    "# CD-T Imports\n",
    "import math\n",
    "import tqdm\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "import itertools\n",
    "\n",
    "from torch import nn\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "base_dir = os.path.split(os.getcwd())[0]\n",
    "sys.path.append(base_dir)\n",
    "\n",
    "from argparse import Namespace\n",
    "from methods.bag_of_ngrams.processing import cleanReports, cleanSplit, stripChars\n",
    "from pyfunctions.general import extractListFromDic, readJson, combine_token_attn, compute_word_intervals\n",
    "from pyfunctions.pathology import extract_synoptic, fixLabelProstateGleason, fixProstateLabels, fixLabel, exclude_labels\n",
    "from pyfunctions.cdt_from_source_nodes import *\n",
    "from pyfunctions.cdt_source_to_target import *\n",
    "from sklearn import preprocessing\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "from transformers import BertTokenizer, BertForSequenceClassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f6d6ada4-4781-4789-b3b1-1d044c11b3d3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x717706b71820>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.autograd.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7651660-7f59-4b59-a574-afecc52dc306",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "## Model Arguments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c3183be1-3bf6-4f5a-8134-9bdd83db0a56",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "args = {\n",
    "    'model_type': 'bert', # bert, medical_bert, pubmed_bert, biobert, clinical_biobert\n",
    "    'task': 'path',\n",
    "    'field': 'PrimaryGleason'\n",
    "}\n",
    "\n",
    "device = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8053ccd9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if args['model_type'] == 'bert':\n",
    "    bert_path = 'bert-base-uncased'\n",
    "    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "elif args['model_type'] == 'medical_bert':\n",
    "    bert_path = f\"{base_dir}/models/pretrained/bert_pretrain_output_all_notes_150000/\"\n",
    "    tokenizer = BertTokenizer.from_pretrained(bert_path, local_files_only=True)\n",
    "elif args['model_type'] == 'pubmed_bert':\n",
    "    bert_path = \"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\"\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\")\n",
    "elif args['model_type'] == 'pubmed_bert_full':\n",
    "    bert_path = \"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext\"\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext\")\n",
    "elif args['model_type'] == 'biobert':\n",
    "    bert_path = \"dmis-lab/biobert-v1.1\"\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"dmis-lab/biobert-v1.1\")\n",
    "elif args['model_type'] == 'clinical_biobert':\n",
    "    bert_path = \"emilyalsentzer/Bio_ClinicalBERT\"\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"emilyalsentzer/Bio_ClinicalBERT\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d377977-fe3b-45dd-9d00-1c19e5366038",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cba4d6b-b67c-4ae9-8bc2-b4acb7ce1a65",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "you can cutomize the code here to read in your own data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4ea84464",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (1345 > 512). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2066 517 324\n"
     ]
    }
   ],
   "source": [
    "# Read in data\n",
    "#field = 'PrimaryGleason' # out of PrimaryGleason, SecondaryGleason', 'MarginStatusNone', 'SeminalVesicleNone'\n",
    "path = f\"../data/prostate.json\"\n",
    "data = readJson(path)\n",
    "\n",
    "# Clean reports\n",
    "data = cleanSplit(data, stripChars)\n",
    "data['dev_test'] = cleanReports(data['dev_test'], stripChars)\n",
    "data = fixLabel(data)\n",
    "\n",
    "train_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['train']]\n",
    "val_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['val']]\n",
    "test_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['test']]\n",
    "print(len(train_documents), len(val_documents),len(test_documents))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "925a80b9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'synoptic comment for prostate tumors null - type of tumor : small acinar adenocarcinoma. - location of tumor : left posterior mid gland ( slides d17 and d18 ). - estimated volume of tumor : 0. 3 cm3. - gleason score : 4 + 3. - estimated volume > gleason pattern 3 : 70 %. - involvement of capsule : tumor invades capsule in left posterior mid section slide d18. - extraprostatic extension : none. - margin status for tumor : - negative. - margin status for benign prostate glands : - no benign glands present at inked excision margins. - high - grade prostatic intraepithelial neoplasia ( hgpin ) : present extensively. - tumor involvement of seminal vesicle : none. - perineural infiltration : present ( slides d17 and d18 ). - lymph node status : - negative ; total number of nodes examined : 15 ( parts b and c ). - ajcc / uicc stage : pt2an0. null null null null specimen ( s ) received a : anterior prostatic fat b : lymph node right pelvic c : lymph node left pelvic d : prostate and bilateral seminal vesicles null null clinical history the patient is a 60 - year - old man with a history of prostate cancer. null null gross description the case is received in four parts each labeled with the patient\\'s name and medical number. parts a - c are received in formalin and part d is received fresh. null part a is additionally labeled \" anterior prostatic fat \" and consists of two soft yellow adipose tissue fragments ( 1. 5 x 1. 5 x 0. 4 cm in aggregate ). no nodes are palpated. the entire specimen is submitted in cassette a1. ( rsm ) null part b is additionally labeled \" right pelvic lymph nodes \" and consists of multiple soft tan - yellow fibrofatty tissue fragments ( 5 x 4 x 1. 5 cm in aggregate ). the specimen is extensively searched for lymph nodes. nine yellow tan rubbery candidate nodes ( 0. 6 - 1. 4 cm in greatest dimension ) are found and submitted as follows : b1 : four intact nodes. b2 : four intact nodes. b3 : one node inked blue and bisected. ( rsm ) null part c is additionally labeled \" left pelvic lymph nodes \" consists of multiple soft'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#data['train'][0]\n",
    "train_documents[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dae7a95c-3b73-42b0-84fe-9aa08fd927e3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Create datasets\n",
    "train_labels = [patient['labels'][args['field']] for patient in data['train']]\n",
    "val_labels = [patient['labels'][args['field']] for patient in data['val']]\n",
    "test_labels = [patient['labels'][args['field']] for patient in data['test']]\n",
    "\n",
    "train_documents, train_labels = exclude_labels(train_documents, train_labels)\n",
    "val_documents, val_labels = exclude_labels(val_documents, val_labels)\n",
    "test_documents, test_labels = exclude_labels(test_documents, test_labels)\n",
    "\n",
    "le = preprocessing.LabelEncoder()\n",
    "le.fit(train_labels)\n",
    "\n",
    "# Map raw label to processed label\n",
    "le_dict = dict(zip(le.classes_, le.transform(le.classes_)))\n",
    "le_dict = {str(key):le_dict[key] for key in le_dict}\n",
    "\n",
    "for label in val_labels + test_labels:\n",
    "    if str(label) not in le_dict:\n",
    "        le_dict[str(label)] = len(le_dict)\n",
    "\n",
    "# Map processed label back to raw label\n",
    "inv_le_dict = {v: k for k, v in le_dict.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d23e85ca",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "documents_full = train_documents + val_documents + test_documents\n",
    "labels_full = train_labels + val_labels + test_labels"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6ffd4d3-5e8f-4587-bb2a-f06b61918c09",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "## Load Trained Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dd45f9a1",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BertForSequenceClassification(\n",
       "  (bert): BertModel(\n",
       "    (embeddings): BertEmbeddings(\n",
       "      (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
       "      (position_embeddings): Embedding(512, 768)\n",
       "      (token_type_embeddings): Embedding(2, 768)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (encoder): BertEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0-11): 12 x BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSdpaSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (pooler): BertPooler(\n",
       "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "      (activation): Tanh()\n",
       "    )\n",
       "  )\n",
       "  (dropout): Dropout(p=0.1, inplace=False)\n",
       "  (classifier): Linear(in_features=768, out_features=3, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#load finetuned model\n",
    "model_path = f\"{base_dir}/PG_best_ckpts/{args['model_type']}\" #{args['task']}/{args['model_type']}_{args['field']}\"\n",
    "checkpoint_file = f\"{model_path}/save_output\"\n",
    "config_file = f\"{model_path}/save_output/config.json\"\n",
    "\n",
    "model = BertForSequenceClassification.from_pretrained(checkpoint_file, num_labels=len(le_dict), output_hidden_states=True)\n",
    "\n",
    "model = model.eval()\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "9d739595",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.float32\n"
     ]
    }
   ],
   "source": [
    "print(next(model.parameters()).dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "142dc16a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BertLayer(\n",
      "  (attention): BertAttention(\n",
      "    (self): BertSdpaSelfAttention(\n",
      "      (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "      (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "      (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "      (dropout): Dropout(p=0.1, inplace=False)\n",
      "    )\n",
      "    (output): BertSelfOutput(\n",
      "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "      (dropout): Dropout(p=0.1, inplace=False)\n",
      "    )\n",
      "  )\n",
      "  (intermediate): BertIntermediate(\n",
      "    (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "    (intermediate_act_fn): GELUActivation()\n",
      "  )\n",
      "  (output): BertOutput(\n",
      "    (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "    (dropout): Dropout(p=0.1, inplace=False)\n",
      "  )\n",
      ")\n",
      "<class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'>\n"
     ]
    }
   ],
   "source": [
    "#model\n",
    "print(model.bert.encoder.layer[0])\n",
    "print(type(model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "8bad2cbb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "python(11648) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting torchsummary\n",
      "  Downloading torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)\n",
      "Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)\n",
      "Installing collected packages: torchsummary\n",
      "Successfully installed torchsummary-1.5.1\n"
     ]
    }
   ],
   "source": [
    "!pip install torchsummary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4094e30",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsummary import summary\n",
    "summary(model, )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76277da4-2c38-4e47-8c4a-1c5b462943b7",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "## Head to head direct influence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0b008fcb-8bcf-4033-b58a-bcb5aadc5462",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: '/home/shawnghu/ml/CD_Circuit/output/bert_mean_acts_random_500.pkl'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[10], line 4\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# read in the pre-calculated mean head response\u001b[39;00m\n\u001b[1;32m      2\u001b[0m path \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mbase_dir\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/output/\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43margs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmodel_type\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m_mean_acts_random_500.pkl\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m handle:\n\u001b[1;32m      5\u001b[0m     back \u001b[38;5;241m=\u001b[39m pickle\u001b[38;5;241m.\u001b[39mload(handle)\n",
      "File \u001b[0;32m~/ml/CD_Circuit/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py:324\u001b[0m, in \u001b[0;36m_modified_open\u001b[0;34m(file, *args, **kwargs)\u001b[0m\n\u001b[1;32m    317\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m {\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m}:\n\u001b[1;32m    318\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m    319\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIPython won\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt let you open fd=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m by default \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    320\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    321\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124myou can use builtins\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m open.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    322\u001b[0m     )\n\u001b[0;32m--> 324\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mio_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/home/shawnghu/ml/CD_Circuit/output/bert_mean_acts_random_500.pkl'"
     ]
    }
   ],
   "source": [
    "# read in the pre-calculated mean head response\n",
    "path = f\"{base_dir}/output/\"\n",
    "\n",
    "with open(os.path.join(path, f\"{args['model_type']}_mean_acts_random_500.pkl\"), 'rb') as handle:\n",
    "    back = pickle.load(handle)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0409a42",
   "metadata": {},
   "source": [
    "### Function description:\n",
    "\n",
    "prop_classifier_model_hh_batched(encoding, model, source_list, target_nodes):\n",
    "\n",
    "- encoding - Encoding given by tokenizer\n",
    "- model - BERT model\n",
    "- source_list - List of lists where each list consists of tuples (layer, position, head) indexing a particular attention head whose influence is to be calculated\n",
    "- target_nodes - A single list of tuples (layer, position, head) containing attention heads on whom the influence is to be measured\n",
    "- num_at_time (optional) - Number of source_lists to be processed in a batch\n",
    "- n_layers - Number of layers\n",
    "- att_list - Attention probabilities if precomputed\n",
    "\n",
    "Output consists of two lists - out_decomps and target_decomps:\n",
    "- out_decomps - Consists of a list of tuples (rel, irrel) reflecting the decomposition of the _output_\n",
    "- target_decomps - A list containining 12 (one for each layer) where each list is of length len(source_list). For any layer l, each entry of target_decomps[l] is a tuple (rel, irrel) decomposition of the target nodes at that layer for the corresponding set of source nodes. rel, irrel are of dimension #number of target nodes in layer l x head_size and the ordering of the target nodes in this layer is the same as provided "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "693cd0e8-2fd3-4e60-bb2f-00efa1356dff",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def patch_hh_at_pos(encoding, model, target_nodes, pos=0, mean_acts=None, set_irrel_to_mean=False):\n",
    "    pos_specific_hs = [\n",
    "        [i for i in range(12)],\n",
    "        [pos],\n",
    "        [i for i in range(12)]\n",
    "    ]\n",
    "    all_heads = list(itertools.product(*pos_specific_hs))\n",
    "\n",
    "    # patch one node at a time\n",
    "    h_ctbn_list = []\n",
    "    \n",
    "    source_node_list = [node for node in all_heads if node not in target_nodes]\n",
    "    print(source_node_list[0])\n",
    "    prop_fn = functools.partial(prop_BERT_hh, encoding, model, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=set_irrel_to_mean)\n",
    "    out_decomps, target_decomps = batch_run(prop_fn, source_node_list)\n",
    "    for i, _ in enumerate(source_node_list):\n",
    "        ctbn = 0\n",
    "        for l in range(12):\n",
    "            if target_decomps[l][i][0].shape[0] != 0:\n",
    "                rel_part = np.mean(abs(target_decomps[l][i][0]))\n",
    "                irrel_part = np.mean(abs(target_decomps[l][i][1]))\n",
    "                ctbn += rel_part / abs(rel_part + irrel_part) * 100\n",
    "        h_ctbn_list.append(ctbn)\n",
    "        \n",
    "    return source_node_list, h_ctbn_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "dff0fe2e-2ba8-4470-9dcd-7c65b048ccfe",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# perform on one doc as an example\n",
    "text = documents_full[0]\n",
    "label = labels_full[0]\n",
    "encoding = get_encoding(text, tokenizer, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "bd8d1919-0e91-4f59-a07d-05847645d6ea",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0, 0, 0)\n"
     ]
    }
   ],
   "source": [
    "# perform one iteration measuring effect of the source nodes to target nodes as an example\n",
    "# note that target nodes get updated in each iteration\n",
    "\n",
    "target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]\n",
    "\n",
    "all_source_hs = []\n",
    "all_htbn = []\n",
    "for pos in tqdm.tqdm(range(2)):\n",
    "    with torch.no_grad():\n",
    "        source_list, h_ctbn_list = patch_hh_at_pos(encoding, model, target_nodes, pos=pos, mean_acts=None, set_irrel_to_mean=False)\n",
    "    torch.cuda.empty_cache()\n",
    "    all_source_hs.append(source_list)\n",
    "    all_htbn.append(h_ctbn_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e6c0bf3b-0617-4884-80d7-8091ddedce64",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "flat_ctbn = [c for sublist in all_htbn for c in sublist]\n",
    "flat_source_h = [c for sublist in all_source_hs for c in sublist]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "id": "d0e92299-54f4-4428-bf98-b920b237831e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "top_idx = sorted(range(len(flat_ctbn)), key=lambda i: flat_ctbn[i])[-6:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "id": "7d48ead3-05fc-43d4-b90f-a32d822bc2a6",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(6, 82, 4)] 36.659424751996994\n",
      "[(5, 82, 4)] 42.90880411863327\n",
      "[(3, 82, 0)] 51.443591713905334\n",
      "[(4, 82, 0)] 87.46089041233063\n",
      "[(5, 82, 0)] 103.08798849582672\n",
      "[(6, 82, 0)] 132.23715126514435\n"
     ]
    }
   ],
   "source": [
    "for i in top_idx:\n",
    "    print(flat_source_h[i], flat_ctbn[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "id": "98bc384b-e405-4815-b919-3b6d03739e2f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# save the identified heads\n",
    "path = f\"{base_dir}/output/{args['task']}/{args['model_type']}_{args['field']}/h3\"\n",
    "os.makedirs(path, exist_ok=True)\n",
    "\n",
    "with open(os.path.join(path, f\"flat_source_h.pkl\"), 'wb') as handle:\n",
    "    pickle.dump(flat_source_h, handle)\n",
    "    \n",
    "with open(os.path.join(path, f\"flat_source_h.pkl\"), 'rb') as handle:\n",
    "    back = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4515c210-6040-41c2-a3a5-9d55d7112863",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "## Examine the attended words by the identified heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "ec1bf9c5-979b-4473-8958-e34adab200a1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def collect_attended_tokens_hh(positives_heads, device, tokenizer, N=100, Z_thres=2, percentile=75, use_perc=False):\n",
    "    index_lst = random.sample(range(0, len(documents_full)), N)\n",
    "    docs = [documents_full[i] for i in index_lst]\n",
    "    \n",
    "    collect = collections.defaultdict(int)\n",
    "    for doc in docs:\n",
    "        encoding = get_encoding(doc, tokenizer, device)\n",
    "        \n",
    "        _, _, raw_att_probs_lst = prop_BERT_hh(encoding, model, [[]], [], device=device, output_att_prob=True)\n",
    "        raw_att_probs = torch.stack(raw_att_probs_lst).cpu().numpy()\n",
    "\n",
    "        avg_att_m = np.zeros((512))\n",
    "        for level, pos, h in positives_heads:\n",
    "            att_m = raw_att_probs[level, h, pos, :]\n",
    "            avg_att_m += att_m\n",
    "\n",
    "        avg_att_m /= len(positives)\n",
    "        \n",
    "        # convert to word level\n",
    "        interval_dict, word_lst = compute_word_intervals(encoding, tokenizer)\n",
    "        word_att_m = combine_token_attn(interval_dict, avg_att_m)\n",
    "        \n",
    "        if use_perc:\n",
    "            perc_cutoff = np.percentile(word_att_m, percentile)\n",
    "            positive_words = np.where(word_att_m > perc_cutoff)\n",
    "        else:\n",
    "            Z = (word_att_m - np.mean(word_att_m)) / np.std(word_att_m)\n",
    "            positive_words = np.where(Z > Z_thres)\n",
    "        \n",
    "        for w_idx in positive_words[0]:\n",
    "            w = word_lst[w_idx]\n",
    "            #collect[w] += 1\n",
    "            collect[w] += word_att_m[w_idx]\n",
    "            \n",
    "    return collect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "2b2bbf39-9363-45ae-b516-2da5e1d9d773",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def collect_attended_tokens_hh_rm_pos(positives_heads, device, tokenizer, N=100, Z_thres=2, percentile=75, use_perc=False):\n",
    "    index_lst = random.sample(range(0, len(documents_full)), N)\n",
    "    docs = [documents_full[i] for i in index_lst]\n",
    "    \n",
    "    collect = collections.defaultdict(int)\n",
    "    for doc in docs:\n",
    "        encoding = get_encoding(doc, tokenizer, device)\n",
    "        \n",
    "        _, _, raw_att_probs_lst = prop_BERT_hh(encoding, model, [[]], [], device=device, output_att_prob=True)\n",
    "        raw_att_probs = torch.stack(raw_att_probs_lst).cpu().numpy()\n",
    "\n",
    "        avg_att_m = np.zeros((512))\n",
    "        for level, _, h in positives_heads:\n",
    "            att_m = raw_att_probs[level, h, :, :]\n",
    "            #att_m = np.mean(att_m, axis=0)\n",
    "            max_row = np.unravel_index(np.argmax(att_m, axis=None), att_m.shape)[0]\n",
    "            avg_att_m += att_m[max_row, :]\n",
    "\n",
    "        avg_att_m /= len(positives)\n",
    "        \n",
    "        # convert to word level\n",
    "        interval_dict, word_lst = compute_word_intervals(encoding, tokenizer)\n",
    "        word_att_m = combine_token_attn(interval_dict, avg_att_m)\n",
    "        \n",
    "        if use_perc:\n",
    "            perc_cutoff = np.percentile(word_att_m, percentile)\n",
    "            positive_words = np.where(word_att_m > perc_cutoff)\n",
    "        else:\n",
    "            Z = (word_att_m - np.mean(word_att_m)) / np.std(word_att_m)\n",
    "            positive_words = np.where(Z > Z_thres)\n",
    "        \n",
    "        for w_idx in positive_words[0]:\n",
    "            w = word_lst[w_idx]\n",
    "            #collect[w] += 1\n",
    "            collect[w] += word_att_m[w_idx]\n",
    "            \n",
    "    return collect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "bfd9dacb-2fb8-40e3-bdd1-94a1f62b1973",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "negatives = [(11, 2), (11, 5), (11, 6), (11, 9), (11, 10), (11, 11),\n",
    "             (10, 3), (10, 4), (10, 5), (10, 6), (10, 9), (10, 10), (10, 11),\n",
    "             (9, 1), (9, 2), (9, 3), (9, 4), (9, 5), (9, 6), (9, 8), (9, 9), (9, 10), (9, 11),\n",
    "             (8, 1), (8, 2), (8, 3), (8, 4), (8, 5), (8, 6), (8, 7), (8, 8), (8, 9), (8, 10), (8, 11),\n",
    "             (7, 1), (7, 2), (7, 3), (7, 4), (7, 5), (7, 7), (7, 8), (7, 9), (7, 10),\n",
    "             (6, 1), (6, 2), (6, 3), (6, 5), (6, 6), (6, 7), (6, 8), (6, 9), (6, 10), (6, 11),\n",
    "             (5, 1), (5, 2), (5, 3), (5, 5), (5, 6), (5, 7), (5, 8), (5, 9), (5, 10), (5, 11),\n",
    "             (4, 1), (4, 2), (4, 3), (4, 4), (4, 5), (4, 6), (4, 7), (4, 8), (4, 9), (4, 10), (4, 11),\n",
    "             (3, 1), (3, 2), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (3, 8), (3, 9), (3, 10), (3, 11),\n",
    "             (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (2, 9), (2, 10), (2, 11),\n",
    "             (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 0), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11),\n",
    "             (0, 0), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 8), (0, 10), (0, 11),\n",
    "            ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "36cebd9e-5d02-41cc-923b-65e88bf24b2a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# the identified attn heads using CD-T\n",
    "pos_specific_hs = [\n",
    "            [i for i in range(12)],\n",
    "            [i for i in range(512)],\n",
    "            [i for i in range(12)]\n",
    "        ]\n",
    "all_heads = list(itertools.product(*pos_specific_hs))\n",
    "random_heads = random.sample(all_heads, 6)\n",
    "positives = random_heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "id": "87528ebc-415b-45f6-a6bc-5cdf193d482b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "positives = [(1, 169, 2), (2, 169, 2), (2, 169, 3), (4, 169, 8), (1, 411, 3)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "id": "dc7df4ef-9292-4b83-8201-2ad57cfee8c1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "positive_attended_token_freq = collect_attended_tokens_hh_rm_pos(positives, device, tokenizer, N=200, use_perc=True)\n",
    "positive_attended_token_freq = sorted(positive_attended_token_freq.items(), key=lambda k_v: k_v[1], reverse=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "42c0c0cc-0f63-466d-b586-822bbca6e8f7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#h1 = [(10, 82, 0), (10, 61, 8), (10, 82, 7), (10, 176, 2), (10, 467, 1), (10, 91, 7)]\n",
    "#h2 = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]\n",
    "#h3 = [(6, 82, 4), (5, 82, 4), (3, 82, 0), (4, 82, 0), (5, 82, 0), (6, 82, 0)]\n",
    "#positives = [(0, 82, 9), (0, 82, 1), (0, 82, 7), (1, 82, 6), (0, 82, 6), (2, 82, 0)]\n",
    "positive_attended_token_freq = collect_attended_tokens_hh(positives, device, tokenizer, N=500, use_perc=True)\n",
    "positive_attended_token_freq = sorted(positive_attended_token_freq.items(), key=lambda k_v: k_v[1], reverse=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "id": "5fb143fa-8d8d-4185-8931-53d94c329144",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "with open('pp15_h2.json', 'w') as fp:\n",
    "    json.dump(positive_attended_token_freq, fp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbb563fe-5586-4600-a700-a5fe4c23d87d",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "## Tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "d4494d04-0b59-4e12-ae4f-660c1680398e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "text = documents_full[0]\n",
    "label = labels_full[0]\n",
    "encoding = get_encoding(text, tokenizer, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "6ccb2c31-7d29-4307-8309-76454da052ef",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "source_list_30 = [#list(itertools.product(range(12), range(512), range(12))), \n",
    "                  # list(itertools.product(range(12), range(70, 85), range(12))), \n",
    "                  # [(11, 0, i) for i in range(12)]\n",
    "                  [(0, 0, 0)]] * 30\n",
    "source_list_60 = [#list(itertools.product(range(12), range(512), range(12))), \n",
    "                  # list(itertools.product(range(12), range(70, 85), range(12))), \n",
    "                  # [(11, 0, i) for i in range(12)]\n",
    "                  [(0, 0, 0)], []] * 30\n",
    "\"\"\"\n",
    "target_nodes = [(11, 8), (11, 0), (11, 1), (11, 4), (11, 3), (11, 7)]\n",
    "source_list = [[(5, 7, 0)], [(5, 5, 0)]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "709e27d0-769c-4c06-9114-e567034a1372",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "out_decomps, target_decomps, _ = prop_BERT_hh(encoding, model, source_list, target_nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "cdcc84d2-f700-4b3b-a62d-a397b657dbb8",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(tensor([-0.0171,  0.0302, -0.0149], device='cuda:0'),\n",
       "  tensor([-3.2685,  6.1339, -3.2953], device='cuda:0')),\n",
       " (tensor([-0.0179,  0.0306, -0.0147], device='cuda:0'),\n",
       "  tensor([-3.2677,  6.1335, -3.2955], device='cuda:0'))]"
      ]
     },
     "execution_count": 103,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_decomps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "1ccea5db-d9ad-4468-8e36-3f0f0903dd3b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 64])"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_decomps[11][0][0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "id": "882e2b06-770b-4e60-8031-76693614b5a4",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([0, 64])"
      ]
     },
     "execution_count": 110,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_decomps[0][0][0].shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
