{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e61a0720",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8a120102",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import argparse\n",
    "import numpy as np\n",
    "import os\n",
    "import pandas as pd\n",
    "import scipy as sp\n",
    "import sys\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import warnings\n",
    "import random\n",
    "import collections\n",
    "\n",
    "# CD-T Imports\n",
    "import math\n",
    "import tqdm\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "import itertools\n",
    "\n",
    "from torch import nn\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "base_dir = os.path.split(os.getcwd())[0]\n",
    "sys.path.append(base_dir)\n",
    "\n",
    "from argparse import Namespace\n",
    "from methods.bag_of_ngrams.processing import cleanReports, cleanSplit, stripChars\n",
    "from pyfunctions.general import extractListFromDic, readJson, combine_token_attn, compute_word_intervals, compare_same\n",
    "from pyfunctions.pathology import extract_synoptic, fixLabelProstateGleason, fixProstateLabels, fixLabel, exclude_labels\n",
    "from pyfunctions.cdt_basic import *\n",
    "from pyfunctions.cdt_source_to_target import *\n",
    "from pyfunctions.cdt_from_source_nodes import *\n",
    "from pyfunctions.ioi_dataset import IOIDataset\n",
    "from pyfunctions.wrappers import *\n",
    "from sklearn import preprocessing\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "from transformers import GPT2Tokenizer, GPT2Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f6d6ada4-4781-4789-b3b1-1d044c11b3d3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x7f63d16402c0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.autograd.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c3183be1-3bf6-4f5a-8134-9bdd83db0a56",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "device = 'cuda:0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a520f760",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model gpt2-small into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "# Model code adapted from Callum McDougall's notebook for ARENA on reproducing the IOI paper using TransformerLens.\n",
    "# This makes some sense, since EasyTransformer, the repo/lib released by the IOI guys, was forked from TransformerLens.\n",
    "# In fact, this makes the reproduction a little bit more faithful, since they most likely do certain things such as \n",
    "# \"folding\" LayerNorms to improve their interpretability results, and we are able to do the same by using TransformerLens.\n",
    "# HuggingFace, by contrast, has the most impenetrable docs and tons of outdated APIs and etc.; even their source \n",
    "# code is impossible to traverse, and I gave up on it, thankfully quickly.\n",
    "\n",
    "from transformer_lens import utils, HookedTransformer, ActivationCache\n",
    "model = HookedTransformer.from_pretrained(\"gpt2-small\",\n",
    "                                          center_unembed=True,\n",
    "                                          center_writing_weights=True,\n",
    "                                          fold_ln=False,\n",
    "                                          refactor_factored_attn_matrices=True)\n",
    "                                          "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fe8f76d9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.cfg.default_prepend_bos = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bf3c253",
   "metadata": {},
   "source": [
    "## Model inspection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "529e2aeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.cfg.default_prepend_bos # prepends a beginning-of-sequence token? but GPT2 doesn't do this, I think?\n",
    "model.tokenizer.padding_side # 'right'\n",
    "model.cfg.positional_embedding_type # 'standard', as opposed to 'shortformer', which puts the positional embedding in the attention pattern\n",
    "model.cfg"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28216f8c",
   "metadata": {},
   "source": [
    "## Example forward pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7023f10b",
   "metadata": {},
   "source": [
    "## Inspect model / hooks\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "11dbf9fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hook_embed                     (1, 16, 768)\n",
      "hook_pos_embed                 (1, 16, 768)\n",
      "blocks.0.hook_resid_pre        (1, 16, 768)\n",
      "blocks.0.ln1.hook_scale        (1, 16, 1)\n",
      "blocks.0.ln1.hook_normalized   (1, 16, 768)\n",
      "blocks.0.attn.hook_q           (1, 16, 12, 64)\n",
      "blocks.0.attn.hook_k           (1, 16, 12, 64)\n",
      "blocks.0.attn.hook_v           (1, 16, 12, 64)\n",
      "blocks.0.attn.hook_attn_scores (1, 12, 16, 16)\n",
      "blocks.0.attn.hook_pattern     (1, 12, 16, 16)\n",
      "blocks.0.attn.hook_z           (1, 16, 12, 64)\n",
      "blocks.0.hook_attn_out         (1, 16, 768)\n",
      "blocks.0.hook_resid_mid        (1, 16, 768)\n",
      "blocks.0.ln2.hook_scale        (1, 16, 1)\n",
      "blocks.0.ln2.hook_normalized   (1, 16, 768)\n",
      "blocks.0.mlp.hook_pre          (1, 16, 3072)\n",
      "blocks.0.mlp.hook_post         (1, 16, 3072)\n",
      "blocks.0.hook_mlp_out          (1, 16, 768)\n",
      "blocks.0.hook_resid_post       (1, 16, 768)\n",
      "ln_final.hook_scale            (1, 16, 1)\n",
      "ln_final.hook_normalized       (1, 16, 768)\n"
     ]
    }
   ],
   "source": [
    "text = \"After John and Mary went to the store, John gave a bottle of milk to\"\n",
    "tokens = model.to_tokens(text).to(device)\n",
    "logits, cache = model.run_with_cache(tokens)\n",
    "for activation_name, activation in cache.items():\n",
    "    # Only print for first layer\n",
    "    if \".0.\" in activation_name or \"blocks\" not in activation_name:\n",
    "        print(f\"{activation_name:30} {tuple(activation.shape)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "061cbe24",
   "metadata": {
    "tags": [],
    "user_expressions": []
   },
   "source": [
    "## Test correctness of forward pass\n",
    "\n",
    "This is especially hard given that TL GPT has so many discrepancies from HF BERT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1a64470a",
   "metadata": {},
   "outputs": [],
   "source": [
    "source_list = [(Node(0, 0, 0),), (Node(1, 1, 1),)]\n",
    "target_nodes = []\n",
    "\n",
    "\n",
    "encoding = model.tokenizer.encode_plus(text, \n",
    "                                 add_special_tokens=True, \n",
    "                                 max_length=512,\n",
    "                                 truncation=True, \n",
    "                                 padding = \"longest\", \n",
    "                                 return_attention_mask=True, \n",
    "                                 return_tensors=\"pt\").to(device)\n",
    "encoding_idxs, attention_mask = encoding.input_ids, encoding.attention_mask\n",
    "input_shape = encoding_idxs.size()\n",
    "extended_attention_mask = get_extended_attention_mask(attention_mask, \n",
    "                                                        input_shape, \n",
    "                                                        model,\n",
    "                                                        device)\n",
    "out_decomps, target_decomps, _, _ = prop_GPT(encoding_idxs, extended_attention_mask, model, source_list, target_nodes, device=device, mean_acts=None, set_irrel_to_mean=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "855602d1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "99.64% of the values are equal\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.9964407943172096"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# This loosely verifies that the out-decomposition's rel + irrel components form the output of the forward pass\n",
    "compare_same(out_decomps[0].rel + out_decomps[0].irrel, out_decomps[1].rel + out_decomps[1].irrel)\n",
    "\n",
    "# compare_same(cache['blocks.0.attn.hook_q'], cache['blocks.0.attn.hook_k']) # checking that there aren't the inputs, because the source code names them as such"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "19553016-c25b-4514-a178-ded1b6db1e9e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 16, 768])\n"
     ]
    }
   ],
   "source": [
    "print(embedding_output.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a376f69e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100.00% of the values are equal\n",
      "100.00% of the values are equal\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# compare_same(cache['hook_embed'], cache['blocks.0.hook_resid_pre'])\n",
    "# This just verifies an assumption about the input conventions to the model.\n",
    "in_0 = cache['blocks.0.hook_resid_pre']\n",
    "compare_same(cache['hook_embed'] + cache['hook_pos_embed'], cache['blocks.0.hook_resid_pre'])\n",
    "\n",
    "# Ensure that the encodings are the same (they won't be if you set prepend_bos=True for one and not the other!)\n",
    "# print(tokens)\n",
    "# print(encoding.input_ids)\n",
    "\n",
    "# Note that model.embed alone isn't the same thing as what we might normally call \"the embedding\"!\n",
    "embedding_output = model.embed(encoding.input_ids) + model.pos_embed(encoding.input_ids)\n",
    "compare_same(in_0, embedding_output)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8f2c3e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extra analysis of value matrix calcuation\n",
    "print(model.blocks[0].attn.W_V.shape)\n",
    "print(model.blocks[0].attn.W_O.shape)\n",
    "# print(model.blocks[0].attn.hook_pattern) # 1, 12, seq_len, seq_len\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "da0edc71",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "not enough values to unpack (expected 3, got 2)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[13], line 54\u001b[0m\n\u001b[1;32m     49\u001b[0m \u001b[38;5;66;03m# level = 0\u001b[39;00m\n\u001b[1;32m     50\u001b[0m \u001b[38;5;66;03m# source_node_list = [(0, 0, 0), (1, 1, 1)]\u001b[39;00m\n\u001b[1;32m     51\u001b[0m \u001b[38;5;66;03m# target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]\u001b[39;00m\n\u001b[1;32m     52\u001b[0m attn_wrapper \u001b[38;5;241m=\u001b[39m GPTAttentionWrapper(attention_module)\n\u001b[0;32m---> 54\u001b[0m rel_summed_values, irrel_summed_values, returned_att_probs \u001b[38;5;241m=\u001b[39m prop_self_attention(rel_ln, irrel_ln, attention_mask, \n\u001b[1;32m     55\u001b[0m                                                                         \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m     56\u001b[0m                                                                         attn_wrapper,\n\u001b[1;32m     57\u001b[0m                                                                         att_probs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m     58\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mZ\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m matrix:\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m     59\u001b[0m desired_z \u001b[38;5;241m=\u001b[39m cache[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mblocks.0.attn.hook_z\u001b[39m\u001b[38;5;124m'\u001b[39m]\n",
      "\u001b[0;31mValueError\u001b[0m: not enough values to unpack (expected 3, got 2)"
     ]
    }
   ],
   "source": [
    "# Test equivalence of attention layer\n",
    "from fancy_einsum import einsum\n",
    "'''\n",
    "blocks.0.attn.hook_q           (1, 17, 12, 64)\n",
    "blocks.0.attn.hook_k           (1, 17, 12, 64)\n",
    "blocks.0.attn.hook_v           (1, 17, 12, 64)\n",
    "blocks.0.attn.hook_attn_scores (1, 12, 17, 17)\n",
    "blocks.0.attn.hook_pattern     (1, 12, 17, 17)\n",
    "blocks.0.attn.hook_z           (1, 17, 12, 64)\n",
    "'''\n",
    "\n",
    "input_shape = encoding['input_ids'].size()\n",
    "\n",
    "attention_mask = get_extended_attention_mask(encoding['attention_mask'], \n",
    "                                                        input_shape, \n",
    "                                                        model,\n",
    "                                                        device)\n",
    "sh = list(embedding_output.shape)\n",
    "rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)\n",
    "irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)\n",
    "irrel[:] = embedding_output[:]\n",
    "layer_module = model.blocks[0]\n",
    "attention_module = layer_module.attn\n",
    "\n",
    "\n",
    "rel_ln, irrel_ln = prop_layer_norm(rel, irrel, GPTLayerNormWrapper(layer_module.ln1))\n",
    "'''\n",
    "print(\"LayerNorm: \")\n",
    "compare_same(cache['blocks.0.ln1.hook_normalized'], (rel_ln + irrel_ln)[0, :16, :])\n",
    "\n",
    "desired_value_output = cache['blocks.0.attn.hook_v']\n",
    "rel_value, irrel_value = prop_linear(rel_ln, irrel_ln, GPTAttentionWrapper(attention_module, layer_module.ln2).value)\n",
    "value = rel_value + irrel_value\n",
    "print(\"Value matrix:\")\n",
    "# desired_value_output = desired_value_output.transpose(-1, -2)\n",
    "old_shape = desired_value_output.size()\n",
    "new_shape = old_shape[:-2] + (old_shape[-2] * old_shape[-1],)\n",
    "desired_value_output = desired_value_output.reshape(new_shape)\n",
    "print(desired_value_output.shape)\n",
    "print(value.shape)\n",
    "\n",
    "compare_same(value[0, :16, :], desired_value_output, rtol = 1e-2) # some elements differ by more than 1e-3 in proportion, not great\n",
    "'''\n",
    "\n",
    "# desired_block_output = cache['blocks.0.hook_resid_mid']\n",
    "\n",
    "desired_block_output = cache['blocks.0.hook_attn_out']\n",
    "desired_attention_pattern = cache['blocks.0.attn.hook_pattern']\n",
    "# level = 0\n",
    "# source_node_list = [(0, 0, 0), (1, 1, 1)]\n",
    "# target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]\n",
    "attn_wrapper = GPTAttentionWrapper(attention_module)\n",
    "\n",
    "rel_summed_values, irrel_summed_values, returned_att_probs = prop_attention_no_output_hh(rel_ln, irrel_ln, attention_mask, \n",
    "                                                                        None,\n",
    "                                                                        attn_wrapper,\n",
    "                                                                        att_probs=None)\n",
    "print(\"'Z' matrix:\")\n",
    "desired_z = cache['blocks.0.attn.hook_z']\n",
    "old_shape = desired_z.size()\n",
    "new_shape = old_shape[:-2] + (old_shape[-2] * old_shape[-1],)\n",
    "desired_z = desired_z.reshape(new_shape)\n",
    "z = rel_summed_values + irrel_summed_values\n",
    "compare_same(desired_z, z[:, :16, :])\n",
    "\n",
    "rel_attn_residual, irrel_attn_residual = prop_linear(rel_summed_values, irrel_summed_values, attn_wrapper.output)\n",
    "\n",
    "print(\"Output matrix:\")\n",
    "output = rel_attn_residual + irrel_attn_residual\n",
    "# output = rel_attn_residual + irrel_attn_residual + rel_ln + irrel_ln\n",
    "compare_same(output[:, :16, :], desired_block_output)\n",
    "\n",
    "\n",
    "desired_mid_output = cache['blocks.0.hook_resid_mid']\n",
    "rel_mid, irrel_mid = rel + rel_attn_residual, irrel + irrel_attn_residual\n",
    "mid_output = rel_mid + irrel_mid\n",
    "compare_same(mid_output[:, :16, :], desired_mid_output)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ab56041",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test equivalence of entire transformer block\n",
    "in_1 = cache['blocks.1.hook_resid_pre']\n",
    "input_shape = encoding['input_ids'].size()\n",
    "\n",
    "extended_attention_mask = get_extended_attention_mask(encoding['attention_mask'], \n",
    "                                                        input_shape, \n",
    "                                                        model,\n",
    "                                                        device)\n",
    "sh = list(embedding_output.shape)\n",
    "rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)\n",
    "irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)\n",
    "irrel[:] = embedding_output[:]\n",
    "layer_module = model.blocks[0]\n",
    "\n",
    "source_node_list = [(0, 0, 0), (1, 1, 1)]\n",
    "target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]\n",
    "rel_out, irrel_out, layer_target_decomps, _ = prop_GPT_layer_hh(rel, irrel, extended_attention_mask, \n",
    "                                                                                 None, source_node_list, \n",
    "                                                                                 target_nodes, 0, \n",
    "                                                                                 None,\n",
    "                                                                                 layer_module, \n",
    "                                                                                 device,\n",
    "                                                                                 None, False,\n",
    "                                                                                 mean_ablated=False)\n",
    "\n",
    "layer0_output = rel_out + irrel_out\n",
    "compare_same(in_1, layer0_output[:, :16, :])\n",
    "\n",
    "# want to check: layer_target_decomps is what we think it is (currently gated by, i didn't implement the decomposition part)\n",
    "# also, rel + irrel sums to the correct output\n",
    "# layer_target_decomps[0].rel + layer_target_decomps[0].irrel == layer_target_decomps[1].rel + layer_target_decomps[1].irrel\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "c1c0bb6b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 16, 50257])\n",
      "(1, 16, 50257)\n",
      "100.00% of the values are equal\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Test equivalence of model run to logits\n",
    "print(logits.shape)\n",
    "logits, cache = model.run_with_cache(tokens)\n",
    "\n",
    "model_out = out_decomps[0].rel + out_decomps[0].irrel\n",
    "print(model_out.shape)\n",
    "compare_same(logits[0], model_out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6ffd4d3-5e8f-4587-bb2a-f06b61918c09",
   "metadata": {
    "tags": [],
    "user_expressions": []
   },
   "source": [
    "## Test correctness of batching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1f242eed-09da-408b-8f31-0703bfc22bdb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-30 13:20:47.277592: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-08-30 13:20:51.596009: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "from pyfunctions.ioi_dataset import IOIDataset\n",
    "\n",
    "# Generate a dataset all consisting of one template, randomly chosen.\n",
    "# nb_templates = 2 due to some logic internal to IOIDataset:\n",
    "# essentially, the nouns can be an ABBA or ABAB order and that counts as separate templates.\n",
    "ioi_dataset = IOIDataset(prompt_type=\"mixed\", N=3, tokenizer=model.tokenizer, prepend_bos=False, nb_templates=2)\n",
    "\n",
    "# This is the P_ABC that is mentioned in the IOI paper, which we use for mean ablation.\n",
    "# Importantly, passing in prompt_type=\"ABC\" or similar is NOT the same thing as this.\n",
    "abc_dataset = (\n",
    "    ioi_dataset.gen_flipped_prompts((\"IO\", \"RAND\"))\n",
    "    .gen_flipped_prompts((\"S\", \"RAND\"))\n",
    "    .gen_flipped_prompts((\"S1\", \"RAND\"))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4405d437-1534-463e-bab7-9976296e99a4",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([12, 16, 768])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits, cache = model.run_with_cache(abc_dataset.toks) # run on entire dataset along batch dimension\n",
    "attention_outputs = [cache['blocks.' + str(i) + '.hook_attn_out'] for i in range(12)]\n",
    "attention_outputs = torch.stack(attention_outputs, dim=1) # now batch, head, seq, d_model\n",
    "mean_acts = torch.mean(attention_outputs, dim=0)\n",
    "mean_acts.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "42f6c5d1-0170-438f-b5e2-ae07cd3a8e2a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "text = ioi_dataset.sentences[0]\n",
    "encoding = model.tokenizer.encode_plus(text, \n",
    "                                 add_special_tokens=True, \n",
    "                                 max_length=512,\n",
    "                                 truncation=True, \n",
    "                                 padding = \"longest\", \n",
    "                                 return_attention_mask=True, \n",
    "                                 return_tensors=\"pt\").to(device)\n",
    "encoding_idxs, attention_mask = encoding.input_ids, encoding.attention_mask\n",
    "input_shape = encoding_idxs.size()\n",
    "extended_attention_mask = get_extended_attention_mask(attention_mask, \n",
    "                                                        input_shape, \n",
    "                                                        model,\n",
    "                                                        device)\n",
    "\n",
    "import functools\n",
    "ranges = [\n",
    "        [layer for layer in range(11, 12)],\n",
    "        [sequence_position for sequence_position in range(input_shape[-1])],\n",
    "        [attention_head_idx for attention_head_idx in range(1)]\n",
    "    ]\n",
    "\n",
    "source_nodes = list(itertools.product(*ranges))\n",
    "# print(source_nodes[:64])\n",
    "# target_nodes = [(7, 0, 1)]\n",
    "target_nodes = []\n",
    "\n",
    "prop_fn = lambda snl: prop_GPT(encoding_idxs[0:1, :], extended_attention_mask, model, snl, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True)\n",
    "batch_out_decomps, target_decomps = batch_run(prop_fn, source_nodes)\n",
    "iter_out_decomps = []\n",
    "\n",
    "# batching is broken for now, just run one by one\n",
    "for layer in range(11, 12):\n",
    "    for sequence_position in range(input_shape[-1]):\n",
    "        for attention_head_idx in range(1):\n",
    "            source_node = (layer, sequence_position, attention_head_idx)\n",
    "            target_nodes = []\n",
    "            out_decomp, _, _ = prop_GPT(encoding_idxs[0:1, :], extended_attention_mask, model, [source_node], target_nodes, mean_acts=mean_acts, set_irrel_to_mean=True, device=device)\n",
    "            iter_out_decomps.append(out_decomp[0])\n",
    "            \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72f874e1-74bf-430a-8e95-aae77ed7cdc2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "manual_batch_out_decomps, _, _ = prop_GPT(encoding_idxs[0:1, :], extended_attention_mask, model, source_nodes, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51208791-d88a-4a60-a363-189c8142eacd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(batch_out_decomps[0].rel[0][4][11])\n",
    "print(iter_out_decomps[0].rel[0][4][11])\n",
    "print(batch_out_decomps[0].source_node)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce77f2ad-c9e6-48aa-b1e5-ab4508d6766b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for seq_pos in range(16):\n",
    "    for head_idx in range(12):\n",
    "        print(seq_pos, head_idx)\n",
    "        compare_same(batch_out_decomps[0].rel[0][seq_pos][head_idx], iter_out_decomps[0].rel[0][seq_pos][head_idx])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "72d799ea-6692-4979-a4e7-c25caaf81207",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100.00% of the values are correct\n"
     ]
    }
   ],
   "source": [
    "compare_same(batch_out_decomps[0].rel, iter_out_decomps[0].rel)\n",
    "# compare_same(manual_batch_out_decomps[0].rel, iter_out_decomps[0].rel)\n",
    "# compare_same(manual_batch_out_decomps[0].rel, batch_out_decomps[0].rel)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df96d304-59fd-4947-81c5-209ce252093e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import importlib\n",
    "import pyfunctions.cdt_source_to_target\n",
    "from pyfunctions.cdt_source_to_target import *\n",
    "\n",
    "importlib.reload(pyfunctions.cdt_source_to_target)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c17f6a5c-b7c4-482e-9efa-0b4238b9aef9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for x in range(16):\n",
    "    batch = batch_out_decomps[x]\n",
    "    it = iter_out_decomps[x]\n",
    "    print(batch.rel.shape, it.rel.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.12 (ipykernel)",
   "language": "python",
   "name": "python3.12"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
