{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b339403",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as mcolors\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "import transformer_lens\n",
    "import transformer_lens.utils as utils\n",
    "from transformer_lens.hook_points import (\n",
    "    HookPoint,\n",
    ")  # Hooking utilities\n",
    "from transformer_lens import HookedTransformer\n",
    "import transformer_lens.patching as patching\n",
    "from fancy_einsum import einsum\n",
    "\n",
    "from functools import partial\n",
    "import scipy as sp\n",
    "\n",
    "from copy import deepcopy\n",
    "\n",
    "from ioi_dataset import IOIDataset\n",
    "import networkx as nx\n",
    "import einops\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "import utils\n",
    "import scipy as sp\n",
    "import scipy.cluster as cl\n",
    "\n",
    "import h5py\n",
    "import matplotlib.lines as mlines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f4883fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModelContext:\n",
    "    \n",
    "    def __init__(self, model_name, model_family, device, num_prompts):\n",
    "        self.model_name = model_name\n",
    "        self.model_family = model_family\n",
    "        self.device = device\n",
    "        # Loading the model with processing; fold_ln = True, center_writing_weights = True\n",
    "        self.model = HookedTransformer.from_pretrained(model_name, device=device)\n",
    "        if model_family == \"gemma\":\n",
    "            prepend_bos = True\n",
    "        else:\n",
    "            prepend_bos = False\n",
    "        self.ioi_dataset = IOIDataset(\n",
    "            model_family=self.model_family,\n",
    "            prompt_type=\"mixed\",\n",
    "            N=num_prompts,\n",
    "            tokenizer=self.model.tokenizer,\n",
    "            prepend_bos=prepend_bos,\n",
    "            seed=0,\n",
    "            device=device)\n",
    "        \n",
    "        # run on this set of prompts\n",
    "        self.logits, self.cache = self.model.run_with_cache(self.ioi_dataset.toks)\n",
    "\n",
    "        # This creates the keys , which have the individual attention heads outputs.\n",
    "        self.cache.compute_head_results()\n",
    "\n",
    "        self.ALL_AHS = [(i, j) for i in range(self.model.cfg.n_layers) for j in range(self.model.cfg.n_heads)]\n",
    "        self.d_model = self.model.cfg.d_model\n",
    "\n",
    "        self.svs_used_u = {}\n",
    "        self.svs_used_v = {}\n",
    "        self.dfs_i = {}\n",
    "        self.dfs_j = {}\n",
    "\n",
    "        self.mean_ips = {}\n",
    "        self.std_ips = {}\n",
    "\n",
    "    def get_svds(self):\n",
    "        self.U, self.S, self.VT = utils.get_omega_decomposition_all_ahs(self.model, self.model_name, new_defn_omega=True) \n",
    "\n",
    "    def trace_prompts(self, prompt_list, firing_criteria = 'threshold', attn_thresh = 0.5, trace_mlp = False, use_svs = True):\n",
    "        '''\n",
    "        Trace the prompts through the model, and store the singular vectors used in each firing.\n",
    "        Default is to use firings where attn > 1/2, and to use intersection method to find the orhogonal slices.\n",
    "        Args:\n",
    "            prompt_list: list of prompt ids to trace\n",
    "            firing_criteria: 'threshold' or '1/n'\n",
    "            attn_thresh: threshold for attention firing if firing_criteria is 'threshold'\n",
    "            trace_mlp: whether to trace the MLPs\n",
    "            use_svs: whether to use singular vectors for decomposition\n",
    "\n",
    "        This is based on the first half of the code in utils/__trace_firing_optimized_rope_new_defn_omega() - don't need the tracing part, just the svd part.\n",
    "        '''\n",
    "        frac_contrib_thresh = 1.0\n",
    "        candidates = []\n",
    "        for prompt_id in prompt_list:\n",
    "            for layer in range(0, self.model.cfg.n_layers):\n",
    "                for ah_idx in range(self.model.cfg.n_heads):\n",
    "                    # skipping dest = 0; special case where contrib can be negative\n",
    "                    for dest_token in range(1, self.ioi_dataset.word_idx[\"end\"][prompt_id]+1):\n",
    "                        for src_token in range(0, dest_token+1):\n",
    "                            # Dynamic threshold\n",
    "                            if firing_criteria == \"1/n\":\n",
    "                                attn_thresh = 1/(dest_token+1) # n in this case is the number of src_tokens\n",
    "\n",
    "                            # We cannot trace firings with attention score < 1/n\n",
    "                            if firing_criteria == \"threshold\" and attn_thresh < 1/(dest_token+1):\n",
    "                                continue\n",
    "                                \n",
    "                            # did the attention head fire on this source/dest combination?\n",
    "                            if self.cache[f\"blocks.{layer}.attn.hook_pattern\"][prompt_id, ah_idx, dest_token, src_token].item() < attn_thresh:\n",
    "                                continue\n",
    "\n",
    "                            # NOT skipping punct token\n",
    "                            #if src_token == ioi_dataset.word_idx[\"punct\"][prompt_id].item():\n",
    "                            #    continue\n",
    "\n",
    "                            candidates.append((prompt_id, layer, ah_idx, dest_token, src_token))\n",
    "        \n",
    "        for prompt_id, layer, ah_idx, dest_token, src_token in tqdm(candidates, total=len(candidates)):\n",
    "            X = self.cache[f\"blocks.{layer}.ln1.hook_normalized\"][prompt_id, :, :] #Float[Tensor, 'n_tokens d_model'] \n",
    "\n",
    "            if self.model_name == \"gemma-2-2b\":\n",
    "                df_decomp_i, df_decomp_j = utils.get_components_used_comparative_no_bias(X, src_token, dest_token, layer, \n",
    "                                                                            ah_idx, self.U, self.S, self.VT, \n",
    "                                                                            self.model_name, self.device)\n",
    "            else:\n",
    "                df_decomp_i, df_decomp_j = utils.get_components_used_comparative_new_defn(X, src_token, dest_token, layer, \n",
    "                                                                            ah_idx, self.U, self.S, self.VT, \n",
    "                                                                            self.model_name, self.device)\n",
    "\n",
    "            # df_decomp_i will be None when contribution comes from the bias term c_1 and can't be traced\n",
    "            # see last paragraph of Appendix A\n",
    "            if df_decomp_i is not None:\n",
    "                # Decomposing on x_i\n",
    "                if use_svs:\n",
    "                    last_sv_idx = np.where(df_decomp_i['sv_frac_contribution'].values.round(5) >= frac_contrib_thresh)[0][0]\n",
    "                else:\n",
    "                    last_sv_idx = self.model.cfg.d_head # all SVs\n",
    "                svs_decomp_i = df_decomp_i.iloc[:last_sv_idx+1].idx.astype(int).values\n",
    "                self.svs_used_u[(prompt_id, layer, ah_idx, dest_token, src_token)] = svs_decomp_i\n",
    "                self.dfs_i[(prompt_id, layer, ah_idx, dest_token, src_token)] = df_decomp_i\n",
    "\n",
    "            # Decomposing on x_j\n",
    "            if df_decomp_j is not None:\n",
    "                if use_svs:\n",
    "                    last_sv_idx = np.where(df_decomp_j['sv_frac_contribution'].values.round(5) >= frac_contrib_thresh)[0][0]\n",
    "                else:\n",
    "                    last_sv_idx = self.model.cfg.d_head-1 # all SVs\n",
    "                svs_decomp_j = df_decomp_j.iloc[:last_sv_idx+1].idx.astype(int).values\n",
    "                self.svs_used_v[(prompt_id, layer, ah_idx, dest_token, src_token)] = svs_decomp_j\n",
    "                self.dfs_j[(prompt_id, layer, ah_idx, dest_token, src_token)] = df_decomp_j\n",
    "    \n",
    "    def get_contrib_u_signals(self, df, svs, SVecs, x):\n",
    "        # compute the signal used in this firing\n",
    "        # weighting each singular vector by its contribution\n",
    "        retvec = torch.zeros(self.d_model)\n",
    "        for sv in svs:\n",
    "            df_row = df[df['idx'] == sv]\n",
    "            x_i_ip = utils.apply_projection(SVecs[:, sv].reshape(-1, 1), x).T @ x\n",
    "            #retvec += SVecs[:, sv] * torch.Tensor(np.sign(x_i_ip) * df_row['contrib'].values)\n",
    "            retvec += SVecs[:, sv] * torch.Tensor(df_row['contrib'].values)\n",
    "        return retvec / torch.linalg.norm(retvec)\n",
    "    \n",
    "    def get_contrib_v_signals(self, df, svs, SVecs, x):\n",
    "        # compute the signal used in this firing\n",
    "        # weighting each singular vector by its contribution (not a projection)\n",
    "        retvec = torch.zeros(self.d_model)\n",
    "        for sv in svs:\n",
    "            df_row = df[df['idx'] == sv]\n",
    "            x_j_ip = utils.apply_projection(SVecs[:, sv].reshape(-1, 1), x).T @ x\n",
    "            #retvec += SVecs[:, sv] * torch.Tensor(np.sign(x_j_ip) * df_row['contrib'].values)\n",
    "            retvec += SVecs[:, sv] * torch.Tensor(df_row['contrib'].values)\n",
    "        return retvec / torch.linalg.norm(retvec)\n",
    "\n",
    "    def get_u_signals(self, df, svs, SVecs, x):\n",
    "        # compute the signal used in this firing\n",
    "        # as the projection of the residual on the signal subspace \n",
    "        retvec = torch.zeros(self.d_model)\n",
    "        for sv in svs:\n",
    "            retvec += utils.apply_projection(SVecs[:, sv].reshape(-1, 1), x)\n",
    "        return retvec / torch.linalg.norm(retvec)\n",
    "    \n",
    "    def get_v_signals(self, df, svs, SVecs, x):\n",
    "        # compute the signal used in this firing\n",
    "        # as the projection of the residual on the signal subspace\n",
    "        retvec = torch.zeros(self.d_model)\n",
    "        for sv in svs:\n",
    "            retvec += utils.apply_projection(SVecs[:, sv].reshape(-1, 1), x)\n",
    "        return retvec / torch.linalg.norm(retvec)\n",
    "    \n",
    "    def compute_signals(self):\n",
    "        # compile the set of all signals used across all firings of the model \n",
    "        self.u_signals = []\n",
    "        self.v_signals = []\n",
    "        self.contrib_u_signals = [] \n",
    "        self.contrib_v_signals = []\n",
    "        # firings for which we can trace the destination token\n",
    "        for key in tqdm(self.svs_used_u.keys(), total=len(self.svs_used_u), desc=\"Destination signals\"):\n",
    "            prompt_id, layer, ah_idx, dest_token, src_token = key\n",
    "            if (self.model_name == 'gpt2-small'):\n",
    "                diff = -1\n",
    "            else:\n",
    "                diff = dest_token - src_token\n",
    "            #X = deepcopy(self.cache[f\"blocks.{layer}.ln1.hook_normalized\"][prompt_id, :, :])\n",
    "            X = self.cache[f\"blocks.{layer}.ln1.hook_normalized\"][prompt_id, :, :] # no deepcopy needed here\n",
    "            if self.model_name == \"gemma-2-2b\":\n",
    "                contrib_u_signal = self.get_contrib_u_signals(self.dfs_i[key], self.svs_used_u[key], self.U[layer, ah_idx, diff], X[dest_token, :])\n",
    "                u_signal = self.get_u_signals(self.dfs_i[key], self.svs_used_u[key], self.U[layer, ah_idx, diff], X[dest_token, :])\n",
    "            else:\n",
    "                contrib_u_signal = self.get_contrib_u_signals(self.dfs_i[key], self.svs_used_u[key], self.U['d'][layer, ah_idx, diff], X[dest_token, :])\n",
    "                u_signal = self.get_u_signals(self.dfs_i[key], self.svs_used_u[key], self.U['d'][layer, ah_idx, diff], X[dest_token, :])\n",
    "        \n",
    "            self.u_signals.append(u_signal)\n",
    "            self.contrib_u_signals.append(contrib_u_signal)\n",
    "        # firings for which we can trace the source token\n",
    "        for key in tqdm(self.svs_used_v.keys(), total=len(self.svs_used_v), desc=\"Destination signals\"):\n",
    "            prompt_id, layer, ah_idx, dest_token, src_token = key\n",
    "            if (self.model_name == 'gpt2-small'):\n",
    "                diff = -1\n",
    "            else:\n",
    "                diff = dest_token - src_token\n",
    "            #X = deepcopy(self.cache[f\"blocks.{layer}.ln1.hook_normalized\"][prompt_id, :, :])\n",
    "            X = self.cache[f\"blocks.{layer}.ln1.hook_normalized\"][prompt_id, :, :] # no deepcopy needed here\n",
    "            if self.model_name == \"gemma-2-2b\":\n",
    "                contrib_v_signal = self.get_contrib_v_signals(self.dfs_j[key], self.svs_used_v[key], self.VT[layer, ah_idx, diff].T, X[src_token, :])\n",
    "                v_signal = self.get_v_signals(self.dfs_j[key], self.svs_used_v[key], self.VT[layer, ah_idx, diff].T, X[src_token, :])\n",
    "            else:\n",
    "                contrib_v_signal = self.get_contrib_v_signals(self.dfs_j[key], self.svs_used_v[key], self.VT['s'][layer, ah_idx, diff].T, X[src_token, :])\n",
    "                v_signal = self.get_v_signals(self.dfs_j[key], self.svs_used_v[key], self.VT['s'][layer, ah_idx, diff].T, X[src_token, :])\n",
    "            self.v_signals.append(v_signal)\n",
    "            self.contrib_v_signals.append(contrib_v_signal)\n",
    "        \n",
    "        \n",
    "        self.u_signals = np.array([x.numpy() for x in self.u_signals])\n",
    "        self.v_signals = np.array([x.numpy() for x in self.v_signals])\n",
    "        self.contrib_u_signals = np.array([x.numpy() for x in self.contrib_u_signals])\n",
    "        self.contrib_v_signals = np.array([x.numpy() for x in self.contrib_v_signals])\n",
    "        \n",
    "    def compute_control_signals(self, similarity_threshold = 0.75):\n",
    "        self.ctrl_sigs = {}\n",
    "        self.ctrld_heads = {}\n",
    "        test_signal_candidates = [self.u_signals, self.v_signals, \n",
    "                              self.contrib_u_signals, self.contrib_v_signals]\n",
    "        test_signal_types = ['u_signals', 'v_signals', 'contrib_u_signals', 'contrib_v_signals']\n",
    "        # for each of the four signal types, compute the control signals\n",
    "        # and the controlled heads\n",
    "        # controlled heads are those that have a consistent control signal\n",
    "        # across all firings\n",
    "        for test_signals, signal_type in tqdm(zip(test_signal_candidates, test_signal_types), total=len(test_signal_candidates)):\n",
    "            # for each head, determine whether the head is \"controlled\", ie\n",
    "            # uses a single predominant signal for firing on the default\n",
    "            # token(s), and compute an estimate of that signal\n",
    "            if signal_type in ['contrib_u_signals', 'u_signals']:\n",
    "                firings = list(self.svs_used_u.keys())\n",
    "            else:\n",
    "                firings = list(self.svs_used_v.keys())\n",
    "            vec_sets = {}\n",
    "            mean_ips = np.zeros((self.model.cfg.n_layers, self.model.cfg.n_heads))\n",
    "            std_ips = np.zeros((self.model.cfg.n_layers, self.model.cfg.n_heads))\n",
    "            n_ips = np.zeros((self.model.cfg.n_layers, self.model.cfg.n_heads))\n",
    "            for (test_layer, test_ah_idx) in self.ALL_AHS:\n",
    "                # get all zero-firings for this head\n",
    "                vec_sets[(test_layer, test_ah_idx)] = []\n",
    "                for key, sig in zip(firings, test_signals):\n",
    "                    prompt_id, layer, ah_idx, dest_token, src_token = key\n",
    "                    # if this firing is for this head\n",
    "                    if (layer == test_layer) and (ah_idx == test_ah_idx):\n",
    "                        # we have specific rules for default tokens in each model\n",
    "                        if self.model_name == 'gpt2-small':\n",
    "                            if (src_token == 0):\n",
    "                                vec_sets[(test_layer, test_ah_idx)].append(sig)\n",
    "                        elif self.model_name == 'EleutherAI/pythia-160m' or self.model_name == \"gemma-2-2b\":\n",
    "                            if src_token in [0, self.ioi_dataset.word_idx[\"punct\"][prompt_id].item()]:\n",
    "                                vec_sets[(test_layer, test_ah_idx)].append(sig)\n",
    "                        else:\n",
    "                            raise ValueError('default tokens are not yet configured for this model')\n",
    "                vec_sets[(test_layer, test_ah_idx)] = np.array(vec_sets[(test_layer, test_ah_idx)])\n",
    "                # we now have all the signals used in zero-firings for this head\n",
    "                # next, compute statistics used to identify whether head has consistent control signals\n",
    "                # specifically, compute the average mean cosine distance between each signal for this head\n",
    "                n_ips[test_layer, test_ah_idx] = vec_sets[test_layer, test_ah_idx].shape[0]\n",
    "                if n_ips[test_layer, test_ah_idx] > 1:\n",
    "                    dists = sp.spatial.distance.pdist(vec_sets[test_layer, test_ah_idx], metric='cosine')\n",
    "                    mean_ips[test_layer, test_ah_idx] = np.mean(dists)\n",
    "                    std_ips[test_layer, test_ah_idx] = np.std(dists)\n",
    "                else:\n",
    "                    mean_ips[test_layer, test_ah_idx] = np.nan \n",
    "                    std_ips[test_layer, test_ah_idx] = np.nan \n",
    "            # some heads are controlled (have consistent control signals), some not\n",
    "            controlled_heads = []\n",
    "            control_signals = []\n",
    "            for (test_layer, test_ah_idx) in self.ALL_AHS:\n",
    "                if (mean_ips[test_layer, test_ah_idx] < similarity_threshold):\n",
    "                    controlled_heads.append((test_layer, test_ah_idx))\n",
    "                    control_signals.append(np.mean(vec_sets[(test_layer, test_ah_idx)], axis = 0))\n",
    "            # control signals has one (mean) control signal per controlled head\n",
    "            self.ctrl_sigs[signal_type] = np.array(control_signals)\n",
    "            self.ctrld_heads[signal_type] = controlled_heads\n",
    "            self.mean_ips[signal_type] = mean_ips\n",
    "            self.std_ips[signal_type] = std_ips\n",
    "\n",
    "    def load_cached_tracing(self, prompt_list):\n",
    "        for prompt_id in prompt_list:\n",
    "            with h5py.File(f'control_signals_cache/dicts_{self.model_name}_{prompt_id}.hdf5', 'r') as f:\n",
    "                for key in f.keys():\n",
    "                    dict_type, dict_key = key.split(\"_\")\n",
    "                    dict_key = eval(dict_key)\n",
    "                    if dict_type == \"dfs-i\":\n",
    "                        df = pd.DataFrame(f[key][:], columns=['idx', 'singular_value', 'x_i_ip', 'x_j_ip', 'denom_avg', 'product', 'contrib', 'sv_frac_contribution'])\n",
    "                        df[\"idx\"] = df[\"idx\"].astype(int)            \n",
    "                        self.dfs_i[dict_key] = df\n",
    "                    elif dict_type == \"dfs-j\":\n",
    "                        df = pd.DataFrame(f[key][:], columns=['idx', 'singular_value', 'x_i_ip', 'x_j_ip', 'denom_avg', 'product', 'contrib', 'sv_frac_contribution'])\n",
    "                        df[\"idx\"] = df[\"idx\"].astype(int)            \n",
    "                        self.dfs_j[dict_key] = df\n",
    "                    elif dict_type == \"svs-used-u\":\n",
    "                        self.svs_used_u[dict_key] = f[key][:]\n",
    "                    elif dict_type == \"svs-used-v\":\n",
    "                        self.svs_used_v[dict_key] = f[key][:]\n",
    "                    else:\n",
    "                        print(f\"Error in the dict_type={dict_type}. Key={key}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92593c54",
   "metadata": {},
   "outputs": [],
   "source": [
    "def main_clusters(cluster_id):\n",
    "    '''\n",
    "    Given a list of cluster ids, return the sorted list of cluster ids\n",
    "    sorted by the number of elements in each cluster.'''\n",
    "    cluster_sizes = {}\n",
    "    for id in cluster_id:\n",
    "        if id not in cluster_sizes:\n",
    "            cluster_sizes[id] = 1\n",
    "        else:\n",
    "            cluster_sizes[id] += 1\n",
    "    return sorted(cluster_sizes, key = cluster_sizes.get, reverse=True)\n",
    "\n",
    "def f_counts(fset): \n",
    "    '''\n",
    "    Given a set of firings, return the counts of each firing.\n",
    "    fset: set of firings\n",
    "    '''\n",
    "    # fset is a list of tuples, where each tuple is (prompt, layer, head, dest_token, src_token)\n",
    "    # f1_cnts is a dictionary where the keys are the firing sources and the values are the counts\n",
    "    # of each firing source\n",
    "    f1_cnts = {}    \n",
    "    for f in fset:\n",
    "        f_src = f[4]\n",
    "        if f_src in f1_cnts:\n",
    "            f1_cnts[f_src] += 1\n",
    "        else:\n",
    "            f1_cnts[f_src] = 1\n",
    "    return f1_cnts\n",
    "\n",
    "def conf_matrix(c_sorted, subject_cluster, targets):\n",
    "    f = []\n",
    "    cm = {'tp': 0, 'fp': 0, 'tn': 0, 'fn': 0}\n",
    "    for c in range(len(c_sorted)):\n",
    "        c_firings = [firings[i] for i in np.where(cluster_id == c_sorted[c])[0]]\n",
    "        f.append(c_firings)\n",
    "        if subject_cluster == c:\n",
    "            for firing in c_firings:\n",
    "                if firing[4] in targets:\n",
    "                    cm['tp'] += 1\n",
    "                else:\n",
    "                    cm['fp'] += 1\n",
    "        else:\n",
    "            for firing in c_firings:\n",
    "                if firing[4] == 0:\n",
    "                    cm['fn'] += 1\n",
    "                else:\n",
    "                    cm['tn'] += 1\n",
    "    return cm, f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a92ee47",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_set = [1, 2, 3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "420ae2af",
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt2s = ModelContext('gpt2-small', 'gpt2', 'cpu', 16)\n",
    "gpt2s.get_svds()\n",
    "gpt2s.trace_prompts(prompt_set, firing_criteria='threshold', attn_thresh = 0.4) # attn_thresh = 0.5\n",
    "gpt2s.compute_signals()\n",
    "# similarity_threshold is the threshold for the mean cosine distance between signals\n",
    "# that labels a head as \"controlled\" and assigns it a single signal\n",
    "gpt2s.compute_control_signals(similarity_threshold = 0.5) # no big change between 0.5 and 0.9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0bd10fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "pyth = ModelContext('EleutherAI/pythia-160m', 'pythia', 'cpu', 16)\n",
    "pyth.get_svds()\n",
    "pyth.trace_prompts(prompt_set, firing_criteria='threshold', attn_thresh = 0.4) # attn_thresh = 0.5\n",
    "pyth.compute_signals()\n",
    "# similarity_threshold is the threshold for the mean cosine distance between signals\n",
    "# that labels a head as \"controlled\" and assigns it a single signal\n",
    "pyth.compute_control_signals(similarity_threshold = 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "219ebed2",
   "metadata": {},
   "outputs": [],
   "source": [
    "gemma = ModelContext('gemma-2-2b', 'gemma', 'cpu', 16)\n",
    "gemma.get_svds()\n",
    "gemma.load_cached_tracing(prompt_set)\n",
    "# To fully trace the prompts again:\n",
    "#gemma.trace_prompts(prompt_set, firing_criteria='threshold', attn_thresh = 0.4) # attn_thresh = 0.5\n",
    "gemma.compute_signals()\n",
    "# # similarity_threshold is the threshold for the mean cosine distance between signals\n",
    "# # that labels a head as \"controlled\" and assigns it a single signal\n",
    "gemma.compute_control_signals(similarity_threshold = 0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5f965af",
   "metadata": {},
   "source": [
    "## GPT-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e56da7fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_signals = gpt2s.v_signals\n",
    "cosine_distances = sp.spatial.distance.squareform(sp.spatial.distance.pdist(test_signals, metric='cosine'))\n",
    "Z = cl.hierarchy.linkage(test_signals, 'average', 'cosine', optimal_ordering = False)\n",
    "sns.clustermap(cosine_distances, row_linkage = Z, col_linkage = Z, figsize = (10,10))\n",
    "plt.title('Cosine distances of Source signals');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b285b0c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for average linkage with \"regular\" U signal, a threshold of 0.99 separates the two big clusters\n",
    "# note that it gets much worse at threshold of 1.0 !\n",
    "# for average linkage with \"regular\" C signal, a threshold of 1 separates the two big clusters\n",
    "# note that it gets much worse at threshold of 1.0 !\n",
    "plt.figure()\n",
    "cluster_threshold = .97\n",
    "# cosine_distances = sp.spatial.distance.squareform(sp.spatial.distance.pdist(contrib_u_signals, metric='cosine'))\n",
    "Z_contrib = cl.hierarchy.linkage(test_signals, 'average', 'cosine', optimal_ordering = False)\n",
    "cluster_id = cl.hierarchy.fcluster(Z_contrib, cluster_threshold, criterion = 'distance')\n",
    "c_sorted = main_clusters(cluster_id)\n",
    "firings = list(gpt2s.svs_used_v.keys())\n",
    "f = []\n",
    "for c in range(len(c_sorted)):\n",
    "    f.append([firings[i] for i in np.where(cluster_id == c_sorted[c])[0]])\n",
    "#cl.hierarchy.dendrogram(Z_contrib, color_threshold = cluster_threshold)\n",
    "print('Source tokens in each cluster:')\n",
    "for st in range(len(f)):\n",
    "    print(f_counts(f[st]))\n",
    "# confusion matrix tells us how cluster zero works as a classifier of a zero-firing\n",
    "# conf_matrix(c_sorted, subject_cluster, targets)\n",
    "print(conf_matrix(c_sorted, 0, [0])[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee3b5c58",
   "metadata": {},
   "outputs": [],
   "source": [
    "# find vector that is \"average\" for cluster zero, meaning first cluster in sorted list\n",
    "# assuming the step above has found that cluster zero gives good separation (confusion matrix)\n",
    "cluster_zero_signals = [s for i, s in enumerate(test_signals) if i in np.where(cluster_id == c_sorted[0])[0]]\n",
    "cluster_nonzero_signals = [s for i, s in enumerate(test_signals) if i not in np.where(cluster_id == c_sorted[0])[0]]\n",
    "# compute average cosine dist \n",
    "cosine_distances = sp.spatial.distance.squareform(sp.spatial.distance.pdist(cluster_zero_signals, metric='cosine'))\n",
    "# find vector that has minimum avg cos dist to all others in the cluster\n",
    "average_cos_dist = np.mean(cosine_distances, axis = 1)\n",
    "min_idx = np.where(average_cos_dist == np.min(average_cos_dist))[0][0]\n",
    "min_vec = cluster_zero_signals[min_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bbe6bca",
   "metadata": {},
   "outputs": [],
   "source": [
    "zero_token_signals = np.array([s for s, f in zip(gpt2s.v_signals, gpt2s.svs_used_v) if f[4] == 0])\n",
    "nonzero_token_signals = np.array([s for s, f in zip(gpt2s.v_signals, gpt2s.svs_used_v) if f[4] != 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f3b6ddd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "plt.rc('font', size=8)\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(3, 1.9))\n",
    "sns.kdeplot(zero_token_signals @ min_vec, label = 'zero tokens', bw_adjust = 0.125)\n",
    "sns.kdeplot(nonzero_token_signals @ min_vec, label = 'non-zero tokens', bw_adjust = 0.125)\n",
    "#plt.legend(loc = 'best')\n",
    "plt.legend(loc='lower center', bbox_to_anchor=(0.5, 1), ncol=2);\n",
    "plt.xlabel(\"Inner product\");\n",
    "plt.tight_layout()\n",
    "#plt.title('IP of prototype signal with other signals');\n",
    "plt.savefig(\"figures/control_signals/gpt2-small_control-signals_ip_v-signals.pdf\", bbox_inches='tight', dpi=800);\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e8a9e13",
   "metadata": {},
   "source": [
    "## Pythia"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bb030ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_signals = pyth.v_signals\n",
    "cosine_distances = sp.spatial.distance.squareform(sp.spatial.distance.pdist(test_signals, metric='cosine'))\n",
    "Z = cl.hierarchy.linkage(test_signals, 'average', 'cosine', optimal_ordering = False)\n",
    "sns.clustermap(cosine_distances, row_linkage = Z, col_linkage = Z, figsize = (10,10))\n",
    "plt.title('Cosine distances of Src signals');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b1eedd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pythia: very good separation obtained for v_signals, average linkage, threshold = 0.89\n",
    "plt.figure()\n",
    "cluster_threshold = 0.99\n",
    "Z_contrib = cl.hierarchy.linkage(test_signals, 'average', 'cosine', optimal_ordering = False)\n",
    "cluster_id = cl.hierarchy.fcluster(Z_contrib, cluster_threshold, criterion = 'distance')\n",
    "c_sorted = main_clusters(cluster_id)\n",
    "f = []\n",
    "firings = list(pyth.svs_used_v.keys())\n",
    "for c in range(len(c_sorted)):\n",
    "    f.append([firings[i] for i in np.where(cluster_id == c_sorted[c])[0]])\n",
    "cl.hierarchy.dendrogram(Z_contrib, color_threshold = cluster_threshold)\n",
    "plt.title('Contrib signal defn')\n",
    "for st in range(len(f)):\n",
    "    print(f_counts(f[st]))\n",
    "# note this is only approximately right; we need to apply the punct on a per-prompt basis\n",
    "# potentially inflating false negatives at expense of true negatives\n",
    "# and true positives at expnse of false positives\n",
    "pyth_puncts = [pyth.ioi_dataset.word_idx[\"punct\"][prompt_id] for prompt_id in prompt_set]\n",
    "print(conf_matrix(c_sorted, 0, [0] + pyth_puncts)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01eaab2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# find vector that is \"average\" for cluster zero\n",
    "cluster_zero_signals = [s for i, s in enumerate(test_signals) if i in np.where(cluster_id == c_sorted[0])[0]]\n",
    "cluster_nonzero_signals = [s for i, s in enumerate(test_signals) if i not in np.where(cluster_id == c_sorted[0])[0]]\n",
    "# compute average cosine dist \n",
    "cosine_distances = sp.spatial.distance.squareform(sp.spatial.distance.pdist(cluster_zero_signals, metric='cosine'))\n",
    "# find vector that has minimum avg cos dist to all others in the cluster\n",
    "average_cos_dist = np.mean(cosine_distances, axis = 1)\n",
    "min_idx = np.where(average_cos_dist == np.min(average_cos_dist))[0][0]\n",
    "min_vec = cluster_zero_signals[min_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd47e675",
   "metadata": {},
   "outputs": [],
   "source": [
    "zero_token_signals = np.array([s for s, f in zip(test_signals, firings) if f[4] in [0] + pyth_puncts])\n",
    "nonzero_token_signals = np.array([s for s, f in zip(test_signals, firings) if f[4] not in [0] + pyth_puncts])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04f57304",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "plt.rc('font', size=8)\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(3, 1.9))\n",
    "sns.kdeplot(zero_token_signals @ min_vec, label = 'zero tokens', bw_adjust = 0.125)\n",
    "sns.kdeplot(nonzero_token_signals @ min_vec, label = 'non-zero tokens', bw_adjust = 0.125)\n",
    "#plt.legend(loc = 'best')\n",
    "plt.legend(loc='lower center', bbox_to_anchor=(0.5, 1), ncol=2);\n",
    "plt.xlabel(\"Inner product\");\n",
    "plt.tight_layout()\n",
    "#plt.title('IP of prototype signal with other signals');\n",
    "plt.savefig(\"figures/control_signals/pythia-160m_control-signals_ip_v-signals.pdf\", bbox_inches='tight', dpi=800);\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f96ea178",
   "metadata": {},
   "source": [
    "## Gemma-2 2B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75b196c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_signals = gemma.v_signals\n",
    "cosine_distances = sp.spatial.distance.squareform(sp.spatial.distance.pdist(test_signals, metric='cosine'))\n",
    "Z = cl.hierarchy.linkage(test_signals, 'average', 'cosine', optimal_ordering = False)\n",
    "sns.clustermap(cosine_distances, row_linkage = Z, col_linkage = Z, figsize = (10,10))\n",
    "plt.title('Cosine distances of Source signals');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10306ec8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# gemma:\n",
    "plt.figure()\n",
    "cluster_threshold = 0.99\n",
    "Z_contrib = cl.hierarchy.linkage(test_signals, 'average', 'cosine', optimal_ordering = False)\n",
    "cluster_id = cl.hierarchy.fcluster(Z_contrib, cluster_threshold, criterion = 'distance')\n",
    "c_sorted = main_clusters(cluster_id)\n",
    "f = []\n",
    "firings = list(gemma.svs_used_v.keys())\n",
    "for c in range(len(c_sorted)):\n",
    "    f.append([firings[i] for i in np.where(cluster_id == c_sorted[c])[0]])\n",
    "cl.hierarchy.dendrogram(Z_contrib, color_threshold = cluster_threshold)\n",
    "plt.title('Contrib signal defn')\n",
    "for st in range(len(f)):\n",
    "    print(f_counts(f[st]))\n",
    "# note this is only approximately right; we need to apply the punct on a per-prompt basis\n",
    "# potentially inflating false negatives at expense of true negatives\n",
    "# and true positives at expnse of false positives\n",
    "gemma_puncts = [gemma.ioi_dataset.word_idx[\"punct\"][prompt_id] for prompt_id in prompt_set]\n",
    "print(conf_matrix(c_sorted, 0, [0] + gemma_puncts)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14334992",
   "metadata": {},
   "outputs": [],
   "source": [
    "# find vector that is \"average\" for cluster zero\n",
    "cluster_zero_signals = [s for i, s in enumerate(test_signals) if i in np.where(cluster_id == c_sorted[0])[0]]\n",
    "cluster_nonzero_signals = [s for i, s in enumerate(test_signals) if i not in np.where(cluster_id == c_sorted[0])[0]]\n",
    "# compute average cosine dist \n",
    "cosine_distances = sp.spatial.distance.squareform(sp.spatial.distance.pdist(cluster_zero_signals, metric='cosine'))\n",
    "# find vector that has minimum avg cos dist to all others in the cluster\n",
    "average_cos_dist = np.mean(cosine_distances, axis = 1)\n",
    "min_idx = np.where(average_cos_dist == np.min(average_cos_dist))[0][0]\n",
    "min_vec = cluster_zero_signals[min_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dfbe378",
   "metadata": {},
   "outputs": [],
   "source": [
    "zero_token_signals = np.array([s for s, f in zip(test_signals, firings) if f[4] in [0] + gemma_puncts])\n",
    "nonzero_token_signals = np.array([s for s, f in zip(test_signals, firings) if f[4] not in [0] + gemma_puncts])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e74661f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "plt.rc('font', size=8)\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(3, 1.9))\n",
    "sns.kdeplot(zero_token_signals @ min_vec, label = 'zero tokens', bw_adjust = 0.125)\n",
    "sns.kdeplot(nonzero_token_signals @ min_vec, label = 'non-zero tokens', bw_adjust = 0.125)\n",
    "#plt.legend(loc = 'best')\n",
    "plt.legend(loc='lower center', bbox_to_anchor=(0.5, 1), ncol=2);\n",
    "plt.xlabel(\"Inner product\");\n",
    "plt.tight_layout()\n",
    "#plt.title('IP of prototype signal with other signals');\n",
    "plt.savefig(\"figures/control_signals/gemma-2-2b_control-signals_ip_v-signals.pdf\", bbox_inches='tight', dpi=800);\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa0bd62e",
   "metadata": {},
   "source": [
    "## GPT2 at the head level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aadb2d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "plt.rc('font', size=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae664a3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, 2, figsize = (5, 3.5))\n",
    "\n",
    "# Has to be set manually after seeing the plot\n",
    "n_clusters_appearing_dendogram = {\n",
    "    0: 6,\n",
    "    1: 9\n",
    "}\n",
    "\n",
    "signal_types = {0: \"u_signals\", 1: \"v_signals\"}\n",
    "\n",
    "for t in range(2): # Your loop\n",
    "    control_signals = gpt2s.ctrl_sigs[signal_types[t]]\n",
    "    current_Z = cl.hierarchy.linkage(control_signals, method='average', metric='cosine', optimal_ordering=False)\n",
    "    n_samples = current_Z.shape[0] + 1\n",
    "    \n",
    "    current_cluster_threshold = 0.6\n",
    "\n",
    "    flat_cluster_ids_for_samples = cl.hierarchy.fcluster(current_Z, current_cluster_threshold, criterion='distance')\n",
    "    c_sorted_list = list(main_clusters(flat_cluster_ids_for_samples))\n",
    "\n",
    "    # --- Prepare Color Tuples First ---\n",
    "    color_tuples_for_dendro = []\n",
    "    if len(c_sorted_list) > 0:\n",
    "        color_tuples_for_dendro = list(sns.color_palette(\"deep\", len(c_sorted_list)+1))[1:] # List of RGB tuples\n",
    "        #color_tuples_for_dendro[0] = (0, 0, 1) #(1.0, 1.0, 1.0) # White RGB tuple\n",
    "        #color_tuples_for_dendro = [(0, 0, 1)] * len(c_sorted_list)\n",
    "    \n",
    "    above_threshold_color_tuple = (0.5, 0.5, 0.5, 0.3) # Gray RGBA tuple\n",
    "\n",
    "    # --- Convert Color Tuples to Hex Strings ---\n",
    "    hex_dendrogram_link_colors = []\n",
    "    for rgb_tuple in color_tuples_for_dendro:\n",
    "        # sns.color_palette for \"deep\" returns RGB, so keep_alpha=False\n",
    "        hex_dendrogram_link_colors.append(mcolors.to_hex(rgb_tuple, keep_alpha=False)) \n",
    "    \n",
    "    # above_threshold_color_tuple has alpha, so keep_alpha=True\n",
    "    hex_above_threshold_color = mcolors.to_hex(above_threshold_color_tuple, keep_alpha=True)\n",
    "\n",
    "    # --- Define the custom link_color_func using Hex Colors ---\n",
    "    _Z_for_func = current_Z\n",
    "    _n_samples_for_func = n_samples\n",
    "    _cluster_threshold_for_func = current_cluster_threshold\n",
    "    _flat_ids_for_samples_for_func = flat_cluster_ids_for_samples\n",
    "    _c_sorted_list_for_func = c_sorted_list\n",
    "    _dendro_colors_hex_for_func = hex_dendrogram_link_colors # Use hex colors\n",
    "    _above_color_hex_for_func = hex_above_threshold_color   # Use hex color\n",
    "\n",
    "    def custom_dendrogram_link_color_func(link_k):\n",
    "        link_merge_idx = link_k - _n_samples_for_func\n",
    "        link_distance = _Z_for_func[link_merge_idx, 2]\n",
    "\n",
    "        if link_distance >= _cluster_threshold_for_func:\n",
    "            return _above_color_hex_for_func # Return hex string\n",
    "        else:\n",
    "            current_node = link_k\n",
    "            while current_node >= _n_samples_for_func:\n",
    "                current_node = int(_Z_for_func[current_node - _n_samples_for_func, 0])\n",
    "            one_leaf_idx = current_node\n",
    "            \n",
    "            target_flat_cid = _flat_ids_for_samples_for_func[one_leaf_idx]\n",
    "            \n",
    "            try:\n",
    "                cmap_idx = _c_sorted_list_for_func.index(target_flat_cid)\n",
    "                if 0 <= cmap_idx < len(_dendro_colors_hex_for_func):\n",
    "                    return _dendro_colors_hex_for_func[cmap_idx] # Return hex string\n",
    "                else: \n",
    "                    return _above_color_hex_for_func\n",
    "            except ValueError:\n",
    "                return _above_color_hex_for_func\n",
    "\n",
    "    # --- Plotting the Dendrogram ---\n",
    "    cl.hierarchy.dendrogram(\n",
    "        current_Z,\n",
    "        link_color_func=custom_dendrogram_link_color_func,\n",
    "        ax=axs[0, t],\n",
    "        above_threshold_color=_above_color_hex_for_func # Consistent with link_color_func returns\n",
    "    )\n",
    "    axs[0, t].set_xticks([])\n",
    "    axs[0, t].set_ylabel(\"Cosine similarity\")\n",
    "    axs[0, t].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])\n",
    "    axs[0, t].set_yticklabels([1.0, 0.8, 0.6, 0.4, 0.2, 0.0])\n",
    "\n",
    "    # --- Your existing Heatmap code (slightly adapted for clarity) ---\n",
    "    headmap = np.zeros((12, 12), dtype=int) - 1\n",
    "    for c_map_idx, actual_cluster_id_val in enumerate(c_sorted_list):\n",
    "        for i in np.where(flat_cluster_ids_for_samples == actual_cluster_id_val)[0]:\n",
    "            layer, ah_idx = gpt2s.ctrld_heads[signal_types[t]][i]\n",
    "            headmap[layer, ah_idx] = c_map_idx\n",
    "\n",
    "    # Prepare the exact cmap for the heatmap (using original tuples is fine for seaborn)\n",
    "    if len(c_sorted_list) > 0:\n",
    "        heatmap_specific_color_tuples = list(sns.color_palette(\"deep\", len(c_sorted_list)))\n",
    "        heatmap_specific_color_tuples[0] = (1.0, 1.0, 1.0)\n",
    "        final_heatmap_cmap = heatmap_specific_color_tuples\n",
    "    else:\n",
    "        final_heatmap_cmap = None \n",
    "\n",
    "    for c in range(n_clusters_appearing_dendogram[t]+1, len(final_heatmap_cmap)):\n",
    "        final_heatmap_cmap[c] = (1, 1, 1) # setting every cluster not appearing in the dendogram to white\n",
    "\n",
    "    sns.heatmap(headmap, ax=axs[1, t], cmap=final_heatmap_cmap, cbar=False, annot=False)\n",
    "    axs[1, t].set_ylabel(\"Layer\")\n",
    "    axs[1, t].set_xlabel(\"Attention Head Index\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"figures/control_signals/gpt-2-small_control-signals.pdf\", bbox_inches='tight', dpi=800);\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c860484c",
   "metadata": {},
   "source": [
    "## Pythia at the Head level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "699df500",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, 2, figsize = (5, 3.5))\n",
    "\n",
    "# Has to be set manually after seeing the plot\n",
    "n_clusters_appearing_dendogram = {\n",
    "    0: 6,\n",
    "    1: 9\n",
    "}\n",
    "\n",
    "signal_types = {0: \"u_signals\", 1: \"v_signals\"}\n",
    "\n",
    "for t in range(2): # Your loop\n",
    "    control_signals = pyth.ctrl_sigs[signal_types[t]]\n",
    "    current_Z = cl.hierarchy.linkage(control_signals, method='average', metric='cosine', optimal_ordering=False)\n",
    "    n_samples = current_Z.shape[0] + 1\n",
    "    \n",
    "    current_cluster_threshold = 0.7\n",
    "\n",
    "    flat_cluster_ids_for_samples = cl.hierarchy.fcluster(current_Z, current_cluster_threshold, criterion='distance')\n",
    "    c_sorted_list = list(main_clusters(flat_cluster_ids_for_samples))\n",
    "\n",
    "    # --- Prepare Color Tuples First ---\n",
    "    color_tuples_for_dendro = []\n",
    "    if len(c_sorted_list) > 0:\n",
    "        color_tuples_for_dendro = list(sns.color_palette(\"deep\", len(c_sorted_list)+1))[1:] # List of RGB tuples\n",
    "        #color_tuples_for_dendro[0] = (0, 0, 1) #(1.0, 1.0, 1.0) # White RGB tuple\n",
    "        #color_tuples_for_dendro = [(0, 0, 1)] * len(c_sorted_list)\n",
    "    \n",
    "    above_threshold_color_tuple = (0.5, 0.5, 0.5, 0.3) # Gray RGBA tuple\n",
    "\n",
    "    # --- Convert Color Tuples to Hex Strings ---\n",
    "    hex_dendrogram_link_colors = []\n",
    "    for rgb_tuple in color_tuples_for_dendro:\n",
    "        # sns.color_palette for \"deep\" returns RGB, so keep_alpha=False\n",
    "        hex_dendrogram_link_colors.append(mcolors.to_hex(rgb_tuple, keep_alpha=False)) \n",
    "    \n",
    "    # above_threshold_color_tuple has alpha, so keep_alpha=True\n",
    "    hex_above_threshold_color = mcolors.to_hex(above_threshold_color_tuple, keep_alpha=True)\n",
    "\n",
    "    # --- Define the custom link_color_func using Hex Colors ---\n",
    "    _Z_for_func = current_Z\n",
    "    _n_samples_for_func = n_samples\n",
    "    _cluster_threshold_for_func = current_cluster_threshold\n",
    "    _flat_ids_for_samples_for_func = flat_cluster_ids_for_samples\n",
    "    _c_sorted_list_for_func = c_sorted_list\n",
    "    _dendro_colors_hex_for_func = hex_dendrogram_link_colors # Use hex colors\n",
    "    _above_color_hex_for_func = hex_above_threshold_color   # Use hex color\n",
    "\n",
    "    def custom_dendrogram_link_color_func(link_k):\n",
    "        link_merge_idx = link_k - _n_samples_for_func\n",
    "        link_distance = _Z_for_func[link_merge_idx, 2]\n",
    "\n",
    "        if link_distance >= _cluster_threshold_for_func:\n",
    "            return _above_color_hex_for_func # Return hex string\n",
    "        else:\n",
    "            current_node = link_k\n",
    "            while current_node >= _n_samples_for_func:\n",
    "                current_node = int(_Z_for_func[current_node - _n_samples_for_func, 0])\n",
    "            one_leaf_idx = current_node\n",
    "            \n",
    "            target_flat_cid = _flat_ids_for_samples_for_func[one_leaf_idx]\n",
    "            \n",
    "            try:\n",
    "                cmap_idx = _c_sorted_list_for_func.index(target_flat_cid)\n",
    "                if 0 <= cmap_idx < len(_dendro_colors_hex_for_func):\n",
    "                    return _dendro_colors_hex_for_func[cmap_idx] # Return hex string\n",
    "                else: \n",
    "                    return _above_color_hex_for_func\n",
    "            except ValueError:\n",
    "                return _above_color_hex_for_func\n",
    "\n",
    "    # --- Plotting the Dendrogram ---\n",
    "    cl.hierarchy.dendrogram(\n",
    "        current_Z,\n",
    "        link_color_func=custom_dendrogram_link_color_func,\n",
    "        ax=axs[0, t],\n",
    "        above_threshold_color=_above_color_hex_for_func # Consistent with link_color_func returns\n",
    "    )\n",
    "    axs[0, t].set_xticks([])\n",
    "    axs[0, t].set_ylabel(\"Cosine similarity\")\n",
    "    axs[0, t].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])\n",
    "    axs[0, t].set_yticklabels([1.0, 0.8, 0.6, 0.4, 0.2, 0.0])\n",
    "\n",
    "\n",
    "    # --- Your existing Heatmap code (slightly adapted for clarity) ---\n",
    "    headmap = np.zeros((12, 12), dtype=int) - 1\n",
    "    for c_map_idx, actual_cluster_id_val in enumerate(c_sorted_list):\n",
    "        for i in np.where(flat_cluster_ids_for_samples == actual_cluster_id_val)[0]:\n",
    "            layer, ah_idx = pyth.ctrld_heads[signal_types[t]][i]\n",
    "            headmap[layer, ah_idx] = c_map_idx\n",
    "\n",
    "    # Prepare the exact cmap for the heatmap (using original tuples is fine for seaborn)\n",
    "    if len(c_sorted_list) > 0:\n",
    "        heatmap_specific_color_tuples = list(sns.color_palette(\"deep\", len(c_sorted_list)))\n",
    "        heatmap_specific_color_tuples[0] = (1.0, 1.0, 1.0)\n",
    "        final_heatmap_cmap = heatmap_specific_color_tuples\n",
    "    else:\n",
    "        final_heatmap_cmap = None \n",
    "\n",
    "    for c in range(n_clusters_appearing_dendogram[t]+1, len(final_heatmap_cmap)):\n",
    "        final_heatmap_cmap[c] = (1, 1, 1) # setting every cluster not appearing in the dendogram to white\n",
    "\n",
    "    sns.heatmap(headmap, ax=axs[1, t], cmap=final_heatmap_cmap, cbar=False, annot=False)\n",
    "    axs[1, t].set_ylabel(\"Layer\")\n",
    "    axs[1, t].set_xlabel(\"Attention Head Index\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"figures/control_signals/pythia-160m_control-signals.pdf\", bbox_inches='tight', dpi=800);\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c52483e",
   "metadata": {},
   "source": [
    "## Gemma-2 2B at the Head level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1310623b",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, 2, figsize = (5, 3.5))\n",
    "\n",
    "# Has to be set manually after seeing the plot\n",
    "n_clusters_appearing_dendogram = {\n",
    "    0: 4,\n",
    "    1: 9\n",
    "}\n",
    "\n",
    "signal_types = {0: \"u_signals\", 1: \"v_signals\"}\n",
    "\n",
    "for t in range(2): # Your loop\n",
    "    control_signals = gemma.ctrl_sigs[signal_types[t]]\n",
    "    current_Z = cl.hierarchy.linkage(control_signals, method='average', metric='cosine', optimal_ordering=False)\n",
    "    n_samples = current_Z.shape[0] + 1\n",
    "    \n",
    "    current_cluster_threshold = 0.75\n",
    "\n",
    "    flat_cluster_ids_for_samples = cl.hierarchy.fcluster(current_Z, current_cluster_threshold, criterion='distance')\n",
    "    c_sorted_list = list(main_clusters(flat_cluster_ids_for_samples))\n",
    "\n",
    "    # --- Prepare Color Tuples First ---\n",
    "    color_tuples_for_dendro = []\n",
    "    if len(c_sorted_list) > 0:\n",
    "        color_tuples_for_dendro = list(sns.color_palette(\"deep\", len(c_sorted_list)+1))[1:] # List of RGB tuples\n",
    "        #color_tuples_for_dendro[0] = (0, 0, 1) #(1.0, 1.0, 1.0) # White RGB tuple\n",
    "        #color_tuples_for_dendro = [(0, 0, 1)] * len(c_sorted_list)\n",
    "    \n",
    "    above_threshold_color_tuple = (0.5, 0.5, 0.5, 0.3) # Gray RGBA tuple\n",
    "    #above_threshold_color_tuple = (0., 0., 1.0, 1.0) # Blue RGBA tuple\n",
    "\n",
    "    # --- Convert Color Tuples to Hex Strings ---\n",
    "    hex_dendrogram_link_colors = []\n",
    "    for rgb_tuple in color_tuples_for_dendro:\n",
    "        # sns.color_palette for \"deep\" returns RGB, so keep_alpha=False\n",
    "        hex_dendrogram_link_colors.append(mcolors.to_hex(rgb_tuple, keep_alpha=False)) \n",
    "    \n",
    "    # above_threshold_color_tuple has alpha, so keep_alpha=True\n",
    "    hex_above_threshold_color = mcolors.to_hex(above_threshold_color_tuple, keep_alpha=True)\n",
    "\n",
    "    # --- Define the custom link_color_func using Hex Colors ---\n",
    "    _Z_for_func = current_Z\n",
    "    _n_samples_for_func = n_samples\n",
    "    _cluster_threshold_for_func = current_cluster_threshold\n",
    "    _flat_ids_for_samples_for_func = flat_cluster_ids_for_samples\n",
    "    _c_sorted_list_for_func = c_sorted_list\n",
    "    _dendro_colors_hex_for_func = hex_dendrogram_link_colors # Use hex colors\n",
    "    _above_color_hex_for_func = hex_above_threshold_color   # Use hex color\n",
    "\n",
    "    def custom_dendrogram_link_color_func(link_k):\n",
    "        link_merge_idx = link_k - _n_samples_for_func\n",
    "        link_distance = _Z_for_func[link_merge_idx, 2]\n",
    "\n",
    "        if link_distance >= _cluster_threshold_for_func:\n",
    "            return _above_color_hex_for_func # Return hex string\n",
    "        else:\n",
    "            current_node = link_k\n",
    "            while current_node >= _n_samples_for_func:\n",
    "                current_node = int(_Z_for_func[current_node - _n_samples_for_func, 0])\n",
    "            one_leaf_idx = current_node\n",
    "            \n",
    "            target_flat_cid = _flat_ids_for_samples_for_func[one_leaf_idx]\n",
    "            \n",
    "            try:\n",
    "                cmap_idx = _c_sorted_list_for_func.index(target_flat_cid)\n",
    "                if 0 <= cmap_idx < len(_dendro_colors_hex_for_func):\n",
    "                    return _dendro_colors_hex_for_func[cmap_idx] # Return hex string\n",
    "                else: \n",
    "                    return _above_color_hex_for_func\n",
    "            except ValueError:\n",
    "                return _above_color_hex_for_func\n",
    "\n",
    "    # --- Plotting the Dendrogram ---\n",
    "    cl.hierarchy.dendrogram(\n",
    "        current_Z,\n",
    "        link_color_func=custom_dendrogram_link_color_func,\n",
    "        ax=axs[0, t],\n",
    "        above_threshold_color=_above_color_hex_for_func # Consistent with link_color_func returns\n",
    "    )\n",
    "    axs[0, t].set_xticks([])\n",
    "    axs[0, t].set_ylabel(\"Cosine similarity\")\n",
    "    axs[0, t].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])\n",
    "    axs[0, t].set_yticklabels([1.0, 0.8, 0.6, 0.4, 0.2, 0.0])\n",
    "\n",
    "    # --- Your existing Heatmap code (slightly adapted for clarity) ---\n",
    "    headmap = np.zeros((26, 8), dtype=int) - 1\n",
    "    for c_map_idx, actual_cluster_id_val in enumerate(c_sorted_list):\n",
    "        for i in np.where(flat_cluster_ids_for_samples == actual_cluster_id_val)[0]:\n",
    "            layer, ah_idx = gemma.ctrld_heads[signal_types[t]][i]\n",
    "            headmap[layer, ah_idx] = c_map_idx\n",
    "\n",
    "    # Prepare the exact cmap for the heatmap (using original tuples is fine for seaborn)\n",
    "    if len(c_sorted_list) > 0:\n",
    "        heatmap_specific_color_tuples = list(sns.color_palette(\"deep\", len(c_sorted_list)))\n",
    "        heatmap_specific_color_tuples[0] = (1.0, 1.0, 1.0)\n",
    "        final_heatmap_cmap = heatmap_specific_color_tuples\n",
    "    else:\n",
    "        final_heatmap_cmap = None \n",
    "\n",
    "    for c in range(n_clusters_appearing_dendogram[t]+1, len(final_heatmap_cmap)):\n",
    "        final_heatmap_cmap[c] = (1, 1, 1) # setting every cluster not appearing in the dendogram to white\n",
    "\n",
    "    sns.heatmap(headmap, ax=axs[1, t], cmap=final_heatmap_cmap, cbar=False, annot=False)\n",
    "    axs[1, t].set_ylabel(\"Layer\")\n",
    "    axs[1, t].set_xlabel(\"Attention Head Index\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"figures/control_signals/gemma-2-2b_control-signals.pdf\", bbox_inches='tight', dpi=800);\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8526e18f",
   "metadata": {},
   "source": [
    "# Which components are adding the control signals?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8065b21",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_proto_signals(mod, test_signals, cluster_threshold = 0.65):\n",
    "    Z = cl.hierarchy.linkage(mod.ctrl_sigs[test_signals], 'average', 'cosine', optimal_ordering = False)\n",
    "    cl.hierarchy.dendrogram(Z, color_threshold = cluster_threshold)\n",
    "    cluster_id = cl.hierarchy.fcluster(Z, cluster_threshold, criterion = 'distance')\n",
    "    c_sorted = main_clusters(cluster_id)\n",
    "    proto_signals = []\n",
    "    head_map = {}\n",
    "    for c in range(len(c_sorted)):\n",
    "        accum = []\n",
    "        for i in np.where(cluster_id == c_sorted[c])[0]:\n",
    "            test_layer, test_ah_idx = mod.ctrld_heads[test_signals][i]\n",
    "            head_map[test_layer, test_ah_idx] = c\n",
    "            accum.append(mod.ctrl_sigs[test_signals][i])\n",
    "        # yes, we are taking the mean of means here; does it matter?\n",
    "        proto_signals.append(np.array(accum).mean(axis = 0))\n",
    "    proto_signals = np.array(proto_signals)\n",
    "    return proto_signals, head_map"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1826be4",
   "metadata": {},
   "source": [
    "## GPT-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1cf6fb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt2s_v_proto_signals, gpt2s_v_head_map = get_proto_signals(gpt2s, 'v_signals', cluster_threshold = 0.6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9167aa0",
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt2s_u_proto_signals, gpt2s_u_head_map = get_proto_signals(gpt2s, 'u_signals', cluster_threshold = 0.6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee4c4962",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # # Who is adding this?\n",
    "\n",
    "model_name = \"gpt2-small\"\n",
    "prompt_id = 8\n",
    "n_tokens = gpt2s.ioi_dataset.word_idx[\"end\"][prompt_id].item() + 1\n",
    "n_max_tokens = 21\n",
    "for signal_type in [\"u-signals\", \"v-signals\"]:\n",
    "    ip_diffs = torch.zeros((2, gpt2s.model.cfg.n_layers * 3, n_max_tokens)) - 1\n",
    "\n",
    "    markers_to_use = ['o', 's', '^']\n",
    "    marker_labels = [\"Residual (pre)\", \"Residual (mid)\", \"Residual (post)\"]\n",
    "    plt.figure(figsize=(4, 1.9))\n",
    "\n",
    "    for cid in range(2):\n",
    "        if signal_type == \"u-signals\":\n",
    "            print(\"U-signals\")\n",
    "            control_signal = gpt2s_u_proto_signals[cid]\n",
    "            control_signal = control_signal / np.linalg.norm(control_signal)\n",
    "            control_signal = torch.from_numpy(control_signal)\n",
    "            # We are interested on the end token for U signals\n",
    "            token_plot_gram = \"end\"\n",
    "            token_plot = gpt2s.ioi_dataset.word_idx[token_plot_gram][prompt_id].item()\n",
    "\n",
    "        else:\n",
    "            print(\"V-signals\")\n",
    "            control_signal = gpt2s_v_proto_signals[cid]\n",
    "            control_signal = control_signal / np.linalg.norm(control_signal)\n",
    "            control_signal = torch.from_numpy(control_signal)\n",
    "            # We are interested on the starts token for U signals\n",
    "            token_plot_gram = \"starts\"\n",
    "            token_plot = gpt2s.ioi_dataset.word_idx[token_plot_gram][prompt_id].item()\n",
    "\n",
    "        i = 0 \n",
    "        for layer in range(gpt2s.model.cfg.n_layers):\n",
    "            for tok in range(n_max_tokens):\n",
    "                if layer == 0:\n",
    "                    ip_diffs[cid, i, tok] = F.cosine_similarity(gpt2s.cache[f\"blocks.{layer}.hook_resid_pre\"][prompt_id, tok], control_signal, dim=0)\n",
    "                else:\n",
    "                    ip_diffs[cid, i, tok] = F.cosine_similarity(gpt2s.cache[f\"blocks.{layer}.hook_resid_pre\"][prompt_id, tok], control_signal, dim=0)\n",
    "                \n",
    "                ip_diffs[cid, i+1, tok] = F.cosine_similarity(gpt2s.cache[f\"blocks.{layer}.hook_resid_mid\"][prompt_id, tok], control_signal, dim=0)\n",
    "                ip_diffs[cid, i+2, tok] = F.cosine_similarity(gpt2s.cache[f\"blocks.{layer}.hook_resid_post\"][prompt_id, tok], control_signal, dim=0)\n",
    "\n",
    "            i+=3\n",
    "        \n",
    "        # Plotting\n",
    "        indices = np.arange(len(ip_diffs[cid, :, token_plot]))\n",
    "        \n",
    "        plt.plot(indices, ip_diffs[cid, :, token_plot], linestyle='-', color='lightgray', alpha=0.7, label='_nolegend_') # Base line\n",
    "\n",
    "        labels_mapping = {0: \"Residual (pre)\", 1: \"Residual (mid)\", 2: \"Residual (post)\"}\n",
    "\n",
    "        for marker_condition in range(3):\n",
    "            x_values_group = indices[marker_condition::3]\n",
    "            y_values_group = ip_diffs[cid, :, token_plot][marker_condition::3]\n",
    "            \n",
    "            plt.plot(x_values_group, y_values_group,\n",
    "                    marker=markers_to_use[marker_condition],\n",
    "                    linestyle='None',  # 'None' means no line connecting these specific marked points\n",
    "                    label=labels_mapping[marker_condition],\n",
    "                    markersize=3,\n",
    "                    color=final_heatmap_cmap[cid+1]) # Adjust marker size if needed\n",
    "\n",
    "    # --- Create Handles and Their Corresponding Labels for the Custom Legend ---\n",
    "    legend_handles = []\n",
    "\n",
    "    # 1. Handles for Line Colors (Clusters)\n",
    "    handle_line_cluster0 = mlines.Line2D([], [], # Empty line, only for legend\n",
    "                                        color=final_heatmap_cmap[1],\n",
    "                                        linestyle='-',\n",
    "                                        marker='None', # No marker for this legend entry\n",
    "                                        label=\"Cluster 0\")\n",
    "    legend_handles.append(handle_line_cluster0)\n",
    "\n",
    "    handle_line_cluster1 = mlines.Line2D([], [],\n",
    "                                        color=final_heatmap_cmap[2],\n",
    "                                        linestyle='-',\n",
    "                                        marker='None',\n",
    "                                        label=\"Cluster 1\")\n",
    "    legend_handles.append(handle_line_cluster1)\n",
    "\n",
    "    # 2. Handles for Marker Types (Residual conditions)\n",
    "    for i in range(len(markers_to_use)):\n",
    "        handle_marker = mlines.Line2D([], [],\n",
    "                                    color=\"black\", # Neutral color for marker symbol in legend\n",
    "                                    marker=markers_to_use[i],\n",
    "                                    linestyle='None', # Only show the marker\n",
    "                                    label=marker_labels[i])\n",
    "        legend_handles.append(handle_marker)\n",
    "\n",
    "    # --- Add the Single, Custom Legend to the Plot ---\n",
    "    # The labels are taken from the 'label' attribute of each handle.\n",
    "    plt.legend(handles=legend_handles,\n",
    "            loc='lower center',\n",
    "            bbox_to_anchor=(0.5, 1.02),\n",
    "            title=None,\n",
    "            ncol=3,\n",
    "            fontsize=6)\n",
    "\n",
    "    plt.xlabel('Layer')\n",
    "    plt.ylabel('Cosine similarity')\n",
    "    plt.xticks(range(0, ip_diffs.shape[1], 3), labels=[f\"Layer {i}\" for i, _ in enumerate(range(0, ip_diffs.shape[1], 3))], rotation=90);\n",
    "    #plt.legend(fontsize=6) # Display the legend to identify markers\n",
    "    plt.grid(True) # Add a grid for better readability\n",
    "    #plt.tight_layout()\n",
    "    filename = f\"figures/control_signals/{model_name}_ioi_{signal_type}_{token_plot_gram}_pid-{prompt_id}.pdf\"\n",
    "    plt.savefig(filename, bbox_inches='tight', dpi=800);\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0b6b2ab",
   "metadata": {},
   "source": [
    "## Pythia"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65c8070e",
   "metadata": {},
   "outputs": [],
   "source": [
    "pyth_v_proto_signals, pyth_v_head_map = get_proto_signals(pyth, 'v_signals', cluster_threshold = 0.7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1da21d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "pyth_u_proto_signals, pyth_u_head_map = get_proto_signals(pyth, 'u_signals', cluster_threshold = 0.7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f077ac94",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # # Who is adding this? Pythia\n",
    "\n",
    "model_name = \"pythia-160m\"\n",
    "prompt_id = 8\n",
    "n_tokens = gpt2s.ioi_dataset.word_idx[\"end\"][prompt_id].item() + 1\n",
    "n_max_tokens = 21\n",
    "for signal_type in [\"u-signals\", \"v-signals\"]:\n",
    "\n",
    "    ip_diffs = torch.zeros((2, pyth.model.cfg.n_layers * 2, n_max_tokens)) - 1\n",
    "\n",
    "    markers_to_use = ['o', 's']\n",
    "    marker_labels = [\"Residual (pre)\", \"Residual (post)\"]\n",
    "    plt.figure(figsize=(4, 1.9))\n",
    "\n",
    "    for cid in range(2):\n",
    "        if signal_type == \"u-signals\":\n",
    "            print(\"U-signals\")\n",
    "            control_signal = pyth_u_proto_signals[cid]\n",
    "            control_signal = control_signal / np.linalg.norm(control_signal)\n",
    "            control_signal = torch.from_numpy(control_signal)\n",
    "            # We are interested on the end token for U signals\n",
    "            token_plot_gram = \"end\"\n",
    "            token_plot = pyth.ioi_dataset.word_idx[token_plot_gram][prompt_id].item()\n",
    "\n",
    "        else:\n",
    "            print(\"V-signals\")\n",
    "            control_signal = pyth_v_proto_signals[cid]\n",
    "            control_signal = control_signal / np.linalg.norm(control_signal)\n",
    "            control_signal = torch.from_numpy(control_signal)\n",
    "            # We are interested on the starts token for U signals\n",
    "            token_plot_gram = \"starts\"\n",
    "            token_plot = pyth.ioi_dataset.word_idx[token_plot_gram][prompt_id].item()\n",
    "\n",
    "        i = 0 \n",
    "        for layer in range(pyth.model.cfg.n_layers):\n",
    "            for tok in range(n_max_tokens):\n",
    "                ip_diffs[cid, i, tok] = F.cosine_similarity(pyth.cache[f\"blocks.{layer}.hook_resid_pre\"][prompt_id, tok], control_signal, dim=0)\n",
    "                ip_diffs[cid, i+1, tok] = F.cosine_similarity(pyth.cache[f\"blocks.{layer}.hook_resid_post\"][prompt_id, tok], control_signal, dim=0)\n",
    "            i+=2\n",
    "        \n",
    "        # Plotting\n",
    "        indices = np.arange(len(ip_diffs[cid, :, token_plot]))\n",
    "        \n",
    "        plt.plot(indices, ip_diffs[cid, :, token_plot], linestyle='-', color='lightgray', alpha=0.7, label='_nolegend_') # Base line\n",
    "\n",
    "        labels_mapping = {0: \"Residual (pre)\", 1: \"Residual (post)\"}\n",
    "\n",
    "        for marker_condition in range(2):\n",
    "            x_values_group = indices[marker_condition::2]\n",
    "            y_values_group = ip_diffs[cid, :, token_plot][marker_condition::2]\n",
    "            \n",
    "            plt.plot(x_values_group, y_values_group,\n",
    "                    marker=markers_to_use[marker_condition],\n",
    "                    linestyle='None',  # 'None' means no line connecting these specific marked points\n",
    "                    label=labels_mapping[marker_condition],\n",
    "                    markersize=3,\n",
    "                    color=final_heatmap_cmap[cid+1]) # Adjust marker size if needed\n",
    "\n",
    "\n",
    "    # --- Create Handles and Their Corresponding Labels for the Custom Legend ---\n",
    "    legend_handles = []\n",
    "\n",
    "    # 1. Handles for Line Colors (Clusters)\n",
    "    handle_line_cluster0 = mlines.Line2D([], [], # Empty line, only for legend\n",
    "                                        color=final_heatmap_cmap[1],\n",
    "                                        linestyle='-',\n",
    "                                        marker='None', # No marker for this legend entry\n",
    "                                        label=\"Cluster 0\")\n",
    "    legend_handles.append(handle_line_cluster0)\n",
    "\n",
    "    handle_line_cluster1 = mlines.Line2D([], [],\n",
    "                                        color=final_heatmap_cmap[2],\n",
    "                                        linestyle='-',\n",
    "                                        marker='None',\n",
    "                                        label=\"Cluster 1\")\n",
    "    legend_handles.append(handle_line_cluster1)\n",
    "\n",
    "    # 2. Handles for Marker Types (Residual conditions)\n",
    "    for i in range(len(markers_to_use)):\n",
    "        handle_marker = mlines.Line2D([], [],\n",
    "                                    color=\"black\", # Neutral color for marker symbol in legend\n",
    "                                    marker=markers_to_use[i],\n",
    "                                    linestyle='None', # Only show the marker\n",
    "                                    label=marker_labels[i])\n",
    "        legend_handles.append(handle_marker)\n",
    "\n",
    "    # --- Add the Single, Custom Legend to the Plot ---\n",
    "    # The labels are taken from the 'label' attribute of each handle.\n",
    "    plt.legend(handles=legend_handles,\n",
    "            loc='lower center',\n",
    "            bbox_to_anchor=(0.5, 1.02),\n",
    "            title=None,\n",
    "            ncol=3,\n",
    "            fontsize=6)\n",
    "\n",
    "    #### END HERE\n",
    "    plt.xlabel('Layer')\n",
    "    plt.ylabel('Cosine similarity')\n",
    "    plt.xticks(range(0, ip_diffs.shape[1], 2), labels=[f\"Layer {i}\" for i, _ in enumerate(range(0, ip_diffs.shape[1], 2))], rotation=90);\n",
    "    #plt.legend(fontsize=6) # Display the legend to identify markers\n",
    "    plt.grid(True) # Add a grid for better readability\n",
    "    #plt.tight_layout()\n",
    "    filename = f\"figures/control_signals/{model_name}_ioi_{signal_type}_{token_plot_gram}_pid-{prompt_id}.pdf\"\n",
    "    plt.savefig(filename, bbox_inches='tight', dpi=800);\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "741d291d",
   "metadata": {},
   "source": [
    "## Gemma-2 2B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "deed7325",
   "metadata": {},
   "outputs": [],
   "source": [
    "gemma_v_proto_signals, gemma_v_head_map = get_proto_signals(gemma, 'v_signals', cluster_threshold = 0.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "905f7727",
   "metadata": {},
   "outputs": [],
   "source": [
    "gemma_u_proto_signals, gemma_u_head_map = get_proto_signals(gemma, 'u_signals', cluster_threshold = 0.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e1acbb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "import matplotlib.lines as mlines # For creating Line2D proxy artists for the legend\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "plt.rc('font', size=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "779343b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # # Who is adding this?\n",
    "\n",
    "model_name = \"gemma-2-2b\"\n",
    "prompt_id = 8\n",
    "n_tokens = gemma.ioi_dataset.word_idx[\"end\"][prompt_id].item() + 1\n",
    "n_max_tokens = 21\n",
    "for signal_type in [\"u-signals\", \"v-signals\"]:\n",
    "    ip_diffs = torch.zeros((2, gemma.model.cfg.n_layers * 3, n_max_tokens)) - 1\n",
    "\n",
    "    markers_to_use = ['o', 's', '^']\n",
    "    marker_labels = [\"Residual (pre)\", \"Residual (mid)\", \"Residual (post)\"]\n",
    "    plt.figure(figsize=(4, 1.9))\n",
    "\n",
    "    for cid in range(2):\n",
    "        if signal_type == \"u-signals\":\n",
    "            print(\"U-signals\")\n",
    "            control_signal = gemma_u_proto_signals[cid]\n",
    "            control_signal = control_signal / np.linalg.norm(control_signal)\n",
    "            control_signal = torch.from_numpy(control_signal)\n",
    "            # We are interested on the end token for U signals\n",
    "            token_plot_gram = \"end\"\n",
    "            token_plot = gemma.ioi_dataset.word_idx[token_plot_gram][prompt_id].item()\n",
    "\n",
    "        else:\n",
    "            print(\"V-signals\")\n",
    "            control_signal = gemma_v_proto_signals[cid]\n",
    "            control_signal = control_signal / np.linalg.norm(control_signal)\n",
    "            control_signal = torch.from_numpy(control_signal)\n",
    "            # We are interested on the starts token for U signals\n",
    "            token_plot_gram = \"starts\"\n",
    "            token_plot = gemma.ioi_dataset.word_idx[token_plot_gram][prompt_id].item()\n",
    "\n",
    "        i = 0 \n",
    "        for layer in range(gemma.model.cfg.n_layers):\n",
    "            for tok in range(n_max_tokens):\n",
    "                if layer == 0:\n",
    "                    ip_diffs[cid, i, tok] = F.cosine_similarity(gemma.cache[f\"blocks.{layer}.hook_resid_pre\"][prompt_id, tok], control_signal, dim=0)\n",
    "                else:\n",
    "                    ip_diffs[cid, i, tok] = F.cosine_similarity(gemma.cache[f\"blocks.{layer}.hook_resid_pre\"][prompt_id, tok], control_signal, dim=0)\n",
    "                \n",
    "                ip_diffs[cid, i+1, tok] = F.cosine_similarity(gemma.cache[f\"blocks.{layer}.hook_resid_mid\"][prompt_id, tok], control_signal, dim=0)\n",
    "                ip_diffs[cid, i+2, tok] = F.cosine_similarity(gemma.cache[f\"blocks.{layer}.hook_resid_post\"][prompt_id, tok], control_signal, dim=0)\n",
    "\n",
    "            i+=3\n",
    "        \n",
    "        # Plotting\n",
    "        indices = np.arange(len(ip_diffs[cid, :, token_plot]))\n",
    "        \n",
    "        plt.plot(indices, ip_diffs[cid, :, token_plot], linestyle='-', color='lightgray', alpha=0.7, label='_nolegend_') # Base line\n",
    "\n",
    "        labels_mapping = {0: \"Residual (pre)\", 1: \"Residual (mid)\", 2: \"Residual (post)\"}\n",
    "\n",
    "        for marker_condition in range(3):\n",
    "            x_values_group = indices[marker_condition::3]\n",
    "            y_values_group = ip_diffs[cid, :, token_plot][marker_condition::3]\n",
    "            \n",
    "            plt.plot(x_values_group, y_values_group,\n",
    "                    marker=markers_to_use[marker_condition],\n",
    "                    linestyle='None',  # 'None' means no line connecting these specific marked points\n",
    "                    label=labels_mapping[marker_condition],\n",
    "                    markersize=2.5,\n",
    "                    color=final_heatmap_cmap[cid+1]) # Adjust marker size if needed\n",
    "\n",
    "    # --- Create Handles and Their Corresponding Labels for the Custom Legend ---\n",
    "    legend_handles = []\n",
    "\n",
    "    # 1. Handles for Line Colors (Clusters)\n",
    "    handle_line_cluster0 = mlines.Line2D([], [], # Empty line, only for legend\n",
    "                                        color=final_heatmap_cmap[1],\n",
    "                                        linestyle='-',\n",
    "                                        marker='None', # No marker for this legend entry\n",
    "                                        label=\"Cluster 0\")\n",
    "    legend_handles.append(handle_line_cluster0)\n",
    "\n",
    "    handle_line_cluster1 = mlines.Line2D([], [],\n",
    "                                        color=final_heatmap_cmap[2],\n",
    "                                        linestyle='-',\n",
    "                                        marker='None',\n",
    "                                        label=\"Cluster 1\")\n",
    "    legend_handles.append(handle_line_cluster1)\n",
    "\n",
    "    # 2. Handles for Marker Types (Residual conditions)\n",
    "    for i in range(len(markers_to_use)):\n",
    "        handle_marker = mlines.Line2D([], [],\n",
    "                                    color=\"black\", # Neutral color for marker symbol in legend\n",
    "                                    marker=markers_to_use[i],\n",
    "                                    linestyle='None', # Only show the marker\n",
    "                                    label=marker_labels[i])\n",
    "        legend_handles.append(handle_marker)\n",
    "\n",
    "    # --- Add the Single, Custom Legend to the Plot ---\n",
    "    # The labels are taken from the 'label' attribute of each handle.\n",
    "    plt.legend(handles=legend_handles,\n",
    "            loc='lower center',\n",
    "            bbox_to_anchor=(0.5, 1.02),\n",
    "            title=None,\n",
    "            ncol=3,\n",
    "            fontsize=6)\n",
    "\n",
    "    plt.xlabel('Layer')\n",
    "    plt.ylabel('Cosine similarity')\n",
    "    plt.xticks(range(0, ip_diffs.shape[1], 3), labels=[f\"Layer {i}\" for i, _ in enumerate(range(0, ip_diffs.shape[1], 3))], rotation=90);\n",
    "    #plt.legend(fontsize=6) # Display the legend to identify markers\n",
    "    plt.grid(True) # Add a grid for better readability\n",
    "    #plt.tight_layout()\n",
    "    filename = f\"figures/control_signals/{model_name}_ioi_{signal_type}_{token_plot_gram}_pid-{prompt_id}.pdf\"\n",
    "    plt.savefig(filename, bbox_inches='tight', dpi=800);\n",
    "    plt.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "signals-stream",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
