{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "87133c47-c439-431b-876a-b8a5d2e1b646",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "508fe26c-4731-4245-9779-1e22182bda40",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "996849a5-b234-46a8-96a1-e55965b51725",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "# This is for if we're trying to execute on a remote JupyterHub, where the pwd is set to the server root, or else I think pwd is set correctly already.\n",
    "# %cd CD_Circuit/\n",
    "\n",
    "import argparse\n",
    "import numpy as np\n",
    "import os\n",
    "import pandas as pd\n",
    "import scipy as sp\n",
    "import sys\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import warnings\n",
    "import random\n",
    "import collections\n",
    "\n",
    "# CD-T Imports\n",
    "import math\n",
    "import tqdm\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "import itertools\n",
    "import operator\n",
    "\n",
    "from torch import nn\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "base_dir = os.path.split(os.getcwd())[0]\n",
    "sys.path.append(base_dir)\n",
    "\n",
    "from argparse import Namespace\n",
    "from methods.bag_of_ngrams.processing import cleanReports, cleanSplit, stripChars\n",
    "from pyfunctions.general import extractListFromDic, readJson, combine_token_attn, compute_word_intervals, compare_same\n",
    "from pyfunctions.pathology import extract_synoptic, fixLabelProstateGleason, fixProstateLabels, fixLabel, exclude_labels\n",
    "from pyfunctions.cdt_basic import *\n",
    "from pyfunctions.cdt_source_to_target import *\n",
    "from pyfunctions.ioi_dataset import IOIDataset\n",
    "from sklearn import preprocessing\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "from transformers import GPT2Tokenizer, GPT2Model\n",
    "from pyfunctions.wrappers import Node, AblationSet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "25aee65e-f2d4-4450-8d37-da280543cb01",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model gpt2-small into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
    "torch.autograd.set_grad_enabled(False)\n",
    "# Model code adapted from Callum McDougall's notebook for ARENA on reproducing the IOI paper using TransformerLens.\n",
    "# This makes some sense, since EasyTransformer, the repo/lib released by the IOI guys, was forked from TransformerLens.\n",
    "# In fact, this makes the reproduction a little bit more faithful, since they most likely do certain things such as \n",
    "# \"folding\" LayerNorms to improve their interpretability results, and we are able to do the same by using TransformerLens.\n",
    "# HuggingFace, by contrast, has the most impenetrable docs and tons of outdated APIs and etc.; even their source \n",
    "# code is impossible to traverse, and I gave up on it, thankfully quickly.\n",
    "\n",
    "from transformer_lens import utils, HookedTransformer, ActivationCache\n",
    "model = HookedTransformer.from_pretrained(\"gpt2-small\",\n",
    "                                          center_unembed=True,\n",
    "                                          center_writing_weights=True,\n",
    "                                          fold_ln=False,\n",
    "                                          refactor_factored_attn_matrices=True)\n",
    "                                          "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f72254a5-a721-4a87-8a88-3b950eacec20",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-09-26 02:17:44.339763: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
      "2024-09-26 02:17:46.323034: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-09-26 02:17:49.799224: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([12, 16, 768])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyfunctions.ioi_dataset import IOIDataset\n",
    "\n",
    "# Generate a dataset all consisting of one template, randomly chosen.\n",
    "# nb_templates = 2 due to some logic internal to IOIDataset:\n",
    "# essentially, the nouns can be an ABBA or ABAB order and that counts as separate templates.\n",
    "ioi_dataset = IOIDataset(prompt_type=\"mixed\", N=3, tokenizer=model.tokenizer, prepend_bos=False, nb_templates=2)\n",
    "\n",
    "# This is the P_ABC that is mentioned in the IOI paper, which we use for mean ablation.\n",
    "# Importantly, passing in prompt_type=\"ABC\" or similar is NOT the same thing as this.\n",
    "abc_dataset = (\n",
    "    ioi_dataset.gen_flipped_prompts((\"IO\", \"RAND\"))\n",
    "    .gen_flipped_prompts((\"S\", \"RAND\"))\n",
    "    .gen_flipped_prompts((\"S1\", \"RAND\"))\n",
    ")\n",
    "\n",
    "logits, cache = model.run_with_cache(abc_dataset.toks) # run on entire dataset along batch dimension\n",
    "\n",
    "attention_outputs = [cache['blocks.' + str(i) + '.attn.hook_z'] for i in range(12)]\n",
    "attention_outputs = torch.stack(attention_outputs, dim=1) # now batch, layer, seq, n_heads, dim_attn\n",
    "mean_acts = torch.mean(attention_outputs, dim=0)\n",
    "old_shape = mean_acts.shape\n",
    "last_dim = old_shape[-2] * old_shape[-1]\n",
    "new_shape = old_shape[:-2] + (last_dim,)\n",
    "mean_acts = mean_acts.view(new_shape)\n",
    "mean_acts.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c92fe366-20fa-4aaf-bd5b-e9b65c466b1c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "# source_list = [Node(0, 0, 0), Node(1, 1, 1)]\n",
    "# target_nodes = [(7, 0, 1)]\n",
    "\n",
    "text = ioi_dataset.sentences[0]\n",
    "encoding = model.tokenizer.encode_plus(text, \n",
    "                                 add_special_tokens=True, \n",
    "                                 max_length=512,\n",
    "                                 truncation=True, \n",
    "                                 padding = \"longest\", \n",
    "                                 return_attention_mask=True, \n",
    "                                 return_tensors=\"pt\").to(device)\n",
    "encoding_idxs, attention_mask = encoding.input_ids, encoding.attention_mask\n",
    "input_shape = encoding_idxs.size()\n",
    "extended_attention_mask = get_extended_attention_mask(attention_mask, \n",
    "                                                        input_shape, \n",
    "                                                        model,\n",
    "                                                        device)\n",
    "# out_decomps, target_decomps, _ = prop_GPT(encoding_idxs, extended_attention_mask, model, source_list, target_nodes, mean_acts=mean_acts, set_irrel_to_mean=True, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ee5e1a5-0493-43a3-b5fc-39917a9fddd1",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Circuit evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88f96ad6-7a3f-417a-b0fd-2615963f365b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c41f2571-a91e-4c7e-aa45-1cb1f9eaa739",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from pyfunctions.faithfulness_ablations import logits_to_ave_logit_diff_2, add_mean_ablation_hook\n",
    "ioi_dataset = IOIDataset(prompt_type=\"mixed\", N=100, tokenizer=model.tokenizer, prepend_bos=False)\n",
    "abc_dataset = (\n",
    "    ioi_dataset.gen_flipped_prompts((\"IO\", \"RAND\"))\n",
    "    .gen_flipped_prompts((\"S\", \"RAND\"))\n",
    "    .gen_flipped_prompts((\"S1\", \"RAND\"))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bb334dd7-3118-4bb2-bf18-a1927f4791dc",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(3.6853, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "model.reset_hooks(including_permanent=True)\n",
    "logits, cache = model.run_with_cache(ioi_dataset.toks) # run on entire dataset along batch dimension\n",
    "ave_logit_diff = logits_to_ave_logit_diff_2(logits, ioi_dataset)\n",
    "print(ave_logit_diff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bea958aa-ba52-4a55-82f4-7f21e2c98e07",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(3.2721, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "model = add_mean_ablation_hook(model, means_dataset=abc_dataset) # IOI paper's circuit, by default\n",
    "logits, cache = model.run_with_cache(ioi_dataset.toks) # run on entire dataset along batch dimension\n",
    "ave_logit_diff = logits_to_ave_logit_diff_2(logits, ioi_dataset)\n",
    "print(ave_logit_diff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f122e88c-d33b-434d-8533-599593429c58",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(4.0973, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "This one does better than the full model very consistently, which is kind of alarming.\n",
    "Wasn't ablating away the IOI dataset information supposed to destroy information, thereby resulting in a worse performance?\n",
    "'''\n",
    "\n",
    "model.reset_hooks(including_permanent=True)\n",
    "model = add_mean_ablation_hook(model, means_dataset=ioi_dataset)\n",
    "logits, cache = model.run_with_cache(ioi_dataset.toks) # run on entire dataset along batch dimension\n",
    "ave_logit_diff = logits_to_ave_logit_diff_2(logits, ioi_dataset)\n",
    "print(ave_logit_diff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57b62938-06bd-4729-834a-dd207d42b681",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check implementation of explicit index\n",
    "CIRCUIT = {\n",
    "    \"name mover\": [(9, 9), (10, 0), (9, 6)],\n",
    "    \"backup name mover\": [(10, 10), (10, 6), (10, 2), (10, 1), (11, 2), (9, 7), (9, 0), (11, 9)],\n",
    "    \"negative name mover\": [(10, 7), (11, 10)],\n",
    "    \"s2 inhibition\": [(7, 3), (7, 9), (8, 6), (8, 10)],\n",
    "    \"induction\": [(5, 5), (5, 8), (5, 9), (6, 9)],\n",
    "    \"duplicate token\": [(0, 1), (0, 10), (3, 0)],\n",
    "    \"previous token\": [(2, 2), (4, 11)],\n",
    "}\n",
    "\n",
    "SEQ_POS_TO_KEEP = {\n",
    "    \"name mover\": \"end\",\n",
    "    \"backup name mover\": \"end\",\n",
    "    \"negative name mover\": \"end\",\n",
    "    \"s2 inhibition\": \"end\",\n",
    "    \"induction\": \"S2\",\n",
    "    \"duplicate token\": \"S2\",\n",
    "    \"previous token\": \"S+1\",\n",
    "}\n",
    "nodes = []\n",
    "for key in CIRCUIT:\n",
    "    explicit_seq_pos = ioi_dataset.word_idx[SEQ_POS_TO_KEEP[key]].numpy()[0]\n",
    "    for tup in CIRCUIT[key]:\n",
    "        nodes.append(Node(tup[0], explicit_seq_pos, tup[1]))\n",
    "\n",
    "# you can now run this circuit instead of the IOI circuit in the usual way on ioi_dataset[0] and observe that it achieves the same performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "32a2e1bd-0832-42f2-a636-6c71546b5cf1",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(-1.1722, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from pyfunctions.ioi_dataset import ABC_TEMPLATES, BAC_TEMPLATES, BABA_TEMPLATES, BABA_LONG_TEMPLATES, BABA_LATE_IOS, BABA_EARLY_IOS, ABBA_TEMPLATES, ABBA_LATE_IOS, ABBA_EARLY_IOS\n",
    "template = BABA_TEMPLATES[0]\n",
    "\n",
    "ioi_dataset = IOIDataset(prompt_type=[template], N=100, tokenizer=model.tokenizer, prepend_bos=False)\n",
    "abc_dataset = (\n",
    "    ioi_dataset.gen_flipped_prompts((\"IO\", \"RAND\"))\n",
    "    .gen_flipped_prompts((\"S\", \"RAND\"))\n",
    "    .gen_flipped_prompts((\"S1\", \"RAND\"))\n",
    ")\n",
    "\n",
    "# Perform analysis only for the template that this thing was formed on\n",
    "circuit = [Node(layer_idx=8, sequence_idx=14, attn_head_idx=6),\n",
    "           Node(layer_idx=8, sequence_idx=11, attn_head_idx=6),\n",
    "           Node(layer_idx=9, sequence_idx=14, attn_head_idx=9),\n",
    "           Node(layer_idx=9, sequence_idx=14, attn_head_idx=6),\n",
    "           Node(layer_idx=5, sequence_idx=10, attn_head_idx=5),\n",
    "           Node(layer_idx=7, sequence_idx=11, attn_head_idx=9),\n",
    "           Node(layer_idx=6, sequence_idx=10, attn_head_idx=9),\n",
    "           Node(layer_idx=6, sequence_idx=11, attn_head_idx=0),\n",
    "           Node(layer_idx=5, sequence_idx=10, attn_head_idx=9),\n",
    "           Node(layer_idx=3, sequence_idx=10, attn_head_idx=0),\n",
    "           Node(layer_idx=4, sequence_idx=5, attn_head_idx=11),\n",
    "           Node(layer_idx=3, sequence_idx=5, attn_head_idx=7),\n",
    "           Node(layer_idx=3, sequence_idx=3, attn_head_idx=6),\n",
    "           Node(layer_idx=2, sequence_idx=3, attn_head_idx=2),\n",
    "           Node(layer_idx=2, sequence_idx=3, attn_head_idx=9),\n",
    "           Node(layer_idx=1, sequence_idx=3, attn_head_idx=7),\n",
    "           Node(layer_idx=1, sequence_idx=3, attn_head_idx=10),\n",
    "           Node(layer_idx=0, sequence_idx=2, attn_head_idx=1),\n",
    "           Node(layer_idx=0, sequence_idx=2, attn_head_idx=4)]\n",
    "\n",
    "model.reset_hooks(including_permanent=True)\n",
    "model = add_mean_ablation_hook(model, means_dataset=abc_dataset, circuit=circuit)\n",
    "logits, cache = model.run_with_cache(ioi_dataset.toks) # run on entire dataset along batch dimension\n",
    "ave_logit_diff = logits_to_ave_logit_diff_2(logits, ioi_dataset)\n",
    "print(ave_logit_diff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6267d1c3-0441-4a67-9b6d-27ad51a1411f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(2.9411, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from pyfunctions.ioi_dataset import ABC_TEMPLATES, BAC_TEMPLATES, BABA_TEMPLATES, BABA_LONG_TEMPLATES, BABA_LATE_IOS, BABA_EARLY_IOS, ABBA_TEMPLATES, ABBA_LATE_IOS, ABBA_EARLY_IOS\n",
    "template = ABBA_TEMPLATES[0]\n",
    "\n",
    "ioi_dataset = IOIDataset(prompt_type=[template], N=100, tokenizer=model.tokenizer, prepend_bos=False)\n",
    "abc_dataset = (\n",
    "    ioi_dataset.gen_flipped_prompts((\"IO\", \"RAND\"))\n",
    "    .gen_flipped_prompts((\"S\", \"RAND\"))\n",
    "    .gen_flipped_prompts((\"S1\", \"RAND\"))\n",
    ")\n",
    "\n",
    "# Perform analysis only for the template that this thing was formed on\n",
    "circuit = [Node(9, 14, 9), Node(9, 14, 6), Node(10, 14, 0),\n",
    "          Node(8, 14, 10), Node(7, 14, 9), Node(7, 14, 3),\n",
    "          Node(5, 10, 5), Node(5, 10, 8), Node(5, 10, 9),\n",
    "          Node(0, 10, 10), Node(0, 10, 1), Node(3, 10, 0),\n",
    "          Node(0, 5, 6), Node(0, 5, 7), Node(0, 5, 10),\n",
    "\n",
    "]\n",
    "\n",
    "model.reset_hooks(including_permanent=True)\n",
    "model = add_mean_ablation_hook(model, means_dataset=abc_dataset, circuit=circuit)\n",
    "logits, cache = model.run_with_cache(ioi_dataset.toks) # run on entire dataset along batch dimension\n",
    "ave_logit_diff = logits_to_ave_logit_diff_2(logits, ioi_dataset)\n",
    "print(ave_logit_diff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "528c56dd-b2f1-44ff-9ef7-bac210782bfa",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "b\n",
      "tensor(4.1236, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from pyfunctions.ioi_dataset import ABC_TEMPLATES, BAC_TEMPLATES, BABA_TEMPLATES, BABA_LONG_TEMPLATES, BABA_LATE_IOS, BABA_EARLY_IOS, ABBA_TEMPLATES, ABBA_LATE_IOS, ABBA_EARLY_IOS\n",
    "template = BABA_TEMPLATES[0]\n",
    "\n",
    "ioi_dataset = IOIDataset(prompt_type=[template], N=100, tokenizer=model.tokenizer, prepend_bos=False)\n",
    "abc_dataset = (\n",
    "    ioi_dataset.gen_flipped_prompts((\"IO\", \"RAND\"))\n",
    "    .gen_flipped_prompts((\"S\", \"RAND\"))\n",
    "    .gen_flipped_prompts((\"S1\", \"RAND\"))\n",
    ")\n",
    "\n",
    "# Perform analysis only for the template that this thing was formed on\n",
    "circuit = [Node(layer_idx=8, sequence_idx=14, attn_head_idx=6),\n",
    "           Node(layer_idx=8, sequence_idx=11, attn_head_idx=6),\n",
    "           Node(layer_idx=9, sequence_idx=14, attn_head_idx=9),\n",
    "           Node(layer_idx=9, sequence_idx=14, attn_head_idx=6),\n",
    "           Node(layer_idx=5, sequence_idx=10, attn_head_idx=5),\n",
    "           Node(layer_idx=7, sequence_idx=11, attn_head_idx=9),\n",
    "           Node(layer_idx=6, sequence_idx=10, attn_head_idx=9),\n",
    "           Node(layer_idx=6, sequence_idx=11, attn_head_idx=0),\n",
    "           Node(layer_idx=5, sequence_idx=10, attn_head_idx=9),\n",
    "           Node(layer_idx=3, sequence_idx=10, attn_head_idx=0),\n",
    "           Node(layer_idx=4, sequence_idx=5, attn_head_idx=11),\n",
    "           Node(layer_idx=3, sequence_idx=5, attn_head_idx=7),\n",
    "           Node(layer_idx=3, sequence_idx=3, attn_head_idx=6),\n",
    "           Node(layer_idx=2, sequence_idx=3, attn_head_idx=2),\n",
    "           Node(layer_idx=2, sequence_idx=3, attn_head_idx=9),\n",
    "           Node(layer_idx=1, sequence_idx=3, attn_head_idx=7),\n",
    "           Node(layer_idx=1, sequence_idx=3, attn_head_idx=10),\n",
    "           Node(layer_idx=0, sequence_idx=2, attn_head_idx=1),\n",
    "           Node(layer_idx=0, sequence_idx=2, attn_head_idx=4)]\n",
    "\n",
    "model.reset_hooks(including_permanent=True)\n",
    "model = add_mean_ablation_hook(model, means_dataset=ioi_dataset, circuit=circuit)\n",
    "logits, cache = model.run_with_cache(ioi_dataset.toks) # run on entire dataset along batch dimension\n",
    "ave_logit_diff = logits_to_ave_logit_diff_2(logits, ioi_dataset)\n",
    "print(ave_logit_diff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "145111bb-39b4-4b8a-9b26-86522569517a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "19\n"
     ]
    }
   ],
   "source": [
    "print(len(circuit)) # IOI paper uses 26"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "4486b9db-14f7-4409-b381-ed54a32bb6bd",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "b\n",
      "tensor(1.6466, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "circuit = [Node(layer_idx=9, sequence_idx=14, attn_head_idx=9),\n",
    "            Node(layer_idx=9, sequence_idx=14, attn_head_idx=6),\n",
    "            Node(layer_idx=10, sequence_idx=14, attn_head_idx=0),\n",
    "            Node(layer_idx=9, sequence_idx=11, attn_head_idx=9),\n",
    "            Node(layer_idx=9, sequence_idx=11, attn_head_idx=6),\n",
    "            Node(layer_idx=9, sequence_idx=9, attn_head_idx=6),\n",
    "            Node(layer_idx=8, sequence_idx=3, attn_head_idx=6),\n",
    "            Node(layer_idx=8, sequence_idx=9, attn_head_idx=3),\n",
    "            Node(layer_idx=8, sequence_idx=3, attn_head_idx=10),\n",
    "            Node(layer_idx=7, sequence_idx=3, attn_head_idx=9),\n",
    "            Node(layer_idx=7, sequence_idx=3, attn_head_idx=3),\n",
    "            Node(layer_idx=6, sequence_idx=3, attn_head_idx=4),\n",
    "            Node(layer_idx=6, sequence_idx=3, attn_head_idx=1),\n",
    "            Node(layer_idx=6, sequence_idx=2, attn_head_idx=4),\n",
    "            Node(layer_idx=5, sequence_idx=3, attn_head_idx=10),\n",
    "            Node(layer_idx=5, sequence_idx=2, attn_head_idx=10),\n",
    "            Node(layer_idx=4, sequence_idx=3, attn_head_idx=3),\n",
    "            Node(layer_idx=4, sequence_idx=3, attn_head_idx=11),\n",
    "            Node(layer_idx=4, sequence_idx=2, attn_head_idx=4),\n",
    "            Node(layer_idx=4, sequence_idx=2, attn_head_idx=7),\n",
    "            Node(layer_idx=4, sequence_idx=2, attn_head_idx=3),\n",
    "            Node(layer_idx=3, sequence_idx=2, attn_head_idx=5),\n",
    "            Node(layer_idx=3, sequence_idx=2, attn_head_idx=6),\n",
    "            Node(layer_idx=3, sequence_idx=2, attn_head_idx=2),\n",
    "            Node(layer_idx=2, sequence_idx=2, attn_head_idx=10),\n",
    "            Node(layer_idx=2, sequence_idx=2, attn_head_idx=7),\n",
    "            Node(layer_idx=2, sequence_idx=2, attn_head_idx=1),\n",
    "            Node(layer_idx=1, sequence_idx=2, attn_head_idx=7),\n",
    "            Node(layer_idx=1, sequence_idx=2, attn_head_idx=6),\n",
    "            Node(layer_idx=1, sequence_idx=2, attn_head_idx=3),\n",
    "            Node(layer_idx=0, sequence_idx=2, attn_head_idx=1),\n",
    "            Node(layer_idx=0, sequence_idx=2, attn_head_idx=4),\n",
    "            Node(layer_idx=0, sequence_idx=2, attn_head_idx=5)]\n",
    "model.reset_hooks(including_permanent=True)\n",
    "model = add_mean_ablation_hook(model, means_dataset=ioi_dataset, circuit=circuit)\n",
    "logits, cache = model.run_with_cache(ioi_dataset.toks) # run on entire dataset along batch dimension\n",
    "ave_logit_diff = logits_to_ave_logit_diff_2(logits, ioi_dataset)\n",
    "print(len(circuit)) # IOI paper uses 26\n",
    "print(ave_logit_diff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "ab5a10d2-ebb8-4aaa-be21-b299a50678eb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11546394624\n",
      "6150946816\n",
      "3359564288\n"
     ]
    }
   ],
   "source": [
    "t = torch.cuda.get_device_properties(0).total_memory\n",
    "r = torch.cuda.memory_reserved(0)\n",
    "a = torch.cuda.memory_allocated(0)\n",
    "print(t)\n",
    "print(r)\n",
    "print(a)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.12 (ipykernel)",
   "language": "python",
   "name": "python3.12"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
