{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"machine_shape":"hm","gpuType":"T4","authorship_tag":"ABX9TyNYIJA2mf13TBxRgmzQ3rnR"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU","gpuClass":"standard"},"cells":[{"cell_type":"markdown","source":["# Neuron to Graph Notebook\n","\n","This notebook contains code and instructions for reproducing the paper results.\n","\n","This is designed to be run in Colab, and you'll need to ensure you're using a GPU runtime.\n","\n","Note that running the full baselines takes about 12 hours each, and running the full N2G build takes about 48 hours.\n","\n","Set the correct `intermediate_path` below, run the below cell to install the correct version of NumPy and restart the runtime, then run all cells up to the `Run Experiments` section. You can then reproduce any experiments you want by running the cells in that section."],"metadata":{"id":"VWhLuqAr2nPu"}},{"cell_type":"markdown","source":["## Setup"],"metadata":{"id":"LjaU_aco2kcR"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"nsH592cE18d9"},"outputs":[],"source":["# Run this cell, then restart the runtime\n","!pip install numpy==1.23"]},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"sVanHlBi2gQh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import os\n","\n","# Set this to point to the directory containing this notebook\n","intermediate_path = \n","base_path = f\"/content/drive/MyDrive/{intermediate_path}/supplementary_materials\" "],"metadata":{"id":"FVNl5bhx22lh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["!pip install git+https://github.com/neelnanda-io/TransformerLens"],"metadata":{"id":"99FZfm143zI5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Import stuff\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import numpy as np\n","import einops\n","import tqdm.notebook as tqdm\n","\n","import random\n","import time\n","\n","# from google.colab import drive\n","from pathlib import Path\n","import pickle\n","import os\n","\n","\n","import matplotlib.pyplot as plt\n","\n","%matplotlib inline\n","import plotly.express as px\n","import plotly.graph_objects as go\n","\n","from torch.utils.data import DataLoader\n","\n","from functools import *\n","import pandas as pd\n","import gc\n","import collections\n","import copy\n","\n","# import comet_ml\n","import itertools\n","from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n","import dataclasses\n","import datasets\n","\n","from pprint import pprint"],"metadata":{"id":"SW1iwfDD33UG"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from transformer_lens.utils import (\n","    gelu_new,\n","    to_numpy,\n","    get_corner,\n","    lm_cross_entropy_loss,\n",")  # Helper functions\n","from transformer_lens.hook_points import (\n","    HookedRootModule,\n","    HookPoint,\n",")  # Hooking utilities\n","from transformer_lens import HookedTransformer, HookedTransformerConfig"],"metadata":{"id":"D0mIN73736At"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["device = \"cuda\" if torch.cuda.is_available() else \"cpu\""],"metadata":{"id":"BJ4dKLV-36X6"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from transformers import AutoModelForCausalLM\n","\n","model_name = \"solu-6l-pile\"\n","layer_ending = \"mlp.hook_mid\"\n","model = HookedTransformer.from_pretrained(model_name).to(device)"],"metadata":{"id":"05HWxIiV4EqO"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from transformers import AutoModelForMaskedLM\n","from transformers import AutoTokenizer\n","\n","aug_model_checkpoint = \"distilbert-base-uncased\"\n","aug_model = AutoModelForMaskedLM.from_pretrained(aug_model_checkpoint).to(device)\n","aug_tokenizer = AutoTokenizer.from_pretrained(aug_model_checkpoint)"],"metadata":{"id":"trVXTQsL4HuO"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Code"],"metadata":{"id":"JUbvRuYx4Kd0"}},{"cell_type":"markdown","source":["### Utils"],"metadata":{"id":"sNG5MFo-49wz"}},{"cell_type":"code","source":["import requests\n","import re\n","import json\n","\n","parser = re.compile('\\{\\\"tokens\\\": ')\n","def get_snippets(model_name, layer, neuron):\n","  \"\"\"Get the max activating dataset examples for a given neuron in a model\"\"\"\n","  base_url = f\"https://neuroscope.io/{model_name}/{layer}/{neuron}.html\"\n","\n","  response = requests.get(base_url)\n","  webpage = response.text\n","  \n","  parts = parser.split(webpage)\n","  snippets = []\n","  for i, part in enumerate(parts):    \n","    if i == 0 or i % 2 != 0:\n","      continue\n","\n","    token_str = part.split(', \"values\": ')[0]\n","\n","    tokens = json.loads(token_str)\n","\n","    snippet = \"\".join(tokens)\n","\n","    snippets.append(snippet)\n","    \n","  if len(snippets) != 20:\n","    raise Exception\n","  return snippets"],"metadata":{"id":"MGRhU4pE4IBD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["act_parser = re.compile('<h4>Max Act: <b>')\n","def get_max_activations(model_name, layer, neuron, n=1):\n","  \"\"\"Get the max activating dataset examples for a given neuron in a model\"\"\"\n","  base_url = f\"https://neuroscope.io/{model_name}/{layer}/{neuron}.html\"\n","\n","  response = requests.get(base_url)\n","  webpage = response.text\n","  \n","  parts = act_parser.split(webpage)\n","  activations = []\n","  for i, part in enumerate(parts):    \n","    if i == 0:\n","      continue\n","\n","    activation = float(part.split('</b>')[0])\n","\n","    activations.append(activation)\n","    if len(activations) >= n:\n","      break\n","    \n","  if len(activations) != min(20, n):\n","    raise Exception\n","  return activations if n > 1 else activations[0]"],"metadata":{"id":"dVNeyRvk4N_V"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from string import punctuation\n","\n","class WordTokenizer:\n","  \"\"\"Simple tokenizer for splitting text into words\"\"\"\n","\n","  def __init__(self, split_tokens, stick_tokens):\n","    self.split_tokens = split_tokens\n","    self.stick_tokens = stick_tokens\n","\n","  def __call__(self, text):\n","    return self.tokenize(text)\n","\n","  def is_split(self, char):\n","    \"\"\"Split on any non-alphabet chars unless excluded, and split on any specified chars\"\"\"\n","    return char in self.split_tokens or (not char.isalpha() and char not in stick_tokens)\n","  \n","  def tokenize(self, text):\n","    \"\"\"Tokenize text, preserving all characters\"\"\"\n","    tokens = []\n","    current_token = \"\"\n","    for char in text:\n","      if self.is_split(char):\n","        tokens.append(current_token)\n","        tokens.append(char)\n","        current_token = \"\"\n","        continue\n","      current_token += char\n","    tokens.append(current_token)\n","    tokens = [token for token in tokens if token]\n","    return tokens\n","\n","stick_tokens = {\"'\"}\n","word_tokenizer = WordTokenizer(set(), stick_tokens)"],"metadata":{"id":"sGmzE-rj4Puw"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import math\n","\n","\n","def batch(arr, n=None, batch_size=None):\n","    if n is None and batch_size is None:\n","        raise ValueError(\"Either n or batch_size must be provided\")\n","    if n is not None and batch_size is not None:\n","        raise ValueError(\"Either n or batch_size must be provided, not both\")\n","\n","    if n is not None:\n","        batch_size = math.floor(len(arr) / n)\n","    elif batch_size is not None:\n","        n = math.ceil(len(arr) / batch_size)\n","\n","    extras = len(arr) - (batch_size * n)\n","    groups = []\n","    group = []\n","    added_extra = False\n","    for element in arr:\n","        group.append(element)\n","        if len(group) >= batch_size:\n","            if extras and not added_extra:\n","                extras -= 1\n","                added_extra = True\n","                continue\n","            groups.append(group)\n","            group = []\n","            added_extra = False\n","\n","    if group:\n","        groups.append(group)\n","\n","    return groups"],"metadata":{"id":"uyAQJ2_25EAk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["with open(f\"{base_path}/data/ngrams/word_to_casings.json\") as ifh:\n","  word_to_casings = json.load(ifh)"],"metadata":{"id":"RYW_4YOl4Q5V"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import nltk\n","nltk.download('stopwords')"],"metadata":{"id":"bSQCvroz4rx7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["with open(os.path.join(base_path, f\"data/activation_matrix-{model_name}.json\")) as ifh:\n","    activation_matrix = json.load(ifh)\n","    activation_matrix = np.array(activation_matrix)"],"metadata":{"id":"qOTJcLY05-ZN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def layer_and_neuron_to_index(layer, neuron, width=3072, block_size=None):\n","  index = (layer * width) + neuron\n","  if block_size is None:\n","    return index\n","  return divmod(index, block_size)\n","\n","def index_to_layer_and_neuron(index, width=3072):\n","  return divmod(index, width)"],"metadata":{"id":"0vRnZKMX5_vq"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from sklearn.metrics import classification_report\n","\n","\n","def evaluate(neuron_model, data, fire_threshold=0.2, **kwargs):\n","  y = []\n","  y_pred = []\n","  y_act = []\n","  y_pred_act = []\n","  for prompt_tokens, activations in data:\n","    # print(\"truth\")\n","    non_zero_indices = [i for i, activation in enumerate(activations) if activation > 0]\n","    start = max(0, non_zero_indices[0] - 10)\n","    end = min(len(prompt_tokens) - 1, non_zero_indices[-1] + 10)\n","    pred_activations = neuron_model.forward([prompt_tokens], return_activations=True)[0]\n","\n","    y_act.extend(activations)\n","    y_pred_act.extend(pred_activations)\n","\n","    important_context = list(zip(prompt_tokens, activations, pred_activations))[start:end]\n","\n","    pred_firings = [int(pred_activation >= fire_threshold) for pred_activation in pred_activations]\n","    firings = [int(activation >= fire_threshold) for activation in activations]\n","    y_pred.extend(pred_firings)\n","    y.extend(firings)\n","    \n","  print(classification_report(y, y_pred))\n","  report = classification_report(y, y_pred, output_dict=True)\n","  return report"],"metadata":{"id":"JTAvR2Lg6BEk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from collections import defaultdict\n","import random\n","\n","\n","def get_summary_stats(path):\n","  summary_stats = []\n","  summary_stds = []\n","  \n","  with open(path) as ifh:\n","    stats = json.load(ifh)\n","\n","  missing = 0\n","\n","  random.seed(0)\n","\n","  inelegible_count = 0\n","\n","  precision_case = 0\n","\n","  for layer, layer_stats in stats.items():\n","    eligible_neurons = [neuron for neuron, neuron_stats in layer_stats.items() if \"1\" in neuron_stats]\n","    eligible_neurons = set(eligible_neurons)\n","\n","    aggr_stats_dict = {\"Inactivating\": defaultdict(list), \"Activating\": defaultdict(list)}\n","    for neuron, neuron_stats in layer_stats.items():\n","      if neuron not in eligible_neurons:\n","        inelegible_count += 1\n","        continue\n","\n","      aggr_stats_dict[\"Inactivating\"][\"Precision\"].append(neuron_stats[\"0\"][\"precision\"])\n","      aggr_stats_dict[\"Inactivating\"][\"Recall\"].append(neuron_stats[\"0\"][\"recall\"])   \n","      aggr_stats_dict[\"Inactivating\"][\"F1\"].append(neuron_stats[\"0\"][\"f1-score\"])  \n","\n","      # If we didn't predict anything as activating, treat this as 100% precision rather than 0%\n","      if neuron_stats[\"0\"][\"recall\"] == 1 and neuron_stats[\"1\"][\"recall\"] == 0:\n","        precision_case += 1\n","        neuron_stats[\"1\"][\"precision\"] = 1.0\n","\n","      aggr_stats_dict[\"Activating\"][\"Precision\"].append(neuron_stats[\"1\"][\"precision\"])\n","      aggr_stats_dict[\"Activating\"][\"Recall\"].append(neuron_stats[\"1\"][\"recall\"])\n","      aggr_stats_dict[\"Activating\"][\"F1\"].append(neuron_stats[\"1\"][\"f1-score\"])  \n","\n","    avg_stats_dict = {\"Inactivating\": {}, \"Activating\": {}}\n","    std_stats_dict = {\"Inactivating\": {}, \"Activating\": {}}\n","    for token_type, inner_stats_dict in aggr_stats_dict.items():\n","      for stat_type, stat_arr in inner_stats_dict.items():\n","        avg_stats_dict[token_type][stat_type] = round(np.mean(stat_arr), 3)\n","        std_stats_dict[token_type][stat_type] = round(np.std(stat_arr), 3)\n","\n","    summary_stats.append(avg_stats_dict)\n","    summary_stds.append(std_stats_dict)\n","\n","  for layer, (summary, std_summary) in enumerate(zip(summary_stats, summary_stds)):\n","    print(f\"\\nLayer {layer}\")\n","    pprint(summary)\n","    pprint(std_summary)"],"metadata":{"id":"0eCPdy-JHy9w"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Augment"],"metadata":{"id":"tSPvqvfA472X"}},{"cell_type":"code","source":["from ast import Continue\n","import copy\n","from nltk.corpus import stopwords\n","from string import punctuation\n","import re\n","from scipy.special import softmax\n","\n","class FastAugmenter:\n","  \"\"\"Uses word embeddings to generate variations on input text by replacing words with other words with similar embeddings\"\"\"\n","\n","  def __init__(self, model, model_tokenizer, word_tokenizer, neuron_model, device=\"cuda:0\"):\n","    self.model = model\n","    self.model_tokenizer = model_tokenizer\n","    self.stops = set(stopwords.words('english'))\n","    self.punctuation_set = set(punctuation)\n","    self.to_strip = \" \" + punctuation\n","    self.word_tokenizer = word_tokenizer\n","    self.device = device\n","\n","  def augment(self, text, max_char_position=None, exclude_stopwords=False, n=5, important_tokens=None, **kwargs):\n","    joiner = \"\"\n","    tokens = self.word_tokenizer(text)\n","    \n","    new_texts = []\n","    positions = []\n","\n","    important_tokens = {token.strip(self.to_strip).lower() for token in important_tokens}\n","\n","    seen_prompts = set()\n","\n","    # Gather all tokens to be substituted\n","    tokens_to_sub = []\n","\n","    # Mask important tokens   \n","    masked_token_sets = []\n","    masked_texts = []\n","\n","    masked_tokens = []\n","\n","    for i, token in enumerate(tokens):\n","      norm_token = token.strip(self.to_strip).lower() if any(c.isalpha() for c in token) else token\n","\n","      if not token or word_tokenizer.is_split(token) or (exclude_stopwords and norm_token in self.stops) or (important_tokens is not None and norm_token not in important_tokens):\n","        continue\n","      \n","      # If no alphanumeric characters, we'll do a special substitution rather than using BERT\n","      if not any(c.isalpha() for c in token):\n","        continue\n","\n","      before = tokens[:i]\n","      before_text = joiner.join(before)\n","      position = len(before_text)\n","\n","      # Don't bother if we're beyond the max activating token, as these tokens have no effect on the activation\n","      if max_char_position is not None and position > max_char_position:\n","        break\n","\n","      copy_tokens = copy.deepcopy(tokens)\n","      copy_tokens[i] = \"[MASK]\"     \n","      masked_token_sets.append((copy_tokens, position))\n","      masked_texts.append(joiner.join(copy_tokens))\n","\n","      masked_tokens.append(token)\n","\n","    # pprint(masked_texts)\n","    if len(masked_texts) == 0:\n","      return [], []\n","    \n","    inputs = self.model_tokenizer(masked_texts, padding=True, return_tensors=\"pt\").to(self.device)\n","    token_probs = softmax(self.model(**inputs).logits.cpu().detach().numpy(), axis=-1)\n","    inputs = inputs.to(\"cpu\")\n","\n","    chosen_tokens = set()\n","\n","    new_texts = []\n","    positions = []\n","\n","    seen_texts = set()\n","\n","    for i, (masked_token_set, char_position) in enumerate(masked_token_sets):    \n","      mask_token_index = np.argwhere(inputs[\"input_ids\"][i] == self.model_tokenizer.mask_token_id)[0, 0]\n","\n","      mask_token_probs = token_probs[i, mask_token_index, :]\n","\n","      # We negate the array before argsort to get the largest, not the smallest, logits\n","      top_probs = -np.sort(-mask_token_probs).transpose()\n","      top_tokens = np.argsort(-mask_token_probs).transpose()\n","\n","      subbed = 0\n","\n","      # Substitute the given token with the best predictions\n","      for l, (top_token, top_prob) in enumerate(zip(top_tokens, top_probs)):\n","        if top_prob < 0.00001:\n","          break        \n","\n","        candidate_token = self.model_tokenizer.decode(top_token)\n","\n","        # print(candidate_token)\n","\n","        # Check that the predicted token isn't the same as the token that was already there\n","        normalised_candidate = candidate_token.strip(self.to_strip).lower() if candidate_token not in self.punctuation_set else candidate_token\n","        normalised_token = token.strip(self.to_strip).lower() if token not in self.punctuation_set else token\n","        \n","        if normalised_candidate == normalised_token or not any(c.isalpha() for c in candidate_token):\n","          continue\n","\n","        # Get most common casing of the word\n","        most_common_casing = word_to_casings.get(candidate_token, [(candidate_token, 1)])[0][0]\n","\n","        original_token = masked_tokens[i]\n","        # Title case normally has meaning (e.g., start of sentence, in a proper noun, etc.) so follow original token, otherwise use most common\n","        best_casing = candidate_token.title() if original_token.istitle() else most_common_casing\n","\n","        new_token_set = copy.deepcopy(masked_token_set)\n","        # BERT uses ## to denote a tokenisation within a word, so we remove it to glue the word back together\n","        masked_text = joiner.join(new_token_set)\n","        new_text = masked_text.replace(self.model_tokenizer.mask_token, best_casing, 1).replace(\" ##\", \"\")\n","\n","        if new_text in seen_texts:\n","          continue\n","\n","        new_texts.append(new_text)\n","        positions.append(char_position)\n","        subbed += 1\n","\n","        if subbed >= n:\n","          break\n","\n","    return new_texts, positions\n","    \n","\n","def augment(model, layer, index, prompt, aug, max_length=1024, inclusion_threshold=-0.5, exclusion_threshold=-0.5, n=5, **kwargs):\n","  \"\"\"Generate variations of a prompt using an augmenter\"\"\"\n","  prepend_bos = True\n","  tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)\n","  str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos)\n","\n","  # print(prompt)\n","\n","  if len(tokens[0]) > max_length:\n","    tokens = tokens[0, :max_length].unsqueeze(0)\n","\n","  logits, cache = model.run_with_cache(tokens)\n","  activations = cache[layer][0, :, index]\n","\n","  initial_max = torch.max(activations).cpu().item()\n","  initial_argmax = torch.argmax(activations).cpu().item()\n","  max_char_position = len(\"\".join(str_tokens[int(prepend_bos):initial_argmax + 1]))\n","\n","  positive_prompts = [(prompt, initial_max, 1)]\n","  negative_prompts = []\n","\n","  if n == 0:\n","    return positive_prompts, negative_prompts\n","  \n","  aug_prompts, aug_positions = aug.augment(prompt, max_char_position=max_char_position, n=n, **kwargs)\n","  if not aug_prompts:\n","    return positive_prompts, negative_prompts\n","    \n","  aug_tokens = model.to_tokens(aug_prompts, prepend_bos=prepend_bos)\n","\n","  aug_logits, aug_cache = model.run_with_cache(aug_tokens)\n","  all_aug_activations = aug_cache[layer][:, :, index]\n","\n","  for aug_prompt, char_position, aug_activations in zip(aug_prompts, aug_positions, all_aug_activations):\n","    aug_max = torch.max(aug_activations).cpu().item()\n","    aug_argmax = torch.argmax(aug_activations).cpu().item()\n","\n","    # TODO implement this properly - when we mask multiple tokens, if they cross the max_char_position this will not necessarily be correct\n","    if char_position < max_char_position:\n","      new_str_tokens = model.to_str_tokens(aug_prompt, prepend_bos=prepend_bos)\n","      aug_argmax += len(new_str_tokens) - len(str_tokens)\n","\n","    proportion_drop = (aug_max - initial_max) / initial_max\n","\n","    if proportion_drop >= inclusion_threshold:\n","      positive_prompts.append((aug_prompt, aug_max, proportion_drop))\n","    elif proportion_drop < exclusion_threshold:\n","      negative_prompts.append((aug_prompt, aug_max, proportion_drop))\n","\n","  return positive_prompts, negative_prompts"],"metadata":{"id":"lvFvnXz-4Yn1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["fast_aug = FastAugmenter(aug_model, aug_tokenizer, word_tokenizer, model)"],"metadata":{"id":"qAFuTDfH4ZR2"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Prune"],"metadata":{"id":"gsYiAodG5Afk"}},{"cell_type":"code","source":["from pprint import pprint\n","from collections import defaultdict\n","from string import punctuation\n","import re\n","import copy\n","\n","splitter = re.compile(\"[\\.!\\\\n]\")\n","\n","def sentence_tokenizer(str_tokens):\n","  \"\"\"Split tokenized text into sentences\"\"\"\n","  sentences = []\n","  sentence = []\n","  sentence_to_token_indices = defaultdict(list)\n","  token_to_sentence_indices = {}\n","  \n","  for i, str_token in enumerate(str_tokens):\n","    sentence.append(str_token)\n","    sentence_to_token_indices[len(sentences)].append(i)\n","    token_to_sentence_indices[i] = len(sentences)\n","    if splitter.search(str_token) is not None or i + 1 == len(str_tokens):\n","      sentences.append(sentence)\n","      sentence = []    \n","\n","  return sentences, sentence_to_token_indices, token_to_sentence_indices"],"metadata":{"id":"DnmtEj_-4bMt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def fast_prune(model, layer, neuron, prompt, max_length=1024, proportion_threshold=-0.5, absolute_threshold=None, window=0, return_maxes=False, cutoff=30, batch_size=4, max_post_context_tokens=5, skip_threshold=0, skip_interval=5, return_intermediates=False, **kwargs):\n","  \"\"\"Prune an input prompt to the shortest string that preserves x% of neuron activation on the most activating token.\"\"\"\n","\n","  prepend_bos = True\n","  tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)\n","  str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos)\n","\n","  if len(tokens[0]) > max_length:\n","    tokens = tokens[0, :max_length].unsqueeze(0)\n","\n","  logits, cache = model.run_with_cache(tokens)\n","  activations = cache[layer][0, :, neuron]\n","\n","  initial_max = torch.max(activations).cpu().item()\n","  initial_argmax = torch.argmax(activations).cpu().item()\n","\n","  sentences, sentence_to_token_indices, token_to_sentence_indices = sentence_tokenizer(str_tokens)\n","\n","  max_sentence_index = token_to_sentence_indices[initial_argmax]\n","  relevant_str_tokens = [str_token for sentence in sentences[:max_sentence_index + 1] for str_token in sentence]\n","\n","  prior_context = relevant_str_tokens[:initial_argmax + 1]\n","\n","  post_context = relevant_str_tokens[initial_argmax + 1:]\n","\n","  shortest_successful_prompt = None\n","  final_max_index = None\n","\n","  truncated_prompts = []\n","  added_tokens = []\n","\n","  count = 0\n","  full_prior = prior_context[:max(0, initial_argmax - window + 1)]\n","\n","  for i, str_token in reversed(list(enumerate(full_prior))):\n","    count += 1\n","\n","    if count > cutoff:\n","      break\n","\n","    if not count == len(full_prior) and count >= skip_threshold and count % skip_interval != 0:\n","      continue\n","\n","    truncated_prompt = prior_context[i:]\n","    joined = \"\".join(truncated_prompt)\n","    truncated_prompts.append(joined)\n","    added_tokens.append(i)\n","\n","  batched_truncated_prompts = batch(truncated_prompts, batch_size=batch_size)\n","  batched_added_tokens = batch(added_tokens, batch_size=batch_size)\n","  \n","  finished = False\n","  intermediates = []\n","  for i, (truncated_batch, added_tokens_batch) in enumerate(zip(batched_truncated_prompts, batched_added_tokens)):\n","\n","    truncated_tokens = model.to_tokens(truncated_batch, prepend_bos=prepend_bos)\n","\n","    logits, cache = model.run_with_cache(truncated_tokens)\n","    all_truncated_activations = cache[layer][:, :, neuron]\n","\n","    for j, truncated_activations in enumerate(all_truncated_activations):\n","      num_added_tokens = added_tokens_batch[j]\n","      truncated_argmax = torch.argmax(truncated_activations).cpu().item() + num_added_tokens\n","      final_max_index = torch.argmax(truncated_activations).cpu().item()\n","\n","      if prepend_bos:\n","        truncated_argmax -= 1\n","        final_max_index -= 1\n","      truncated_max = torch.max(truncated_activations).cpu().item()\n","\n","      shortest_prompt = truncated_batch[j]\n","\n","      if not shortest_prompt.startswith(\"<|endoftext|>\"):\n","        truncated_str_tokens = model.to_str_tokens(truncated_batch[j], prepend_bos=False)\n","        intermediates.append((shortest_prompt, truncated_str_tokens[0], truncated_max))\n","\n","      if (truncated_argmax == initial_argmax and (\n","          (truncated_max - initial_max) / initial_max > proportion_threshold or \n","          (absolute_threshold is not None and truncated_max >= absolute_threshold))) or (i == len(batched_truncated_prompts) - 1 and j == len(all_truncated_activations) - 1):        \n","        shortest_successful_prompt = shortest_prompt\n","        finished = True\n","        break\n","    \n","    if finished:\n","      break\n","\n","  pruned_sentence = \"\".join(shortest_successful_prompt)\n","\n","  if max_post_context_tokens is not None:\n","    pruned_sentence += \"\".join(post_context[:max_post_context_tokens])\n","  \n","  if return_maxes:\n","    return pruned_sentence, final_max_index, initial_max, truncated_max\n","\n","  elif return_intermediates:\n","    return pruned_sentence, intermediates\n","\n","  return pruned_sentence, final_max_index"],"metadata":{"id":"lItixA685GeH"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Importance"],"metadata":{"id":"5bJNZ46d5Vnq"}},{"cell_type":"code","source":["import numpy as np\n","import copy\n","\n","def fast_measure_importance(model, layer, neuron, prompt, initial_argmax=None, max_length=1024, max_activation=None, masking_token=1, threshold=0.8, scale_factor=1, return_all=False, activation_threshold=0.1, **kwargs):\n","  \"\"\"Compute a measure of token importance by masking each token and measuring the drop in activation on the max activating token\"\"\"\n","  \n","  prepend_bos = True\n","  tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)\n","  str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos)\n","\n","  if len(tokens[0]) > max_length:\n","    tokens = tokens[0, :max_length].unsqueeze(0)\n","\n","  importances_matrix = []\n","\n","  shortest_successful_prompt = None\n","\n","  masked_prompts = tokens.repeat(len(tokens[0]) + 1, 1)\n","\n","  for i in range(1, len(masked_prompts)):\n","    masked_prompts[i, i - 1] = masking_token\n","  \n","  logits, cache = model.run_with_cache(masked_prompts)\n","  all_masked_activations = cache[layer][1:, :, neuron]\n","\n","  activations = cache[layer][0, :, neuron]\n","\n","  if initial_argmax is None:\n","    initial_argmax = torch.argmax(activations).cpu().item()\n","  else:\n","    initial_argmax = min(initial_argmax, len(activations) - 1)\n","    \n","  initial_max = activations[initial_argmax].cpu().item()\n","\n","  if max_activation is None:\n","    max_activation = initial_max\n","  scale = min(1, initial_max / max_activation)\n","\n","  tokens_and_activations = [[str_token, round(activation.cpu().item() * scale_factor / max_activation, 3)] for str_token, activation in zip(str_tokens, activations)]\n","  important_tokens = []\n","  tokens_and_importances = [[str_token, 0] for str_token in str_tokens]\n","\n","  for i, masked_activations in enumerate(all_masked_activations):\n","    if return_all:\n","      # Get importance of the given token for all tokens\n","      importances_row = []\n","      for j, activation in enumerate(masked_activations):\n","        activation = activation.cpu().item()\n","        normalised_activation = (1 - (activation / activations[j].cpu().item()))\n","        importances_row.append((str_tokens[j], normalised_activation))\n","\n","      importances_matrix.append(np.array(importances_row))\n","\n","    masked_max = masked_activations[initial_argmax].cpu().item()\n","    normalised_activation = (1 - (masked_max / initial_max))\n","\n","    str_token = tokens_and_importances[i][0]\n","    tokens_and_importances[i][1] = normalised_activation\n","    if normalised_activation >= threshold and str_token != \"<|endoftext|>\":\n","      important_tokens.append(str_token)\n","\n","  if return_all:\n","    # Flip so we have the importance of all tokens for a given token\n","    importances_matrix = np.array(importances_matrix)\n","    return importances_matrix, initial_max, important_tokens, tokens_and_activations, initial_argmax\n","\n","  return tokens_and_importances, initial_max, important_tokens, tokens_and_activations, initial_argmax"],"metadata":{"id":"7QI0HO235PfW"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Runners"],"metadata":{"id":"ivtgWrG75nRM"}},{"cell_type":"code","source":["from sklearn.model_selection import train_test_split\n","\n","\n","def layer_index_to_name(layer_index):\n","  return f\"blocks.{layer_index}.{layer_ending}\"\n","\n","\n","def train_and_eval(model, layer, neuron, aug=fast_aug, train_proportion=0.5, max_train_size=10, max_eval_size=20, fire_threshold=0.5, random_state=0, train_indexes=None, return_paths=False, **kwargs):\n","  if isinstance(layer, int):\n","    layer = layer_index_to_name(layer)\n","\n","  layer_num = int(layer.split(\".\")[1])\n","  base_max_act = float(activation_matrix[layer_num, neuron])\n","\n","  snippets = get_snippets(model_name, layer_num, neuron)\n","\n","  if train_indexes is None:\n","    train_snippets, test_snippets = train_test_split(snippets, train_size=train_proportion, random_state=random_state)\n","  else:\n","    train_snippets = [snippet for i, snippet in enumerate(snippets) if i in train_indexes]\n","    test_snippets = [snippet for i, snippet in enumerate(snippets) if i not in train_indexes]\n","  \n","  train_data_snippets = []\n","  all_train_snippets = train_snippets + train_data_snippets\n","\n","  all_info = []\n","  pruned_prompts = []\n","  for i, snippet in enumerate(all_train_snippets):\n","    print(f\"Processing {i + 1} of {len(all_train_snippets)}\")\n","\n","    pruned_prompt, _, initial_max_act, truncated_max_act = fast_prune(model, layer, neuron, snippet, return_maxes=True, **kwargs)\n","    pruned_prompts.append(pruned_prompt)\n","\n","    scale_factor = initial_max_act / truncated_max_act\n","\n","    if pruned_prompt is None:\n","      continue  \n","\n","    info = augment_and_return(model, layer, neuron, aug, pruned_prompt, base_max_act=base_max_act, scale_factor=scale_factor, **kwargs)\n","    all_info.append(info)\n","\n","  neuron_model = NeuronModel(layer_num, neuron, **kwargs)\n","  paths = neuron_model.fit(all_info)\n","\n","  print(\"Fitted model\")\n","\n","  max_test_data = []\n","  for snippet in test_snippets:\n","    tokens = model.to_tokens(snippet, prepend_bos=True)\n","    str_tokens = model.to_str_tokens(snippet, prepend_bos=True)\n","    logits, cache = model.run_with_cache(tokens)\n","    activations = cache[layer][0, :, neuron]\n","    max_test_data.append((str_tokens, activations.cpu() / base_max_act))\n","\n","  print(\"Max Activating Evaluation Data\")\n","  try:\n","    stats = evaluate(neuron_model, max_test_data, fire_threshold=fire_threshold, **kwargs)\n","  except Exception as e:\n","    stats = {}\n","    print(f\"Stats failed with error: {e}\")\n","\n","  if return_paths:\n","    return stats, paths\n","  return stats\n","\n","\n","def augment_and_return(model, layer, neuron, aug, pruned_prompt, base_max_act=None, use_index=False, scale_factor=1, **kwargs):\n","  info = []\n","  importances_matrix, initial_max_act, important_tokens, tokens_and_activations, initial_max_index = fast_measure_importance(model, layer, neuron, pruned_prompt, max_activation=base_max_act, scale_factor=scale_factor, return_all=True)\n","  \n","  if base_max_act is not None:\n","    initial_max_act = base_max_act\n","\n","  positive_prompts, negative_prompts = augment(model, layer, neuron, pruned_prompt, aug, important_tokens=set(important_tokens), **kwargs)  \n","\n","  for i, (prompt, activation, change) in enumerate(positive_prompts):\n","    title = prompt\n","    if i == 0:\n","      title = \"Original - \" + prompt\n","\n","    if use_index:\n","      importances_matrix, max_act, _, tokens_and_activations, max_index = fast_measure_importance(model, layer, neuron, prompt, max_activation=initial_max_act, initial_argmax=initial_max_index, scale_factor=scale_factor, return_all=True)\n","    else:\n","      importances_matrix, max_act, _, tokens_and_activations, max_index = fast_measure_importance(model, layer, neuron, prompt, max_activation=initial_max_act, scale_factor=scale_factor, return_all=True)\n","    info.append((importances_matrix, tokens_and_activations, max_index))\n","\n","  for prompt, activation, change in negative_prompts:\n","    if use_index:\n","      importances_matrix, max_act, _, tokens_and_activations, max_index = fast_measure_importance(model, layer, neuron, prompt, max_activation=initial_max_act, initial_argmax=initial_max_index, scale_factor=scale_factor, return_all=True)\n","    else:\n","      importances_matrix, max_act, _, tokens_and_activations, max_index = fast_measure_importance(model, layer, neuron, prompt, max_activation=initial_max_act, scale_factor=scale_factor, return_all=True)\n","    info.append((importances_matrix, tokens_and_activations, max_index))\n","\n","  return info"],"metadata":{"id":"tgaVcQOn5ik4"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def train_and_eval_baseline(model, layer, neuron, Baseline, train_proportion=0.5, fire_threshold=0.5, random_state=0, train_indexes=None, **kwargs):\n","  if isinstance(layer, int):\n","    layer = layer_index_to_name(layer)\n","\n","  layer_num = int(layer.split(\".\")[1])\n","\n","  base_max_act = float(activation_matrix[layer_num, neuron])\n","\n","  snippets = get_snippets(model_name, layer_num, neuron)\n","\n","  if train_indexes is None:\n","    train_snippets, test_snippets = train_test_split(snippets, train_size=train_proportion, random_state=random_state)\n","  else:\n","    train_snippets = [snippet for i, snippet in enumerate(snippets) if i in train_indexes]\n","    test_snippets = [snippet for i, snippet in enumerate(snippets) if i not in train_indexes]\n","\n","  train_data_snippets = []\n","  all_train_snippets = train_snippets + train_data_snippets\n","\n","  baseline_model = Baseline(model, layer_num, neuron, **kwargs)\n","  baseline_model.fit(all_train_snippets)\n","\n","  print(\"Fitted model\")\n","\n","  # Not pruning so don't need to prepend_bos\n","  prepend_bos = False\n","\n","  max_test_data = []\n","  for snippet in test_snippets:\n","    tokens = model.to_tokens(snippet, prepend_bos=prepend_bos)\n","    str_tokens = model.to_str_tokens(snippet, prepend_bos=prepend_bos)\n","    logits, cache = model.run_with_cache(tokens)\n","    activations = cache[layer][0, :, neuron]\n","    max_test_data.append((str_tokens, activations.cpu() / base_max_act))\n","\n","  print(\"Max Activating Evaluation Data\")\n","  stats = evaluate(baseline_model, max_test_data, fire_threshold=fire_threshold, **kwargs)\n","\n","  return stats"],"metadata":{"id":"-ZnmB3VgGAja"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import random\n","\n","\n","def evaluate_baseline(baseline, folder_name, layers=6, neurons=3072, layer_start=0, neuron_start=0, **kwargs):\n","  random.seed(0)\n","\n","  all_neuron_indices = [i for i in range(neurons)]\n","\n","  all_stats = {}\n","  folder_path = os.path.join(base_path, f\"neuron_graphs/{model_name}/{folder_name}\")\n","\n","  if not os.path.exists(folder_path):\n","    print(\"Making\", folder_path)\n","    os.mkdir(folder_path)\n","\n","  if os.path.exists(f\"{folder_path}/stats.json\"):\n","    with open(f\"{folder_path}/stats.json\") as ifh:\n","      all_stats = json.load(ifh)\n","  \n","  else:\n","    all_stats = {}\n","  \n","  for i, layer in enumerate(range(layer_start, layers)):\n","    if layer not in all_stats:\n","      all_stats[layer] = {}\n","\n","    for j, neuron in enumerate(range(neuron_start, neurons)):\n","      print(f\"{layer=} {neuron=}\")\n","      try:\n","        stats = train_and_eval_baseline(model, layer, neuron, baseline, train_proportion=0.5, fire_threshold=0.5, **kwargs)\n","        \n","        all_stats[layer][neuron] = stats\n","\n","        if j % 10 == 0:\n","          with open(f\"{folder_path}/stats.json\", \"w\") as ofh:\n","            json.dump(all_stats, ofh, indent=2)\n","\n","      except Exception as e:\n","        print(e)\n","        print(\"Failed\")\n","\n","  with open(f\"{folder_path}/stats.json\", \"w\") as ofh:\n","    json.dump(all_stats, ofh, indent=2)"],"metadata":{"id":"3NmhSmMhGSzt"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Graph Build"],"metadata":{"id":"krFNcwkEFC_H"}},{"cell_type":"code","source":["!pip install --upgrade graphviz"],"metadata":{"id":"FHJcH2GEFBx_"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from collections import defaultdict, namedtuple, Counter\n","from itertools import zip_longest\n","import json\n","from graphviz import Digraph, Graph, escape\n","from typing import List, Dict\n","import os\n","from IPython.display import Image, display\n","\n","\n","class NeuronStore:\n","  def __init__(self, path):\n","    if not os.path.exists(path):\n","      neuron_store = {\n","          \"activating\": {},\n","          \"important\": {}\n","      }\n","      with open(path, \"w\") as ofh:\n","        json.dump(neuron_store, ofh, indent=2, ensure_ascii=False)\n","\n","    with open(path) as ifh:\n","      self.store = json.load(ifh)\n","\n","    self.to_sets()    \n","    self.path = path\n","    self.count_tokens()\n","    self.by_neuron()\n","\n","  def save(self):\n","    self.to_lists()\n","    with open(self.path, \"w\") as ofh:\n","      json.dump(self.store, ofh, indent=2, ensure_ascii=False)\n","    self.to_sets()\n","\n","  def to_sets(self):\n","    self.store = {token_type: {token: set(info) for token, info in token_dict.items()} for token_type, token_dict in self.store.items()}\n","\n","  def to_lists(self):\n","    self.store = {token_type: {token: list(set(info)) for token, info in token_dict.items()} for token_type, token_dict in self.store.items()}\n","\n","  def by_neuron(self):\n","    self.neuron_to_tokens = {}\n","    for token_type, token_dict in self.store.items():\n","      for token, neurons in token_dict.items():\n","        for neuron in neurons:\n","          if neuron not in self.neuron_to_tokens:\n","            self.neuron_to_tokens[neuron] = {\"activating\": set(), \"important\": set()}\n","          self.neuron_to_tokens[neuron][token_type].add(token)\n","\n","  def search(self, tokens_and_types):\n","    match_arr = []\n","\n","    for token, token_type in tokens_and_types:\n","      token_types = [token_type] if token_type is not None else [\"activating\", \"important\"]\n","      token_matches = set()\n","\n","      for token_type in token_types:\n","        matches = self.store[token_type].get(token, set())\n","        token_matches |= matches\n","\n","      match_arr.append(token_matches)\n","\n","    valid_matches = set.intersection(*match_arr)\n","    return valid_matches\n","\n","  def count_tokens(self):\n","    self.neuron_individual_token_counts = defaultdict(Counter)\n","    self.neuron_total_token_counts = Counter()\n","    for token_type, token_dict in self.store.items():\n","      for token, neurons in token_dict.items():\n","        for neuron in neurons:\n","          self.neuron_individual_token_counts[neuron][token] += 1\n","          self.neuron_total_token_counts[neuron] += 1\n","\n","  def find_similar(self, target_token_types=None, threshold=0.9):\n","    if target_token_types is None:\n","      target_token_types = {\"activating\", \"important\"}\n","\n","    similar_pairs = []\n","    subset_pairs = []\n","\n","    for i, (neuron_1, neuron_dict_1) in enumerate(self.neuron_to_tokens.items()):\n","      if i % 1000 == 0:\n","        print(f\"{i} of {len(self.neuron_to_tokens.items())} complete\")\n","\n","      for j, (neuron_2, neuron_dict_2) in enumerate(self.neuron_to_tokens.items()):\n","        if i <= j:\n","          continue\n","\n","        all_similar = []\n","        all_subset = []\n","\n","        for token_type in target_token_types:\n","          length_1 = len(neuron_dict_1[token_type])\n","          length_2 = len(neuron_dict_2[token_type])\n","\n","          intersection = neuron_dict_1[token_type] & neuron_dict_2[token_type]\n","          similar = (len(intersection) / max(length_1, length_2, 1)) >= threshold\n","          subset = len(intersection) / max(min(length_1, length_2), 1) >= threshold\n","\n","          all_similar.append(similar)\n","          all_subset.append(subset)\n","\n","        if all(all_similar):\n","          similar_pairs.append((neuron_1, neuron_2))\n","        elif all(all_subset):\n","          # The first token indicates the superset neuron and the second the subset neuron\n","          subset_pair = (neuron_1, neuron_2) if length_2 < length_1 else (neuron_2, neuron_1)\n","          subset_pairs.append(subset_pair)\n","\n","    return similar_pairs, subset_pairs\n","          \n","\n","test_neuron_store = NeuronStore(f\"{base_path}/data/neuron_store_{model_name}_test.json\")\n","\n","\n","def view_neuron(path):\n","  display(Image(filename=path))\n","\n","\n","class NeuronNode:\n","  def __init__(self, id_=None, value=None, children=None, depth=None, important=False, activator=False):\n","    if value is None:\n","      value = {}\n","    if children is None:\n","      children = {}\n","    self.id_ = id_\n","    self.value = value\n","    self.children = children\n","    self.depth = depth\n","\n","  def __repr__(self):\n","    return f\"ID: {self.id_}, Value: {json.dumps(self.value)}\"\n","\n","  def paths(self):\n","    if not self.children:\n","      return [[self.value]]  # one path: only contains self.value\n","    paths = []\n","    for child_token, child_tuple in self.children.items():\n","      child_node, _ = child_tuple\n","      for path in child_node.paths():\n","          paths.append([self.value] + path)\n","    return paths\n","\n","\n","class NeuronEdge:\n","  def __init__(self, weight=0, parent=None, child=None):\n","    self.weight = weight\n","    self.parent = parent\n","    self.child = child\n","\n","  def __repr__(self):\n","    parent_str = json.dumps(self.parent.id_) if self.parent is not None else \"None\"\n","    child_str = json.dumps(self.child.id_) if self.child is not None else \"None\"\n","    return f\"Weight: {self.weight:.3f}\\nParent: {parent_str}\\nChild: {child_str}\"\n","\n","\n","class NeuronModel:\n","  def __init__(self, layer, neuron, activation_threshold=0.1, importance_threshold=0.5, folder_name=None, neuron_store=None, **kwargs):\n","    self.layer = layer\n","    self.neuron = neuron\n","    self.Element = namedtuple(\"Element\", \"importance, activation, token, important, activator, ignore, is_end, token_value\")\n","    self.neuron_store = neuron_store\n","\n","    self.root_token = \"**ROOT**\"\n","    self.ignore_token = \"**IGNORE**\"\n","    self.end_token = \"**END**\"\n","    self.special_tokens = {self.root_token, self.ignore_token, self.end_token}\n","\n","    self.root = (NeuronNode(-1, self.Element(0, 0, self.root_token, False, False, True, False, self.root_token), depth=-1), NeuronEdge())\n","    self.trie_root = (NeuronNode(-1, self.Element(0, 0, self.root_token, False, False, True, False, self.root_token), depth=-1), NeuronEdge())\n","    self.activation_threshold = activation_threshold\n","    self.importance_threshold = importance_threshold\n","\n","    self.net = Digraph(\n","        graph_attr={\"rankdir\": \"RL\", \"splines\": \"spline\", \"ranksep\": \"1.5\", \"nodesep\": \"0.2\"},\n","        node_attr={\"fixedsize\": \"true\", \"width\": \"2\", \"height\": \"0.75\"}\n","    )\n","    self.node_count = 0\n","    self.trie_node_count = 0\n","    self.max_depth = 0\n","    self.folder_name = folder_name\n","\n","  def __call__(self, tokens_arr: List[List[str]]) -> List[List[float]]:\n","    return self.forward(tokens_arr)    \n","\n","  def fit(self, data):\n","    for example_data in data:\n","      for j, info in enumerate(example_data):\n","        if j == 0:\n","          lines, important_index_sets = self.make_line(info)\n","        else:\n","          lines, _ = self.make_line(info, important_index_sets)\n","        \n","        for line in lines:\n","          self.add(self.root, line, graph=True)     \n","          self.add(self.trie_root, line, graph=False) \n","\n","    self.build(self.root)\n","    self.merge_ignores()\n","\n","    self.save_neurons()\n","\n","    print(\"Paths after merge\")\n","    paths = []\n","    for path in self.trie_root[0].paths():\n","      # print(path)\n","      paths.append(path)\n","\n","    return paths\n","\n","  def save_neurons(self):\n","    visited = set() # List to keep track of visited nodes.\n","    queue = []      # Initialize a queue\n","\n","    visited.add(self.trie_root[0].id_)\n","    queue.append(self.trie_root)\n","\n","    while queue:\n","      node, edge = queue.pop(0) \n","\n","      token = node.value.token\n","\n","      if token not in self.special_tokens:\n","        add_dict = self.neuron_store.store[\"activating\"] if node.value.activator else self.neuron_store.store[\"important\"]\n","        if token not in add_dict:\n","          add_dict[token] = set()\n","        add_dict[token].add(f\"{self.layer}_{self.neuron}\")         \n","\n","      for token, neighbour in node.children.items():\n","        new_node, new_edge = neighbour\n","        if new_node.id_ not in visited:\n","          visited.add(new_node.id_)\n","          queue.append(neighbour)    \n","\n","  @staticmethod\n","  def normalise(token):\n","      normalised_token = token.lower() if token.istitle() and len(token) > 1 else token\n","      normalised_token = normalised_token.strip() if len(normalised_token) > 1 and any(c.isalpha() for c in normalised_token) else normalised_token\n","      return normalised_token\n","\n","  def make_line(self, info, important_index_sets=None):\n","    if important_index_sets is None:\n","      important_index_sets = []\n","      create_indices = True\n","    else:\n","      create_indices = False\n","\n","    importances_matrix, tokens_and_activations, max_index = info\n","\n","    all_lines = []\n","    \n","    for i, (token, activation) in enumerate(tokens_and_activations):\n","      if create_indices:\n","        important_index_sets.append(set())\n","\n","      if not activation > self.activation_threshold:\n","        continue\n","\n","      before = tokens_and_activations[:i + 1]\n","      \n","      line = []\n","      last_important = 0\n","\n","      if not create_indices:\n","        important_indices = important_index_sets[i] if i < len(important_index_sets) else important_index_sets[-1]\n","      else:\n","        important_indices = set()\n","\n","      for j, (seq_token, seq_activation) in enumerate(reversed(before)):\n","        if seq_token == \"<|endoftext|>\":\n","          continue\n","\n","        seq_index = len(before) - j - 1\n","        important_token, importance = importances_matrix[seq_index, i]\n","        importance = float(importance)\n","\n","        important = importance > self.importance_threshold or (not create_indices and seq_index in important_indices)  \n","        activator = seq_activation > self.activation_threshold     \n","\n","        if important and create_indices:\n","          important_indices.add(seq_index)\n","\n","        ignore = not important and j != 0\n","        is_end = False\n","\n","        seq_token_identifier = self.ignore_token if ignore else seq_token\n","\n","        new_element = self.Element(importance, seq_activation, seq_token_identifier, important, activator, ignore, is_end, seq_token)\n","\n","        if not ignore:\n","          last_important = j\n","\n","        line.append(new_element)\n","      \n","      line = line[:last_important + 1]\n","      # Add an end node\n","      line.append(self.Element(0, activation, self.end_token, False, False, True, True, self.end_token))\n","      all_lines.append(line)\n","\n","      if create_indices:\n","        important_index_sets[i] = important_indices\n","    \n","    return all_lines, important_index_sets\n","\n","  def add(self, start_tuple, line, graph=True):\n","    current_tuple = start_tuple\n","    previous_element = None\n","    important_count = 0\n","\n","    start_depth = current_tuple[0].depth\n","\n","    for i, element in enumerate(line):  \n","      if element is None and i > 0:\n","        break\n","\n","      if element.ignore and graph:\n","        continue\n","\n","      # Normalise token\n","      element = element._replace(token=self.normalise(element.token))      \n","\n","      if graph:\n","        # Set end value as we don't have end nodes in the graph\n","        # The current node is an end if there's only one more node, as that will be the end node that we don't add\n","        is_end = i == len(line) - 2\n","        element = element._replace(is_end=is_end)\n","\n","      important_count += 1\n","\n","      current_node, current_edge = current_tuple\n","\n","      if not current_node.value.ignore:\n","        prev_important_node = current_node\n","\n","      if element.token in current_node.children:\n","        current_tuple = current_node.children[element.token]\n","        continue\n"," \n","      weight = 0 \n","\n","      depth = start_depth + important_count\n","      new_node = NeuronNode(self.node_count, element, {}, depth=depth)\n","      new_tuple = (new_node, NeuronEdge(weight, current_node, new_node))\n","\n","      self.max_depth = depth if depth > self.max_depth else self.max_depth\n","\n","      current_node.children[element.token] = new_tuple\n","\n","      current_tuple = new_tuple\n","      \n","      self.node_count += 1\n","\n","    return current_tuple\n","\n","  def merge_ignores(self):\n","    \"\"\"\n","    Where a set of children contain an ignore token, merge the other nodes into it:\n","      - Fully merge if the other node is not an end node\n","      - Give the ignore node the other node's children (if it has any) if the other node is an end node\n","    \"\"\"\n","    # print(\"\\n\\n******MERGING*******\")\n","    visited = set() # List to keep track of visited nodes.\n","    queue = []      # Initialize a queue\n","\n","    visited.add(self.trie_root[0].id_)\n","    queue.append(self.trie_root)\n","\n","    while queue:\n","      node, edge = queue.pop(0) \n","\n","      token = node.value.token\n","\n","      if self.ignore_token in node.children:\n","        ignore_tuple = node.children[self.ignore_token]\n","\n","        to_remove = []\n","\n","        for child_token, child_tuple in node.children.items():\n","          if child_token == self.ignore_token:\n","            continue\n","\n","          child_node, child_edge = child_tuple\n","\n","          child_paths = child_node.paths()\n","\n","          for path in child_paths:\n","            # Don't merge if the path is only the first tuple, or the first tuple and an end tuple\n","            if len(path) <= 1 or (len(path) == 2 and path[-1].token == self.end_token):\n","              continue\n","            # Merge the path (not including the first tuple that we're merging)\n","            self.add(ignore_tuple, path[1:], graph=False)\n","\n","          # Add the node to a list to be removed later if it isn't an end node and doesn't have an end node in its children\n","          if not child_node.value.is_end and not self.end_token in child_node.children:\n","            to_remove.append(child_token)\n","\n","        for child_token in to_remove:\n","          node.children.pop(child_token)\n","\n","      for token, neighbour in node.children.items():\n","        new_node, new_edge = neighbour\n","        if new_node.id_ not in visited:\n","          visited.add(new_node.id_)\n","          queue.append(neighbour)\n","\n","  def search(self, tokens: List[str]) -> float:\n","    \"\"\"Evaluate the activation on the first token in tokens\"\"\"\n","    current_tuple = self.trie_root\n","\n","    activations = [0] \n","    \n","    for i, token in enumerate(reversed(tokens)):   \n","      token = self.normalise(token)\n","\n","      current_node, current_edge = current_tuple\n","\n","      if token in current_node.children or self.ignore_token in current_node.children:\n","        current_tuple = current_node.children[token] if token in current_node.children else current_node.children[self.ignore_token]\n","      \n","        node, edge = current_tuple\n","        # If the first token is not an activator, return early\n","        if i == 0:\n","          if not node.value.activator:\n","            break\n","          activation = node.value.activation\n","\n","        if self.end_token in node.children:\n","          end_node, _ = node.children[self.end_token]\n","          end_activation = end_node.value.activation\n","          activations.append(end_activation)\n","          \n","      else:\n","        break\n","\n","    # Return the activation on the longest sequence\n","    return activations[-1]\n","\n","  def forward(self, tokens_arr: List[List[str]], return_activations=True) -> List[List[float]]:\n","    if isinstance(tokens_arr[0], str):\n","      raise ValueError(f\"tokens_arr must be of type List[List[str]]\")\n","\n","    # print(\"\\n\\n******PROCESSING*******\")\n","    \"\"\"Evaluate the activation on each token in some input tokens\"\"\"\n","    all_activations = []\n","    all_firings = []\n","\n","    for tokens in tokens_arr:\n","      activations = []\n","      firings = []\n","\n","      for j in range(len(tokens)):\n","        token_activation = self.search(tokens[:len(tokens) - j])\n","        activations.append(token_activation)\n","        firings.append(token_activation > self.activation_threshold)\n","\n","      activations = list(reversed(activations))\n","      firings = list(reversed(firings))\n","\n","      all_activations.append(activations)\n","      all_firings.append(firings)\n","\n","    if return_activations:\n","      return all_activations\n","    return all_firings\n","\n","  def build(self, start_node, graph=True):\n","    \"\"\"Build a graph to visualise\"\"\"\n","    # print(\"\\n\\n******BUILDING*******\")\n","    visited = set() # List to keep track of visited nodes.\n","    queue = []     #Initialize a queue\n","\n","    visited.add(start_node[0].id_)\n","    queue.append(start_node)\n","\n","    zero_width = u'\\u200b'\n","\n","    tokens_by_layer = {}\n","    node_id_to_graph_id = {}\n","    token_by_layer_count = defaultdict(Counter)\n","    added_ids = set()\n","    node_count = 0\n","    depth_to_subgraph = {}\n","    added_edges = set()\n","\n","    node_edge_tuples = []\n","\n","    adjust = lambda x, y: (x - y) / (1 - y)\n","\n","    while queue:\n","      node, edge = queue.pop(0) \n","\n","      node_edge_tuples.append((node, edge))\n","\n","      for token, neighbour in node.children.items():\n","        new_node, new_edge = neighbour\n","        if new_node.id_ not in visited:\n","          visited.add(new_node.id_)\n","          queue.append(neighbour)\n","\n","    for node, edge in node_edge_tuples:\n","      token = node.value.token\n","      depth = node.depth\n","\n","      if depth not in tokens_by_layer:\n","        tokens_by_layer[depth] = {} \n","        depth_to_subgraph[depth] = Digraph(name=f\"cluster_{str(self.max_depth - depth)}\")  \n","        depth_to_subgraph[depth].attr(pencolor=\"white\", penwidth=\"3\") \n","\n","      token_by_layer_count[depth][token] += 1\n","\n","      if not graph:\n","        # This is a horrible hack to allow us to have a dict with the \"same\" token as multiple keys - by adding zero width spaces the tokens look the same but are actually different. This allows us to display a trie rather than a node-collapsed graph\n","        seen_count = token_by_layer_count[depth][token] - 1\n","        add = zero_width * seen_count\n","        token += add\n","\n","      if token not in tokens_by_layer[depth]:\n","        tokens_by_layer[depth][token] = str(node_count)\n","        node_count += 1\n","\n","      graph_node_id = tokens_by_layer[depth][token]\n","      node_id_to_graph_id[node.id_] = graph_node_id\n","\n","      current_graph = depth_to_subgraph[depth]      \n","\n","      if depth == 0:\n","        # colour red according to activation for depth 0 tokens\n","        scaled_activation = int(adjust(node.value.activation, max(0, self.activation_threshold - 0.2)) * 255)\n","        rgb = (255, 255 - scaled_activation, 255 - scaled_activation)\n","      else:\n","        # colour blue according to importance for all other tokens\n","        # Shift and scale importance so the importance threshold becomes 0        \n","        scaled_importance = int(adjust(node.value.importance, max(0.1, self.importance_threshold - 0.2)) * 255)\n","        rgb = (255 - scaled_importance, 255 - scaled_importance, 255)\n","\n","      hex = \"#{0:02x}{1:02x}{2:02x}\".format(*self.clamp(rgb))\n","\n","      if graph_node_id not in added_ids and not node.value.ignore:\n","        display_token = token.strip(zero_width)\n","        display_token = json.dumps(display_token).strip('[]\"') if '\"' not in token else display_token\n","        if set(display_token) == {\" \"}:\n","          display_token = f\"'{display_token}'\"\n","\n","        fontcolor = \"white\" if depth != 0 and rgb[1] < 130 else \"black\"\n","        fontsize = \"25\" if len(display_token) < 12 else \"18\"\n","        edge_width = \"7\" if node.value.is_end else \"3\"\n","\n","        current_graph.node(\n","            graph_node_id, f\"{escape(display_token)}\", fillcolor=hex, shape=\"box\", \n","            style=\"filled,solid\", fontcolor=fontcolor, fontsize=fontsize,\n","            penwidth=edge_width\n","        )\n","        added_ids.add(graph_node_id)      \n","      \n","      if edge.parent is not None and edge.parent.id_ in visited and not edge.parent.value.ignore:\n","        graph_parent_id = node_id_to_graph_id[edge.parent.id_]\n","        edge_tuple = (graph_parent_id, graph_node_id)\n","        if edge_tuple not in added_edges:\n","          self.net.edge(*edge_tuple, penwidth=\"3\", dir=\"back\")\n","          added_edges.add(edge_tuple)\n","\n","    for depth, subgraph in depth_to_subgraph.items():\n","      self.net.subgraph(subgraph)\n","\n","    path_parts = ['neuron_graphs', model_name]\n","\n","    if self.folder_name is not None:\n","      path_parts.append(self.folder_name)\n","\n","    path_parts.append(f\"{self.layer}_{self.neuron}\")\n","\n","    save_path = base_path\n","    for path_part in path_parts:\n","      save_path += f\"/{path_part}\"\n","      if not os.path.exists(save_path):\n","        os.mkdir(save_path)\n","\n","    self.net.format = 'svg'\n","    filename = \"graph\" if graph else \"trie\"\n","    self.net.render(f\"{save_path}/{filename}\", view=False)\n","    self.net.format = 'png'\n","    self.net.render(f\"{save_path}/{filename}\", view=False)\n","\n","  @staticmethod\n","  def clamp(arr): \n","    return [max(0, min(x, 255)) for x in arr]"],"metadata":{"id":"OiW942P8FNP3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import numpy as np\n","\n","\n","class TokenPredictor:\n","  def __init__(self, model, layer, neuron, activation_threshold=0.5):\n","    self.model = model\n","    self.layer = layer\n","    self.neuron = neuron\n","    self.activation_threshold = activation_threshold\n","\n","    self.layer_name = layer_index_to_name(layer)\n","    self.max_activation = activation_matrix[layer, neuron]\n","\n","  def fit(self, texts):\n","    prepend_bos = False\n","\n","    self.token_to_activations = defaultdict(list)\n","    for i, text in enumerate(texts):\n","      all_tokens = model.to_tokens(text, prepend_bos=prepend_bos)\n","      logits, cache = model.run_with_cache(all_tokens)\n","      neuron_activations = cache[self.layer_name][0, :, self.neuron]\n","\n","      tokens = model.to_str_tokens(text, prepend_bos=prepend_bos)\n","      neuron_activations = neuron_activations.to(\"cpu\")\n","      for token, activation in zip(tokens, neuron_activations):\n","        activation = activation.item()\n","        self.token_to_activations[token].append(activation / self.max_activation)\n","\n","    self.token_to_activation = {token: np.max(activations) for token, activations in self.token_to_activations.items()}\n","\n","  def forward(self, tokens_arr: List[List[str]], return_activations=True) -> List[List[float]]:\n","    all_activations = []\n","    all_firings = []\n","\n","    for tokens in tokens_arr:\n","      activations = []\n","      firings = []\n","\n","      for token in tokens:\n","        activation = self.token_to_activation.get(token, 0)\n","\n","        activations.append(activation)\n","        firings.append(activation > self.activation_threshold)\n","\n","      all_activations.append(activations)\n","      all_firings.append(firings)\n","\n","    if return_activations:\n","      return all_activations\n","    return all_firings"],"metadata":{"id":"PDqXWB27GLT-"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import numpy as np\n","\n","\n","class NGramBaseline:\n","  def __init__(self, model, layer, neuron, prior_context=1, activation_threshold=0.5):\n","    self.model = model\n","    self.layer = layer\n","    self.neuron = neuron\n","    self.activation_threshold = activation_threshold\n","    self.prior_context = prior_context\n","\n","    self.layer_name = layer_index_to_name(layer)\n","    self.max_activation = activation_matrix[layer, neuron]\n","\n","  def fit(self, texts):\n","    prepend_bos = False\n","\n","    self.seq_to_activations = defaultdict(list)\n","    self.activating_tokens = set()\n","\n","    for i, text in enumerate(texts):\n","      all_tokens = model.to_tokens(text, prepend_bos=prepend_bos)\n","      logits, cache = model.run_with_cache(all_tokens)\n","      neuron_activations = cache[self.layer_name][0, :, self.neuron]\n","\n","      tokens = model.to_str_tokens(text, prepend_bos=prepend_bos)\n","      neuron_activations = neuron_activations.to(\"cpu\")\n","      for j, (token, activation) in enumerate(zip(tokens, neuron_activations)):\n","        activation = activation.item()\n","        if activation < self.activation_threshold:\n","          continue\n","        token_seq = tokens[max(0, j - self.prior_context):j + 1]\n","        self.activating_tokens.add(token)\n","        self.seq_to_activations[\"\".join(token_seq)].append(activation / self.max_activation)\n","\n","    self.seq_to_activation = {seq: np.max(activations) for seq, activations in self.seq_to_activations.items()}\n","\n","  def forward(self, tokens_arr: List[List[str]], return_activations=True) -> List[List[float]]:\n","    all_activations = []\n","    all_firings = []\n","\n","    for tokens in tokens_arr:\n","      activations = []\n","      firings = []\n","\n","      for j, token in enumerate(tokens):\n","        if token not in self.activating_tokens:\n","          activations.append(0)\n","          firings.append(0 > self.activation_threshold)\n","          continue\n","        \n","        token_seq = tokens[max(0, j - self.prior_context):j + 1]\n","        activation = self.seq_to_activation.get(\"\".join(token_seq), 0)\n","\n","        activations.append(activation)\n","        firings.append(activation > self.activation_threshold)\n","\n","      all_activations.append(activations)\n","      all_firings.append(firings)\n","\n","    if return_activations:\n","      return all_activations\n","    return all_firings"],"metadata":{"id":"Ks2jHzH4F7Kg"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Run Experiments"],"metadata":{"id":"twG_dU_mHmPd"}},{"cell_type":"markdown","source":["### Token Lookup Baseline"],"metadata":{"id":"qpO4HrOZII0U"}},{"cell_type":"code","source":["evaluate_baseline(TokenPredictor, \"token_recall_baseline\", layers=6, neurons=3072)"],"metadata":{"id":"sKgPryzeGPC0"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["get_summary_stats(f\"{base_path}/neuron_graphs/{model_name}/token_recall_baseline/stats.json\")"],"metadata":{"id":"K6z_1z6VIIUh"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### NGram Baseline"],"metadata":{"id":"m8C6SFqQILdT"}},{"cell_type":"code","source":["evaluate_baseline(NGramBaseline, \"ngram_baseline_5n\", layers=6, neurons=3072, prior_context=5, activation_threshold=0.01)"],"metadata":{"id":"ki_vZzLFIOSU"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["get_summary_stats(f\"{base_path}/neuron_graphs/{model_name}/ngram_baseline_5n/stats.json\")"],"metadata":{"id":"1p03I4t_HsIi"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### N2G"],"metadata":{"id":"Of9j6yHRIamH"}},{"cell_type":"code","source":["%%time\n","\n","import random\n","\n","random.seed(0)\n","\n","layers = 6\n","neurons = 3072\n","all_neuron_indices = [i for i in range(neurons)]\n","\n","neuron_store = NeuronStore(f\"{base_path}/neuron_graphs/{model_name}/neuron_store.json\")\n","\n","folder_name = f\"n2g\"\n","\n","folder_path = os.path.join(base_path, f\"neuron_graphs/{model_name}/{folder_name}\")\n","\n","if not os.path.exists(folder_path):\n","  print(\"Making\", folder_path)\n","  os.mkdir(folder_path)\n","\n","if os.path.exists(f\"{folder_path}/stats.json\"):\n","  with open(f\"{folder_path}/stats.json\") as ifh:\n","    all_stats = json.load(ifh)\n","\n","else:\n","  all_stats = {}\n","\n","for layer in range(layers):\n","  chosen_neuron_indices = all_neuron_indices\n","  all_stats[layer] = {}\n","  for i, neuron in enumerate(chosen_neuron_indices):\n","    print(f\"{layer=} {neuron=}\")\n","    try:\n","      stats = train_and_eval(model, layer, neuron, n=5, max_train_size=None, train_proportion=0.5, max_eval_size=None, activation_threshold=0.5, fire_threshold=0.5, importance_threshold=0.75, folder_name=folder_name, neuron_store=neuron_store)\n","      \n","      all_stats[layer][neuron] = stats\n","\n","      if i % 10 == 0:\n","        neuron_store.save()\n","        with open(f\"{folder_path}/stats.json\", \"w\") as ofh:\n","          json.dump(all_stats, ofh, indent=2)\n","\n","    except Exception as e:\n","      print(e)\n","      print(\"Failed\")\n","\n","neuron_store.save()\n","with open(f\"{folder_path}/stats.json\", \"w\") as ofh:\n","  json.dump(all_stats, ofh, indent=2)"],"metadata":{"id":"cz5LLNWYIfb9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["get_summary_stats(f\"{base_path}/neuron_graphs/{model_name}/n2g/stats.json\")"],"metadata":{"id":"d_avWTh1JxVK"},"execution_count":null,"outputs":[]}]}