{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "34ed52a4-3b8b-4fab-b1f0-fedde07837a9",
   "metadata": {
    "tags": [],
    "user_expressions": []
   },
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e61a0720",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d1c21a9-dc2c-43c9-9d0c-64aa73381d18",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "https://arxiv.org/pdf/2305.00586\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a120102",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from transformer_lens import HookedTransformer\n",
    "import random\n",
    "import sys\n",
    "import os\n",
    "import collections\n",
    "import operator\n",
    "import functools\n",
    "import itertools\n",
    "\n",
    "\n",
    "base_dir = os.path.split(os.getcwd())[0]\n",
    "sys.path.append(base_dir)\n",
    "from pyfunctions.general import compare_same\n",
    "from pyfunctions.cdt_basic import *\n",
    "from pyfunctions.cdt_source_to_target import *\n",
    "from pyfunctions.cdt_from_source_nodes import *\n",
    "from pyfunctions.toy_model import *\n",
    "from greater_than_task.greater_than_dataset import *\n",
    "from greater_than_task.utils import get_valid_years\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "import torch\n",
    "Result = collections.namedtuple('Result', ('ablation_set', 'score'))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7651660-7f59-4b59-a574-afecc52dc306",
   "metadata": {
    "tags": [],
    "user_expressions": []
   },
   "source": [
    "## Load model and dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a520f760",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
    "torch.autograd.set_grad_enabled(False)\n",
    "\n",
    "from transformer_lens import utils, HookedTransformer, ActivationCache\n",
    "model = HookedTransformer.from_pretrained(\"gpt2-small\",\n",
    "                                          center_unembed=True,\n",
    "                                          center_writing_weights=True,\n",
    "                                          fold_ln=False,\n",
    "                                          refactor_factored_attn_matrices=True)\n",
    "                                          "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 275,
   "id": "bdc77aff-a0ce-47bc-aa69-f9ced8df1497",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# https://github.com/hannamw/gpt2-greater-than/blob/main/circuit_discovery.py; also these files came with their repo\n",
    "years_to_sample_from = get_valid_years(model.tokenizer, 1000, 1900)\n",
    "N = 5000\n",
    "ds = YearDataset(years_to_sample_from, N, Path(\"../greater_than_task/cache/potential_nouns.txt\"), model.tokenizer, balanced=True, device=device, eos=True)\n",
    "year_indices = torch.load(\"../greater_than_task/cache/logit_indices.pt\")# .to(device)\n",
    "\n",
    "num_layers = len(model.blocks)\n",
    "seq_len = ds.good_toks.size()[-1]\n",
    "num_attention_heads = model.cfg.n_heads"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84bbf183-a5be-4d0b-83fd-75714a2241e1",
   "metadata": {
    "tags": [],
    "user_expressions": []
   },
   "source": [
    "## Exploration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a6c93ab2-081a-4247-89c6-ab5a6ee434af",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(model.tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "45b7d59f-e8c5-412f-b4af-9030d271fc3e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "id": "e84af14c-d078-41b6-bd86-0b220b182217",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['<|endoftext|> The clash lasted from the year 1594 to the year 15', '<|endoftext|> The program lasted from the year 1395 to the year 13', '<|endoftext|> The challenge lasted from the year 1496 to the year 14', '<|endoftext|> The confrontation lasted from the year 1597 to the year 15', '<|endoftext|> The marriage lasted from the year 1098 to the year 10', '<|endoftext|> The journey lasted from the year 1202 to the year 12', '<|endoftext|> The insurgency lasted from the year 1803 to the year 18', '<|endoftext|> The improvement lasted from the year 1404 to the year 14', '<|endoftext|> The consultation lasted from the year 1705 to the year 17', '<|endoftext|> The domination lasted from the year 1606 to the year 16']\n",
      "tensor([ 405,  486, 2999, 3070, 3023, 2713, 3312, 2998, 2919, 2931,  940, 1157,\n",
      "        1065, 1485, 1415, 1314, 1433, 1558, 1507, 1129, 1238, 2481, 1828, 1954,\n",
      "        1731, 1495, 2075, 1983, 2078, 1959, 1270, 3132, 2624, 2091, 2682, 2327,\n",
      "        2623, 2718, 2548, 2670, 1821, 3901, 3682, 3559, 2598, 2231, 3510, 2857,\n",
      "        2780, 2920, 1120, 4349, 4309, 4310, 4051, 2816, 3980, 3553, 3365, 3270,\n",
      "        1899, 5333, 5237, 5066, 2414, 2996, 2791, 3134, 3104, 3388, 2154, 4869,\n",
      "        4761, 4790, 4524, 2425, 4304, 3324, 3695, 3720, 1795, 6659, 6469, 5999,\n",
      "        5705, 5332, 4521, 5774, 3459, 4531, 3829, 6420, 5892, 6052, 5824, 3865,\n",
      "        4846, 5607, 4089, 2079])\n",
      "['00', '01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99']\n"
     ]
    }
   ],
   "source": [
    "# print(ds)\n",
    "'''\n",
    "These guys weirdly implemented all the functionality of their class in class-level attributes\n",
    "years_to_sample_from: torch.Tensor\n",
    "    N: int\n",
    "    ordered: bool\n",
    "    eos: bool\n",
    "\n",
    "    nouns: List[str]\n",
    "    years: torch.Tensor\n",
    "    years_YY: torch.Tensor\n",
    "    good_sentences: List[str]\n",
    "    bad_sentences: List[str]\n",
    "    good_toks: torch.Tensor\n",
    "    bad_toks: torch.Tensor\n",
    "    good_prompt: List[str]\n",
    "    bad_prompt: List[str]\n",
    "    good_mask: torch.Tensor\n",
    "    tokenizer: PreTrainedTokenizer\n",
    "    '''\n",
    "\n",
    "# ds.N\n",
    "# ds.nouns\n",
    "# print(ds.years[:20]) # not sorted by XX for some reason\n",
    "# print(ds.years_YY[:]) # but does correspond to these YYs, which are mostly sorted\n",
    "print(ds.good_sentences[-10:]) # includes The endeavor lasted from the year 1098 to the year 10', but 1099 isn't in the list of years?\n",
    "# note: we want prediction at the last token, unlike with the IOI dataset where we want second-to-last\n",
    "# i checked and there is no internal logic to prevent such sentences from being produced, so i guess we're SOL if we sample one?\n",
    "# print(ds.bad_sentences[-10:]) # these all start with 01, e.g 1601 to. they're bad because there is no possible incorrect input\n",
    "# print(ds.good_mask.size()) # n, 100 (100 different years)\n",
    "# print(ds.good_toks.size()) # n, 13\n",
    "# print(ds.bad_toks.size()) # there isn't any necessary correspondence, N is just the number of good sequences and bad sequences alike\n",
    "# list(ds.years.cpu().numpy()).index(1099)\n",
    "print(year_indices)\n",
    "print(model.tokenizer.convert_ids_to_tokens(year_indices)) # length 100, starts with index for '00' and ends with index for '99', great\n",
    "# print(model.tokenizer.decode(year_indices, clean_up_tokenization_spaces=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "6f3a6e99-a858-431b-a96b-815c9622f72e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([False, False, False, False,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True],\n",
       "       device='cuda:0')"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds.good_mask[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "86b2f8b4-0513-4663-8cd3-1dccffca20f6",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<|endoftext|> The attempts lasted from the year 1603 to the year 16'"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds.good_sentences[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "a241453d-5471-4547-9bc3-d370c570e212",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([False, False, False,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True],\n",
       "       device='cuda:0')"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds.good_mask[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b46ac8f-a1b0-40c5-bd07-ceb27d474497",
   "metadata": {},
   "source": [
    "The task specific objective here is something like (sum of probabilities)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9702b0cf-f9e6-4386-9603-1b35924cb129",
   "metadata": {
    "tags": [],
    "user_expressions": []
   },
   "source": [
    "## Setup attention mask and mean activations for ablation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f54e51e-77a7-4732-885d-d58f9eba9842",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "attention_mask = torch.tensor([1 for x in range(seq_len)]).view(1, -1).to(device)\n",
    "input_shape = ds.good_toks[0:1, :].size() # by making the sample size 1, you can get an extended attention mask with batch size 1, which will broadcast\n",
    "extended_attention_mask = get_extended_attention_mask(attention_mask, \n",
    "                                                        input_shape, \n",
    "                                                        model,\n",
    "                                                        device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "3a83177e-9605-4f53-8067-d4a3ae4f22f5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "del logits\n",
    "del cache\n",
    "import gc\n",
    "gc.collect()\n",
    "model.cfg.use_attn_result = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "e68a4e59-e9ec-4521-b66b-6ff14330bdcd",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([490, 12, 13, 12, 64])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([12, 13, 768])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits, cache = model.run_with_cache(ds.good_toks) # run on entire dataset along batch dimension\n",
    "\n",
    "attention_outputs = [cache['blocks.' + str(i) + '.attn.hook_z'] for i in range(num_attention_heads)]\n",
    "attention_outputs = torch.stack(attention_outputs, dim=1) # now batch, layer, seq, n_heads, dim_attn\n",
    "print(attention_outputs.shape)\n",
    "mean_acts = torch.mean(attention_outputs, dim=0)\n",
    "old_shape = mean_acts.shape\n",
    "last_dim = old_shape[-2] * old_shape[-1]\n",
    "new_shape = old_shape[:-2] + (last_dim,)\n",
    "mean_acts = mean_acts.view(new_shape)\n",
    "mean_acts.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "2f6aa206-1608-4a4e-aa73-24ffcbadb3ab",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# quick check for equality, particularly to make sure we've made the attention mask correctly\n",
    "ranges = [\n",
    "        [layer for layer in range(num_layers)],\n",
    "        [sequence_position for sequence_position in range(seq_len)],\n",
    "        [attention_head_idx for attention_head_idx in range(num_attention_heads)]\n",
    "    ]\n",
    "\n",
    "source_nodes = [Node(*x) for x in itertools.product(*ranges)]\n",
    "ablation_sets = [(n,) for n in source_nodes]\n",
    "target_nodes = []\n",
    "out_decomp, _, _, _ = prop_GPT(ds.good_toks[0:1, :], extended_attention_mask, model, [ablation_sets[0]], target_nodes=target_nodes, device=device, mean_acts=None, set_irrel_to_mean=False)\n",
    "\n",
    "logits, cache = model.run_with_cache(ds.good_toks[0])\n",
    "\n",
    "compare_same(out_decomp[0].rel + out_decomp[0].irrel, logits)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27f1eda4-a2c2-4b35-b95e-ab1cc2149be9",
   "metadata": {
    "tags": [],
    "user_expressions": []
   },
   "source": [
    "# Loose experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 294,
   "id": "159db1cb-e612-4dda-87cb-6e4bbbff075e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "NUM_SAMPLES = 100\n",
    "sample_idxs = random.sample(range(N), NUM_SAMPLES) # you actually have to sample randomly from this dataset because they are arranged in increasing order of YY token\n",
    "# sample_idxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4f239cf-69bd-4c03-ba4f-5f87e5942c7b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[172, 392, 73, 394, 157, 273, 369, 200, 402, 373, 202, 127, 163, 365, 186, 326, 124, 438, 227, 129]\n"
     ]
    }
   ],
   "source": [
    "print(sample_idxs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "7e9dd30c-39a8-48b8-9cb1-a712ca65605b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|endoftext|> The pursuit lasted from the year 1290 to the year 12\n"
     ]
    }
   ],
   "source": [
    "print (ds.good_sentences[88])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "58a48d12-ed55-491e-b61d-245e66635bed",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tokenized prompt: ['<|endoftext|>', '<|endoftext|>', ' The', ' pursuit', ' lasted', ' from', ' the', ' year', ' 12', '90', ' to', ' the', ' year', ' 12']\n",
      "Tokenized answer: ['03']\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Performance on answer token:\n",
       "<span style=\"font-weight: bold\">Rank: </span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">52</span><span style=\"font-weight: bold\">       Logit: </span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">20.55</span><span style=\"font-weight: bold\"> Prob:  </span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.22</span><span style=\"font-weight: bold\">% Token: |</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">03</span><span style=\"font-weight: bold\">|</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Performance on answer token:\n",
       "\u001b[1mRank: \u001b[0m\u001b[1;36m52\u001b[0m\u001b[1m       Logit: \u001b[0m\u001b[1;36m20.55\u001b[0m\u001b[1m Prob:  \u001b[0m\u001b[1;36m0.22\u001b[0m\u001b[1m% Token: |\u001b[0m\u001b[1;36m03\u001b[0m\u001b[1m|\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top 0th token. Logit: 25.19 Prob: 22.77% Token: |90|\n",
      "Top 1th token. Logit: 24.12 Prob:  7.78% Token: |99|\n",
      "Top 2th token. Logit: 23.77 Prob:  5.51% Token: |94|\n",
      "Top 3th token. Logit: 23.73 Prob:  5.31% Token: |95|\n",
      "Top 4th token. Logit: 23.63 Prob:  4.81% Token: |92|\n",
      "Top 5th token. Logit: 23.30 Prob:  3.44% Token: |60|\n",
      "Top 6th token. Logit: 23.22 Prob:  3.18% Token: |98|\n",
      "Top 7th token. Logit: 23.18 Prob:  3.06% Token: |96|\n",
      "Top 8th token. Logit: 23.14 Prob:  2.94% Token: |50|\n",
      "Top 9th token. Logit: 23.10 Prob:  2.82% Token: |91|\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Ranks of the answer tokens:</span> <span style=\"font-weight: bold\">[(</span><span style=\"color: #008000; text-decoration-color: #008000\">'03'</span>, <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">52</span><span style=\"font-weight: bold\">)]</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m'03'\u001b[0m, \u001b[1;36m52\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "example_prompt = ds.good_sentences[88] # GPT2 doesn't always perform this task correctly, only about 99% of the time.\n",
    "# On example input <|endoftext|> The pursuit lasted from the year 1290 to the year 12 , the top prediction is '90'.\n",
    "example_answer = '03'\n",
    "\n",
    "transformer_lens.utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True, prepend_space_to_answer=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 288,
   "id": "3b4e0769-739d-460d-bb4e-1ced376a28d0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# This is not a pure function. It depends on ds.good_mask, sample_idxs, and year_indices.\n",
    "def score_logits(logits, sample_idxs_0):\n",
    "    probs = torch.nn.functional.softmax(torch.tensor(logits[:, -1, :], device='cpu'), dim=-1).numpy() # sad\n",
    "    probs_for_year_tokens = probs[:, year_indices.cpu().numpy()]\n",
    "    probs_for_correct_years = probs_for_year_tokens[ds.good_mask.cpu().numpy()[sample_idxs_0]]\n",
    "    correct_score = np.sum(probs_for_correct_years)\n",
    "    probs_for_incorrect_years = probs_for_year_tokens[np.logical_not(ds.good_mask.cpu().numpy()[sample_idxs_0])]\n",
    "    incorrect_score = np.sum(probs_for_incorrect_years)\n",
    "    return (correct_score - incorrect_score) / len(sample_idxs_0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 295,
   "id": "80318fda-1f1c-4c85-a2c4-d338ecac0bc8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "running input 0\n"
     ]
    }
   ],
   "source": [
    "model.reset_hooks(including_permanent=True)\n",
    "\n",
    "mean_acts = mean_acts.view(new_shape)\n",
    "'''\n",
    "ranges = [\n",
    "        [layer for layer in range(num_layers)],\n",
    "        [sequence_position for sequence_position in range(seq_len)],\n",
    "        [attention_head_idx for attention_head_idx in range(num_attention_heads)]\n",
    "    ]\n",
    "\n",
    "source_nodes = [Node(*x) for x in itertools.product(*ranges)]\n",
    "ablation_sets = [(n,) for n in source_nodes]\n",
    "'''\n",
    "ablation_sets = []\n",
    "for layer in range(num_layers):\n",
    "    for head_idx in range(num_attention_heads):\n",
    "        ablation_sets.append(tuple(Node(layer, seq_pos, head_idx) for seq_pos in range(seq_len)))\n",
    "target_nodes = []\n",
    "\n",
    "# cache activations for faster batch run\n",
    "out_decomp, _, _, pre_layer_activations = prop_GPT(ds.good_toks[sample_idxs, :], extended_attention_mask, model, [ablation_sets[0]], target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True)\n",
    "\n",
    "prop_fn = lambda ablation_list: prop_GPT(ds.good_toks[sample_idxs, :], extended_attention_mask, model, ablation_list, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True, cached_pre_layer_acts=pre_layer_activations)\n",
    "out_decomps, target_decomps = batch_run(prop_fn, ablation_sets, num_at_time=(max(64 // len(sample_idxs), 1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 291,
   "id": "1ebea8bf-a588-4f3f-ad71-45c2f87f23f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_logits_decomposition_scores(out_decomps, sample_idxs, normalized=False):\n",
    "    logits = (out_decomps[0].rel + out_decomps[0].irrel) # 1, seq_len, 50257=d_vocab\n",
    "    full_score = score_logits(logits, sample_idxs)\n",
    "    assert(full_score > 0) # this needs to be replaced with a check higher in the pipeline; GPT2 succeeds at this like 99%+ of the time but not always\n",
    "    \n",
    "    results = []\n",
    "    relevances = np.zeros((num_layers, num_attention_heads))\n",
    "\n",
    "    for layer_idx in range(num_layers):\n",
    "\n",
    "        for head_idx in range(num_attention_heads):\n",
    "            decomp = out_decomps[layer_idx * num_attention_heads + head_idx]\n",
    "            score = score_logits(decomp.rel, sample_idxs)\n",
    "            norm_score = score / full_score\n",
    "            relevances[layer_idx, head_idx] = norm_score\n",
    "            if not normalized:\n",
    "                results.append(Result(decomp.ablation_set, norm_score))\n",
    "    if normalized:\n",
    "        sums_per_layer = np.sum(np.abs(relevances), axis=(1))\n",
    "        print(sums_per_layer)\n",
    "\n",
    "        sums_per_layer[sums_per_layer == 0] = -1e-8\n",
    "        relevances = relevances / np.expand_dims(sums_per_layer, (1))\n",
    "        for layer_idx in range(num_layers):\n",
    "            for head_idx in range(num_attention_heads):\n",
    "                target_decomp = target_decomps[layer_idx * num_attention_heads + head_idx]\n",
    "                results.append(Result(target_decomp.ablation_set, relevances[layer_idx, head_idx]))\n",
    "    results.sort(key=operator.attrgetter('score'), reverse=True)\n",
    "\n",
    "    return results, relevances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 240,
   "id": "3819c7ec-a856-4236-ad3e-dfb0fd749883",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "def compute_logits_decomposition_scores(out_decomps, sample_idxs, normalized=False):\n",
    "    logits = (out_decomps[0].rel + out_decomps[0].irrel) # 1, seq_len, 50257=d_vocab\n",
    "    full_score = score_logits(logits, sample_idxs)\n",
    "    assert(full_score > 0) # this needs to be replaced with a check higher in the pipeline; GPT2 succeeds at this like 99%+ of the time but not always\n",
    "    \n",
    "    results = []\n",
    "    relevances = np.zeros((num_layers, seq_len, num_attention_heads))\n",
    "\n",
    "    for layer_idx in range(num_layers):\n",
    "        for seq_pos in range(seq_len):\n",
    "            for head_idx in range(num_attention_heads):\n",
    "                decomp = out_decomps[layer_idx * seq_len * num_attention_heads + seq_pos * num_attention_heads + head_idx]\n",
    "                score = score_logits(decomp.rel, sample_idxs)\n",
    "                norm_score = score / full_score\n",
    "                relevances[layer_idx, seq_pos, head_idx] = norm_score\n",
    "                if not normalized:\n",
    "                    results.append(Result(decomp.ablation_set, norm_score))\n",
    "    if normalized:\n",
    "        sums_per_layer = np.sum(relevances, axis=(1, 2))\n",
    "        print(sums_per_layer)\n",
    "        sums_per_layer = np.abs(sums_per_layer)\n",
    "\n",
    "        sums_per_layer[sums_per_layer == 0] = -1e-8\n",
    "        relevances = relevances / np.expand_dims(sums_per_layer, (1, 2))\n",
    "        for layer_idx in range(num_layers):\n",
    "            for seq_pos in range(seq_len):\n",
    "                for head_idx in range(num_attention_heads):\n",
    "                    target_decomp = target_decomps[layer_idx * seq_len * num_attention_heads + seq_pos * num_attention_heads + head_idx]\n",
    "                    results.append(Result(target_decomp.ablation_set, relevances[layer_idx, seq_pos, head_idx]))\n",
    "    results.sort(key=operator.attrgetter('score'), reverse=True)\n",
    "\n",
    "    return results, relevances\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 296,
   "id": "c397ecd2-9d0d-4edc-b795-de7b94484fdf",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.03981848 0.01361028 0.00921484 0.00663896 0.00468989 0.00379394\n",
      " 0.00245998 0.00181219 0.00163173 0.00281442 0.00125149 0.00022332]\n",
      "Node(layer_idx=9, sequence_idx=0, attn_head_idx=1) 0.8089691831315646\n",
      "Node(layer_idx=10, sequence_idx=0, attn_head_idx=4) 0.5720321180711452\n",
      "Node(layer_idx=7, sequence_idx=0, attn_head_idx=10) 0.4250676382511792\n",
      "Node(layer_idx=11, sequence_idx=0, attn_head_idx=8) 0.3279695854853505\n",
      "Node(layer_idx=10, sequence_idx=0, attn_head_idx=7) 0.2508779920846017\n",
      "Node(layer_idx=8, sequence_idx=0, attn_head_idx=8) 0.22523116602251275\n",
      "Node(layer_idx=6, sequence_idx=0, attn_head_idx=9) 0.20666289500310217\n",
      "Node(layer_idx=8, sequence_idx=0, attn_head_idx=11) 0.18331776048103574\n",
      "Node(layer_idx=8, sequence_idx=0, attn_head_idx=10) 0.14820980889046276\n",
      "Node(layer_idx=4, sequence_idx=0, attn_head_idx=3) 0.14483214622315416\n",
      "Node(layer_idx=6, sequence_idx=0, attn_head_idx=7) 0.1270143676209398\n",
      "Node(layer_idx=2, sequence_idx=0, attn_head_idx=1) 0.1268682623768458\n",
      "Node(layer_idx=5, sequence_idx=0, attn_head_idx=10) 0.12238649375986464\n",
      "Node(layer_idx=3, sequence_idx=0, attn_head_idx=5) 0.11745679305442037\n",
      "Node(layer_idx=5, sequence_idx=0, attn_head_idx=1) 0.1152104827034423\n",
      "Node(layer_idx=5, sequence_idx=0, attn_head_idx=5) 0.1148199086505459\n",
      "Node(layer_idx=1, sequence_idx=0, attn_head_idx=7) 0.11411368442711638\n",
      "Node(layer_idx=2, sequence_idx=0, attn_head_idx=9) 0.10970435140398688\n",
      "Node(layer_idx=3, sequence_idx=0, attn_head_idx=7) 0.10837733463800243\n",
      "Node(layer_idx=0, sequence_idx=0, attn_head_idx=1) 0.10690872876349256\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'\\n a9.h1, while\\nMLP 8 relies on a8.h11, a8.h8, a7.h10, a6.h9, a5.h5, and a5.h1\\n\\n(9, 1), (8, 11), (8, 8), (7, 10), (6, 9), (5, 5), (5, 1)\\n'"
      ]
     },
     "execution_count": 296,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# results = compute_logits_decomposition_scores(out_decomps)\n",
    "results, relevances = compute_logits_decomposition_scores(out_decomps, sample_idxs, normalized=True)\n",
    "\n",
    "results.sort(key=operator.attrgetter('score'), reverse=True)\n",
    "for result in results[:20]:\n",
    "    # print(result)\n",
    "    print(result.ablation_set[0], result.score)\n",
    "'''\n",
    " a9.h1, while\n",
    "MLP 8 relies on a8.h11, a8.h8, a7.h10, a6.h9, a5.h5, and a5.h1\n",
    "\n",
    "(9, 1), (8, 11), (8, 8), (7, 10), (6, 9), (5, 5), (5, 1)\n",
    "'''\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "afca9ea8-9588-42ec-8f5a-a88408e46acd",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0.013245150161018487\n"
     ]
    }
   ],
   "source": [
    "print(relevances[9, 12, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "id": "85369371-3cfb-4852-a3bf-9a9910d56e4a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.006119984588483005\n"
     ]
    }
   ],
   "source": [
    "print(relevances[10, 12, 4]) # 0.000484417607102171"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 170,
   "id": "3825269f-402c-4b8a-96d3-ddf57429a09b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "20"
      ]
     },
     "execution_count": 170,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(sample_idxs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 225,
   "id": "32b2b05c-eee7-4a95-aef5-126b082ec550",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "running input 0\n"
     ]
    }
   ],
   "source": [
    "model.reset_hooks(including_permanent=True)\n",
    "\n",
    "mean_acts = mean_acts.view(new_shape)\n",
    "'''\n",
    "target_nodes = [Node(9, 12, 1), Node(10, 12, 4)] # (10, 12, 7), (7, 12, 10)\n",
    "ranges = [\n",
    "        [layer for layer in range(num_layers)],\n",
    "        [sequence_position for sequence_position in range(seq_len)],\n",
    "        # [ioi_dataset.word_idx['IO'][0]],\n",
    "        [attention_head_idx for attention_head_idx in range(num_attention_heads)]\n",
    "    ]\n",
    "source_nodes = [Node(*x) for x in itertools.product(*ranges)]\n",
    "ablation_sets = [(n,) for n in source_nodes]\n",
    "'''\n",
    "ablation_sets = []\n",
    "for layer in range(num_layers):\n",
    "    for head_idx in range(num_attention_heads):\n",
    "        ablation_sets.append(tuple(Node(layer, seq_pos, head_idx) for seq_pos in range(seq_len)))\n",
    "target_nodes = []\n",
    "for layer, head_idx in [(9, 1), (10, 4)]:\n",
    "    for seq_pos in range(seq_len):\n",
    "        target_nodes.append(Node(layer, seq_pos, head_idx))\n",
    "\n",
    "_, _, _, pre_layer_activations = prop_GPT(ds.good_toks[sample_idxs, :], extended_attention_mask, model, [ablation_sets[0]], target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True)\n",
    "\n",
    "prop_fn = lambda ablation_list: prop_GPT(ds.good_toks[sample_idxs, :], extended_attention_mask, model, ablation_list, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True, cached_pre_layer_acts=pre_layer_activations)\n",
    "out_decomps, target_decomps = batch_run(prop_fn, ablation_sets, num_at_time=max(64 // len(sample_idxs), 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 231,
   "id": "8ebab7b5-5f80-40b6-98f1-bccf4a16274a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def calculate_target_decomposition_scores(target_decomps, normalized=False):\n",
    "    results = []\n",
    "    relevances = np.zeros((num_layers, num_attention_heads))\n",
    "    for layer_idx in range(num_layers):\n",
    "        for head_idx in range(num_attention_heads):\n",
    "            idx = layer_idx * num_attention_heads + head_idx\n",
    "            target_decomp = target_decomps[idx]\n",
    "            if target_decomp.ablation_set[0] in target_nodes:\n",
    "                continue\n",
    "            score = 0\n",
    "            for target_node_idx in range(len(target_decomp.target_nodes)):\n",
    "                for batch_idx in range(len(target_decomp.rels)):\n",
    "                    rels_magnitude = torch.mean(abs(target_decomp.rels[target_node_idx])) # np.mean if you are on cpu\n",
    "                    irrels_magnitude = torch.mean(abs(target_decomp.irrels[batch_idx])) # np.mean if you are on cpu\n",
    "                    target_node_score = rels_magnitude / (rels_magnitude + irrels_magnitude)\n",
    "                    score += target_node_score\n",
    "            if score != 0:\n",
    "                score /= len(target_decomp.rels)\n",
    "\n",
    "            relevances[layer_idx, head_idx] = score\n",
    "            if not normalized:\n",
    "                results.append(Result(target_decomp.ablation_set, relevances[layer_idx, head_idx]))\n",
    "\n",
    "\n",
    "    if normalized:\n",
    "        sums_per_layer = np.abs(np.sum(relevances, axis=(1)))\n",
    "        sums_per_layer[sums_per_layer == 0] = -1e-8\n",
    "        relevances = relevances / np.expand_dims(sums_per_layer, (1))\n",
    "\n",
    "        for layer_idx in range(num_layers):\n",
    "            for head_idx in range(num_attention_heads):\n",
    "                target_decomp = target_decomps[layer_idx * num_attention_heads + head_idx]\n",
    "                results.append(Result(target_decomp.ablation_set, relevances[layer_idx, head_idx]))\n",
    "\n",
    "    results.sort(key=operator.attrgetter('score'), reverse=True)\n",
    "    return results, relevances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc3cde90-9e60-414d-96ac-f17c8a0992c8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def calculate_target_decomposition_scores(target_decomps, normalized=False):\n",
    "    results = []\n",
    "    relevances = np.zeros((num_layers, seq_len, num_attention_heads))\n",
    "    for layer_idx in range(num_layers):\n",
    "        for seq_pos in range(seq_len):\n",
    "            for head_idx in range(num_attention_heads):\n",
    "                idx = layer_idx * num_layers * seq_len + seq_pos * num_attention_heads + head_idx\n",
    "                target_decomp = target_decomps[idx]\n",
    "                if target_decomp.ablation_set[0] in target_nodes:\n",
    "                    continue\n",
    "                score = 0\n",
    "                for target_node_idx in range(len(target_decomp.target_nodes)):\n",
    "                    for batch_idx in range(len(target_decomp.rels)):\n",
    "                        rels_magnitude = torch.mean(abs(target_decomp.rels[target_node_idx])) # np.mean if you are on cpu\n",
    "                        irrels_magnitude = torch.mean(abs(target_decomp.irrels[batch_idx])) # np.mean if you are on cpu\n",
    "                        target_node_score = rels_magnitude / (rels_magnitude + irrels_magnitude)\n",
    "                        score += target_node_score\n",
    "                if score != 0:\n",
    "                    score /= len(target_decomp.rels)\n",
    "\n",
    "                relevances[layer_idx, seq_pos, head_idx] = score\n",
    "                if not normalized:\n",
    "                    results.append(Result(target_decomp.ablation_set, relevances[layer_idx, seq_pos, head_idx]))\n",
    "\n",
    "\n",
    "    if normalized:\n",
    "        sums_per_layer = np.abs(np.sum(relevances, axis=(1, 2)))\n",
    "        sums_per_layer[sums_per_layer == 0] = -1e-8\n",
    "        relevances = relevances / np.expand_dims(sums_per_layer, (1, 2))\n",
    "\n",
    "        for layer_idx in range(num_layers):\n",
    "            for seq_pos in range(seq_len):\n",
    "                for head_idx in range(num_attention_heads):\n",
    "                    target_decomp = target_decomps[layer_idx * seq_len * num_attention_heads + seq_pos * num_attention_heads + head_idx]\n",
    "                    results.append(Result(target_decomp.ablation_set, relevances[layer_idx, seq_pos, head_idx]))\n",
    "\n",
    "    results.sort(key=operator.attrgetter('score'), reverse=True)\n",
    "    return results, relevances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "id": "dba0453d-5193-4288-8bf1-342964d0c053",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.0368, -0.0941,  0.1855,  0.2428, -0.0738,  0.0607, -0.1385,  0.1011,\n",
       "         0.0750,  0.0955, -0.0027,  0.0664,  0.0500,  0.0146,  0.0697, -0.0249,\n",
       "         0.0653,  0.1335,  0.1746, -0.1930,  0.1157, -0.0148,  0.2589,  0.1349,\n",
       "        -0.0696,  0.0200,  0.0364,  0.0313, -0.0468, -0.0105, -0.0036,  0.1675,\n",
       "        -0.1756,  0.0926, -0.1959, -0.0925, -0.0743,  0.1034,  0.0553,  0.1374,\n",
       "        -0.0344, -0.1161,  0.0424, -0.2551,  0.0880,  0.0200, -0.0320, -0.2025,\n",
       "        -0.1221, -0.1395,  0.0507, -0.1609,  0.2095, -0.0270, -0.0257, -0.0935,\n",
       "        -0.0396,  0.0354, -0.0641,  0.0662,  0.0389,  0.0927, -0.0686,  0.1923],\n",
       "       device='cuda:0')"
      ]
     },
     "execution_count": 178,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_decomps[0].rels[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 234,
   "id": "c0e5db37-3c6c-407a-a7b8-86e2723c2f98",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node(layer_idx=9, sequence_idx=0, attn_head_idx=3) 0.1447705091308364\n",
      "Node(layer_idx=7, sequence_idx=0, attn_head_idx=8) 0.13397999674087258\n",
      "Node(layer_idx=7, sequence_idx=0, attn_head_idx=5) 0.1245598699413583\n",
      "Node(layer_idx=6, sequence_idx=0, attn_head_idx=7) 0.11800533643794124\n",
      "Node(layer_idx=5, sequence_idx=0, attn_head_idx=10) 0.11756238262186239\n",
      "Node(layer_idx=8, sequence_idx=0, attn_head_idx=8) 0.11658563322680085\n",
      "Node(layer_idx=9, sequence_idx=0, attn_head_idx=8) 0.11190624527839693\n",
      "Node(layer_idx=9, sequence_idx=0, attn_head_idx=0) 0.10658123669559956\n",
      "Node(layer_idx=9, sequence_idx=0, attn_head_idx=5) 0.10474013448306205\n",
      "Node(layer_idx=9, sequence_idx=0, attn_head_idx=10) 0.10472372792549454\n",
      "Node(layer_idx=6, sequence_idx=0, attn_head_idx=4) 0.10395374556525286\n",
      "Node(layer_idx=8, sequence_idx=0, attn_head_idx=5) 0.10361383371558026\n",
      "Node(layer_idx=4, sequence_idx=0, attn_head_idx=3) 0.10250919493076233\n",
      "Node(layer_idx=9, sequence_idx=0, attn_head_idx=7) 0.10090647101533862\n",
      "Node(layer_idx=8, sequence_idx=0, attn_head_idx=10) 0.09827043812296836\n",
      "Node(layer_idx=7, sequence_idx=0, attn_head_idx=3) 0.0966235911234089\n",
      "Node(layer_idx=5, sequence_idx=0, attn_head_idx=7) 0.09627517967871815\n",
      "Node(layer_idx=9, sequence_idx=0, attn_head_idx=2) 0.0959595681001785\n",
      "Node(layer_idx=6, sequence_idx=0, attn_head_idx=0) 0.09588239668599481\n",
      "Node(layer_idx=8, sequence_idx=0, attn_head_idx=9) 0.09464267831148702\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'\\n(9, 1), (8, 11), (8, 8), (7, 10), (6, 9), (5, 5), (5, 1)\\n'"
      ]
     },
     "execution_count": 234,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results, relevances = calculate_target_decomposition_scores(target_decomps, normalized=True)\n",
    "\n",
    "for result in results[:20]:\n",
    "    print(result.ablation_set[0], result.score)\n",
    "    # print(result)\n",
    "'''\n",
    "(9, 1), (8, 11), (8, 8), (7, 10), (6, 9), (5, 5), (5, 1)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "id": "9abb034f-8ded-4824-bf4e-5c37cda6e468",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result(ablation_set=(Node(layer_idx=9, sequence_idx=12, attn_head_idx=1),), score=0.0010632349292983398)\n"
     ]
    }
   ],
   "source": [
    "print(results[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "15bb0c2f-bef4-4bc3-8f5a-f2177a17e538",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node(layer_idx=9, sequence_idx=14, attn_head_idx=9)\n",
      "Node(layer_idx=10, sequence_idx=14, attn_head_idx=10)\n",
      "Node(layer_idx=9, sequence_idx=14, attn_head_idx=6)\n",
      "Node(layer_idx=9, sequence_idx=2, attn_head_idx=6)\n",
      "Node(layer_idx=0, sequence_idx=1, attn_head_idx=1)\n",
      "Node(layer_idx=0, sequence_idx=1, attn_head_idx=4)\n"
     ]
    }
   ],
   "source": [
    "all_nodes = []\n",
    "for it in outliers_per_iter:\n",
    "    for result in it:\n",
    "        if result.ablation_set[0] not in all_nodes:\n",
    "            all_nodes.append(result.ablation_set[0])\n",
    "for node in all_nodes:\n",
    "    print((node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a438ad2b-5fa8-4211-bed6-c0d70f152341",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Then, Vanessa and Paul went to the house. Vanessa gave a basketball to Paul\n",
      "Then, Jessica and Lindsay went to the school. Jessica gave a snack to Lindsay\n"
     ]
    }
   ],
   "source": [
    "print(ioi_dataset.sentences[0])\n",
    "print(test_ioi_dataset.sentences[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce7cb13f-4b1f-4bd0-9b8b-1a3bccbfe9ad",
   "metadata": {
    "tags": [],
    "user_expressions": []
   },
   "source": [
    "# Circuit evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "id": "5b2ad668-f0bc-492b-9afd-92c5716b2d95",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7698.91748046875\n",
      "10790.0\n",
      "1087.55322265625\n",
      "4184.0\n"
     ]
    }
   ],
   "source": [
    "# del out_decomps\n",
    "# del target_decomps\n",
    "# del logits\n",
    "# del cache # pretty sure it's this one\n",
    "print(torch.cuda.memory_allocated(0)/1024/1024)\n",
    "print(torch.cuda.memory_reserved(0)/1024/1024)\n",
    "\n",
    "import gc\n",
    "gc.collect()\n",
    "\n",
    "torch.cuda.empty_cache()\n",
    "print(torch.cuda.memory_allocated(0)/1024/1024)\n",
    "print(torch.cuda.memory_reserved(0)/1024/1024)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "33374a25-c2a8-493e-8480-ecc2274e6ba3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "ranges = [\n",
    "        [layer for layer in range(num_layers)],\n",
    "        [sequence_position for sequence_position in range(seq_len)],\n",
    "        # [ioi_dataset.word_idx['IO'][0]],\n",
    "        [attention_head_idx for attention_head_idx in range(num_attention_heads)]\n",
    "    ]\n",
    "\n",
    "source_nodes = [Node(*x) for x in itertools.product(*ranges)]\n",
    "random_circuit = random.sample(source_nodes, 20)\n",
    "\n",
    "# sample_idxs = random.sample(range(N), NUM_SAMPLES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "d5fe00b8-8a79-49f4-bb0d-26af18a46525",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# implicitly depends on year_indices/\n",
    "def correctness_rate(logits, sample_idxs_0):\n",
    "    logits_for_year_tokens = logits[:, -1, year_indices]\n",
    "    predicted_year_idxs = np.argmax(logits_for_year_tokens.cpu().numpy(), axis=-1)\n",
    "    # print(predicted_year_idxs.shape)\n",
    "    correct_per_input = ds.good_mask.cpu().numpy()[sample_idxs_0, predicted_year_idxs]\n",
    "    return np.sum(correct_per_input) / len(sample_idxs_0)\n",
    "    '''\n",
    "    probs_for_year_tokens = probs[:, year_indices.cpu().numpy()]\n",
    "    probs_for_correct_years = probs_for_year_tokens[ds.good_mask.cpu().numpy()[sample_idxs_0]]\n",
    "    correct_score = np.sum(probs_for_correct_years)\n",
    "    probs_for_incorrect_years = probs_for_year_tokens[np.logical_not(ds.good_mask.cpu().numpy()[sample_idxs_0])]\n",
    "    incorrect_score = np.sum(probs_for_incorrect_years)\n",
    "    return (correct_score - incorrect_score) / len(sample_idxs_0)\n",
    "    '''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d8266c4-7c2b-4b14-bc5d-2698140a168b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2482656/2324425928.py:3: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  probs = torch.nn.functional.softmax(torch.tensor(logits[:, -1, :], device='cpu'), dim=-1).numpy() # sad\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7689858919143681\n",
      "0.9896000000000005\n"
     ]
    }
   ],
   "source": [
    "\n",
    "circuit = []\n",
    "for (layer_idx, head_idx) in [(9, 1), (8, 11), (8, 8), (7, 10), (6, 9), (5, 5), (5, 1)]: # greater-than paper's result\n",
    "    for seq_pos in range(seq_len):\n",
    "        circuit.append(Node(layer_idx, seq_pos, head_idx))\n",
    "\n",
    "'''\n",
    "# simply results from first iter\n",
    "circuit = [Node(layer_idx=9, sequence_idx=12, attn_head_idx=1),\n",
    "    # Node(layer_idx=10, sequence_idx=12, attn_head_idx=4),\n",
    "    # Node(layer_idx=10, sequence_idx=12, attn_head_idx=7),\n",
    "    Node(layer_idx=7, sequence_idx=12, attn_head_idx=10),\n",
    "]\n",
    "# 711, 965\n",
    "'''\n",
    "'''\n",
    "circuit = []\n",
    "for (layer_idx, head_idx) in [(9, 1), (7, 10)]: # the above but without seq pos\n",
    "    for seq_pos in range(seq_len):\n",
    "        circuit.append(Node(layer_idx, seq_pos, head_idx))\n",
    "'''\n",
    "evaluate_circuit(circuit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8ffb5dd-88dd-4da0-b912-95b076d55c40",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2482656/2324425928.py:3: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  probs = torch.nn.functional.softmax(torch.tensor(logits[:, -1, :], device='cpu'), dim=-1).numpy() # sad\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8166431448936461\n",
      "0.9920000000000004\n"
     ]
    }
   ],
   "source": [
    "evaluate_circuit(None, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 218,
   "id": "de2cde63-196a-46e8-8fa7-8f10cbf60be1",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2482656/2051035158.py:3: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  probs = torch.nn.functional.softmax(torch.tensor(logits[:, -1, :], device='cpu'), dim=-1).numpy() # sad\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6743128757087552\n",
      "0.9183673469387756\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'\\ncircuit = [Node(layer_idx=9, sequence_idx=12, attn_head_idx=1),\\n    Node(layer_idx=10, sequence_idx=12, attn_head_idx=4),\\nNode(layer_idx=9, sequence_idx=11, attn_head_idx=1),\\nNode(layer_idx=9, sequence_idx=9, attn_head_idx=9),\\nNode(layer_idx=9, sequence_idx=9, attn_head_idx=1),\\nNode(layer_idx=9, sequence_idx=11, attn_head_idx=9),\\nNode(layer_idx=9, sequence_idx=11, attn_head_idx=6),\\nNode(layer_idx=5, sequence_idx=3, attn_head_idx=10),\\nNode(layer_idx=7, sequence_idx=3, attn_head_idx=8),\\nNode(layer_idx=9, sequence_idx=5, attn_head_idx=3),\\nNode(layer_idx=6, sequence_idx=3, attn_head_idx=7),\\nNode(layer_idx=4, sequence_idx=3, attn_head_idx=3),\\nNode(layer_idx=9, sequence_idx=10, attn_head_idx=1),\\nNode(layer_idx=9, sequence_idx=7, attn_head_idx=10),\\nNode(layer_idx=8, sequence_idx=12, attn_head_idx=11),\\n# Node(layer_idx=7, sequence_idx=4, attn_head_idx=5),\\n# Node(layer_idx=9, sequence_idx=4, attn_head_idx=3),\\n# Node(layer_idx=8, sequence_idx=9, attn_head_idx=3),\\n# Node(layer_idx=7, sequence_idx=4, attn_head_idx=8),\\n# Node(layer_idx=5, sequence_idx=4, attn_head_idx=10),\\n# Node(layer_idx=9, sequence_idx=8, attn_head_idx=10),\\n]\\n'"
      ]
     },
     "execution_count": 218,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "\n",
    "'''\n",
    "circuit = [Node(layer_idx=9, sequence_idx=12, attn_head_idx=1),\n",
    "    Node(layer_idx=10, sequence_idx=12, attn_head_idx=4),\n",
    "Node(layer_idx=9, sequence_idx=9, attn_head_idx=9),\n",
    "Node(layer_idx=9, sequence_idx=11, attn_head_idx=1),\n",
    "Node(layer_idx=9, sequence_idx=11, attn_head_idx=9),\n",
    "# Node(layer_idx=9, sequence_idx=11, attn_head_idx=6),\n",
    "# Node(layer_idx=9, sequence_idx=4, attn_head_idx=2),\n",
    "# Node(layer_idx=5, sequence_idx=3, attn_head_idx=10),\n",
    "# Node(layer_idx=9, sequence_idx=9, attn_head_idx=1),\n",
    "# Node(layer_idx=7, sequence_idx=2, attn_head_idx=1),\n",
    "# Node(layer_idx=7, sequence_idx=2, attn_head_idx=3),\n",
    "# Node(layer_idx=7, sequence_idx=3, attn_head_idx=8),\n",
    "# Node(layer_idx=9, sequence_idx=9, attn_head_idx=4),\n",
    "# Node(layer_idx=7, sequence_idx=2, attn_head_idx=5),\n",
    "# Node(layer_idx=9, sequence_idx=8, attn_head_idx=10),\n",
    "# Node(layer_idx=7, sequence_idx=2, attn_head_idx=4),\n",
    "# Node(layer_idx=6, sequence_idx=3, attn_head_idx=7),\n",
    "# Node(layer_idx=7, sequence_idx=2, attn_head_idx=11),\n",
    "# Node(layer_idx=9, sequence_idx=5, attn_head_idx=3),\n",
    "# Node(layer_idx=6, sequence_idx=2, attn_head_idx=7),\n",
    "]\n",
    "'''\n",
    "circuit = []\n",
    "for (layer_idx, head_idx) in [(9, 1), (10, 4), (9, 9), (9, 6), (9, 2), (5, 10), (7, 1), (7, 3), (7, 8)]: # the above but without seq pos\n",
    "    for seq_pos in range(seq_len):\n",
    "        circuit.append(Node(layer_idx, seq_pos, head_idx))\n",
    "evaluate_circuit(circuit)\n",
    "'''\n",
    "circuit = [Node(layer_idx=9, sequence_idx=12, attn_head_idx=1),\n",
    "    Node(layer_idx=10, sequence_idx=12, attn_head_idx=4),\n",
    "Node(layer_idx=9, sequence_idx=11, attn_head_idx=1),\n",
    "Node(layer_idx=9, sequence_idx=9, attn_head_idx=9),\n",
    "Node(layer_idx=9, sequence_idx=9, attn_head_idx=1),\n",
    "Node(layer_idx=9, sequence_idx=11, attn_head_idx=9),\n",
    "Node(layer_idx=9, sequence_idx=11, attn_head_idx=6),\n",
    "Node(layer_idx=5, sequence_idx=3, attn_head_idx=10),\n",
    "Node(layer_idx=7, sequence_idx=3, attn_head_idx=8),\n",
    "Node(layer_idx=9, sequence_idx=5, attn_head_idx=3),\n",
    "Node(layer_idx=6, sequence_idx=3, attn_head_idx=7),\n",
    "Node(layer_idx=4, sequence_idx=3, attn_head_idx=3),\n",
    "Node(layer_idx=9, sequence_idx=10, attn_head_idx=1),\n",
    "Node(layer_idx=9, sequence_idx=7, attn_head_idx=10),\n",
    "Node(layer_idx=8, sequence_idx=12, attn_head_idx=11),\n",
    "# Node(layer_idx=7, sequence_idx=4, attn_head_idx=5),\n",
    "# Node(layer_idx=9, sequence_idx=4, attn_head_idx=3),\n",
    "# Node(layer_idx=8, sequence_idx=9, attn_head_idx=3),\n",
    "# Node(layer_idx=7, sequence_idx=4, attn_head_idx=8),\n",
    "# Node(layer_idx=5, sequence_idx=4, attn_head_idx=10),\n",
    "# Node(layer_idx=9, sequence_idx=8, attn_head_idx=10),\n",
    "]\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 293,
   "id": "ef44e151-a5d9-4b23-bdfc-9da55be07bfd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2482656/2324425928.py:3: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  probs = torch.nn.functional.softmax(torch.tensor(logits[:, -1, :], device='cpu'), dim=-1).numpy() # sad\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7616304731369021\n",
      "0.9806000000000002\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'\\nNode(layer_idx=9, sequence_idx=0, attn_head_idx=1) 0.7919426329190301\\nNode(layer_idx=10, sequence_idx=0, attn_head_idx=4) 0.5688594302281663\\nNode(layer_idx=7, sequence_idx=0, attn_head_idx=10) 0.3724930226405632\\nNode(layer_idx=10, sequence_idx=0, attn_head_idx=7) 0.2592619470224411\\nNode(layer_idx=11, sequence_idx=0, attn_head_idx=8) 0.2262736787177263\\nNode(layer_idx=8, sequence_idx=0, attn_head_idx=10) 0.21704383205004027\\nNode(layer_idx=8, sequence_idx=0, attn_head_idx=8) 0.2017522938898915\\nNode(layer_idx=6, sequence_idx=0, attn_head_idx=9) 0.17671132314414148\\n'"
      ]
     },
     "execution_count": 293,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "'''\n",
    "Node(layer_idx=9, sequence_idx=0, attn_head_idx=1) 0.4083923634530539\n",
    "Node(layer_idx=7, sequence_idx=0, attn_head_idx=10) 0.2221098391988148\n",
    "Node(layer_idx=10, sequence_idx=0, attn_head_idx=4) 0.2202389078105817\n",
    "Node(layer_idx=10, sequence_idx=0, attn_head_idx=7) 0.15423218716630077\n",
    "Node(layer_idx=6, sequence_idx=0, attn_head_idx=9) 0.1332779102978367\n",
    "Node(layer_idx=11, sequence_idx=0, attn_head_idx=8) 0.12993389925707383\n",
    "Node(layer_idx=8, sequence_idx=0, attn_head_idx=8) 0.12727556319370611\n",
    "Node(layer_idx=8, sequence_idx=0, attn_head_idx=11) 0.12419588242524643\n",
    "Node(layer_idx=4, sequence_idx=0, attn_head_idx=3) 0.11703012724861953\n",
    "Node(layer_idx=6, sequence_idx=0, attn_head_idx=7) 0.10890546106985093\n",
    "Node(layer_idx=5, sequence_idx=0, attn_head_idx=10) 0.10690893744165206\n",
    "'''\n",
    "circuit = []\n",
    "for (layer_idx, head_idx) in [(9, 1), (10, 4), (7, 10), (11, 8), (10, 7), (6, 9), (8, 11), (8, 8)]: # the above but without seq pos\n",
    "    for seq_pos in range(seq_len):\n",
    "        circuit.append(Node(layer_idx, seq_pos, head_idx))\n",
    "evaluate_circuit(circuit)\n",
    "'''\n",
    "Node(layer_idx=9, sequence_idx=0, attn_head_idx=1) 0.7919426329190301\n",
    "Node(layer_idx=10, sequence_idx=0, attn_head_idx=4) 0.5688594302281663\n",
    "Node(layer_idx=7, sequence_idx=0, attn_head_idx=10) 0.3724930226405632\n",
    "Node(layer_idx=10, sequence_idx=0, attn_head_idx=7) 0.2592619470224411\n",
    "Node(layer_idx=11, sequence_idx=0, attn_head_idx=8) 0.2262736787177263\n",
    "Node(layer_idx=8, sequence_idx=0, attn_head_idx=10) 0.21704383205004027\n",
    "Node(layer_idx=8, sequence_idx=0, attn_head_idx=8) 0.2017522938898915\n",
    "Node(layer_idx=6, sequence_idx=0, attn_head_idx=9) 0.17671132314414148\n",
    "'''\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 273,
   "id": "9db57878-aaf0-4ef6-83ad-021902bf9e4c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from pyfunctions.faithfulness_ablations import add_mean_ablation_hook\n",
    "def evaluate_circuit(circuit, full_model=False):\n",
    "    # mean_acts = mean_acts.view(old_shape)\n",
    "    model.reset_hooks(including_permanent=True)\n",
    "    # current findings:\n",
    "    # full model: 0.817, 0.989 correctness\n",
    "    # ablate all attention layers entirely: 0.515, 0.891\n",
    "    # random circuit of 20 \"head, seq_pos\": 0.532, 0.891\n",
    "    # our \"four head, seq_pos\" circuit: 0.711, 0.955\n",
    "    # their circuit: 0.765, 0.985\n",
    "    if full_model:\n",
    "        ablation_model = model\n",
    "    else:\n",
    "        ablation_model = add_mean_ablation_hook(model, patch_values=mean_acts.view(old_shape), circuit=circuit)\n",
    "    \n",
    "    # batching\n",
    "    NUM_AT_TIME = 64\n",
    "    start_idx = 0\n",
    "    score = 0\n",
    "    correctness = 0\n",
    "    while True:\n",
    "        end_idx = start_idx + NUM_AT_TIME\n",
    "        if end_idx > N:\n",
    "            end_idx = N\n",
    "\n",
    "        logits, cache = model.run_with_cache(ds.good_toks[start_idx:end_idx]) # run on entire dataset along batch dimension\n",
    "        batch_score = score_logits(logits, range(start_idx, end_idx))\n",
    "        batch_correctness_rate = correctness_rate(logits, range(start_idx, end_idx))\n",
    "        num_samples = end_idx - start_idx\n",
    "        score += batch_score * (num_samples / N)\n",
    "        correctness += batch_correctness_rate * (num_samples / N)\n",
    "        start_idx += NUM_AT_TIME\n",
    "        if end_idx == N:\n",
    "            break\n",
    "    print(score)\n",
    "    print(correctness)\n",
    "    ablation_model.reset_hooks(including_permanent=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea429b1a-516d-4210-a0c7-7278aae0f8cc",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tentatively improved score to -1.255778   by removing node  Node(layer_idx=10, sequence_idx=14, attn_head_idx=10)\n",
      "removing  Node(layer_idx=10, sequence_idx=14, attn_head_idx=10)  to achieve score of -1.255778\n",
      "tentatively improved score to -1.220982   by removing node  Node(layer_idx=0, sequence_idx=1, attn_head_idx=4)\n",
      "removing  Node(layer_idx=0, sequence_idx=1, attn_head_idx=4)  to achieve score of -1.220982\n",
      "Done\n"
     ]
    }
   ],
   "source": [
    "# speculative: try to generate a better circuit by greedy search\n",
    "\n",
    "NAME_MOVER_HEADS = [Node(9, 14, 9), Node(10, 14, 0), Node(9, 14, 6)]\n",
    "old_circuit = circuit.copy()\n",
    "best_score = -1.4686 # \n",
    "while True:\n",
    "    node_to_remove = None\n",
    "    for idx, node in enumerate(circuit):\n",
    "        if node in NAME_MOVER_HEADS:\n",
    "            continue\n",
    "        new_circuit = circuit.copy()\n",
    "        new_circuit.remove(node)\n",
    "        # print(new_circuit)\n",
    "        model.reset_hooks(including_permanent=True)\n",
    "        model = add_mean_ablation_hook(model, means_dataset=test_abc_dataset, circuit=new_circuit)\n",
    "        logits, cache = model.run_with_cache(test_ioi_dataset.toks) # run on entire dataset along batch dimension\n",
    "        ave_logit_diff = logits_to_ave_logit_diff_2(logits, test_ioi_dataset).cpu().numpy().item()\n",
    "        if ave_logit_diff > best_score:\n",
    "            best_score = ave_logit_diff\n",
    "            node_to_remove = node\n",
    "            print('tentatively improved score to %f ' % best_score, ' by removing node ', node_to_remove)\n",
    "    if node_to_remove is None: \n",
    "        # then we can't improve any further so the algorithm terminates\n",
    "        break\n",
    "    print(\"removing \", node_to_remove, \" to achieve score of %f\" % best_score)\n",
    "    circuit.remove(node_to_remove)\n",
    "print('Done')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2be65d92-2c5f-4638-8ff8-ec2f2aab6e71",
   "metadata": {
    "tags": [],
    "user_expressions": []
   },
   "source": [
    "## Without sequence positions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24788897-fccd-4e04-912a-b5ff16bfc887",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.12 (ipykernel)",
   "language": "python",
   "name": "python3.12"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
