{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Circuit Components"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Tuple, List\n",
    "from functools import partial\n",
    "\n",
    "import os\n",
    "import numpy as np \n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from transformer_lens import HookedTransformer\n",
    "from einops import rearrange\n",
    "\n",
    "import einops\n",
    "from fancy_einsum import einsum\n",
    "\n",
    "from utils.data_processing import (\n",
    "    load_edge_scores_into_dictionary,\n",
    "    get_ckpts,\n",
    "    load_metrics,\n",
    "    get_ckpts\n",
    ")\n",
    "from utils.backup_analysis import load_model, get_past_nmhs_for_checkpoints, compute_copy_score\n",
    "from utils.data_utils import generate_data_and_caches\n",
    "from utils.cspa_main import prepare_data\n",
    "from path_patching_cm.ioi_dataset import IOIDataset\n",
    "from path_patching_cm.path_patching import Node, path_patch\n",
    "\n",
    "from utils.visualization import imshow_p, plot_attention_heads\n",
    "\n",
    "#%%\n",
    "def convert_head_names_to_tuple(head_name):\n",
    "    head_name = head_name.replace('a', '')\n",
    "    head_name = head_name.replace('h', '')\n",
    "    layer, head = head_name.split('.')\n",
    "    return (int(layer), int(head))\n",
    "\n",
    "def collate_fn(ds, device):\n",
    "    if not ds:\n",
    "        return {}\n",
    "    return {k: torch.stack([d[k] for d in ds], dim=0).to(device) for k in ds[0].keys()}\n",
    "\n",
    "class BatchIOIDataset(IOIDataset):\n",
    "    def __init__(self, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.N\n",
    "    \n",
    "    def __getitem__(self, i):\n",
    "        return {'toks':self.toks[i], 'io_token_id': torch.tensor(self.io_tokenIDs[i]), 's_token_id': torch.tensor(self.s_tokenIDs[i]), **{f'{k}_pos':v[i] for k, v in self.word_idx.items()}}\n",
    "    \n",
    "##%\n",
    "\n",
    "def make_s2i(layer, head):\n",
    "    return Node(f'blocks.{layer}.attn.hook_z', layer, head)\n",
    "def make_nmh(layer, head):\n",
    "    return Node(f'blocks.{layer}.hook_q_input', layer, head)\n",
    "\n",
    "from utils.head_metrics import S2I_head_metrics, S2I_token_pos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK = 'ioi'\n",
    "PERFORMANCE_METRIC = 'logit_diff'\n",
    "BASE_MODEL = \"pythia-160m\"\n",
    "VARIANT = None #\"EleutherAI/pythia-160m-attndropout\"\n",
    "MODEL_SHORTNAME = BASE_MODEL if not VARIANT else VARIANT[11:]\n",
    "CACHE = \"model_cache\"\n",
    "DATASET_SIZE = 100\n",
    "SEED = 42\n",
    "BATCH_SIZE = 70\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_path = f'results/graphs/{MODEL_SHORTNAME}/{TASK}'\n",
    "df = load_edge_scores_into_dictionary(folder_path)\n",
    "\n",
    "directory_path = 'results'\n",
    "perf_metrics = load_metrics(directory_path)\n",
    "\n",
    "ckpts = get_ckpts(schedule=\"exp_plus_detail\")\n",
    "\n",
    "# filter everything before 1000 steps\n",
    "df = df[df['checkpoint'] >= 1000]\n",
    "\n",
    "df[['source', 'target']] = df['edge'].str.split('->', expand=True)\n",
    "len(df['target'].unique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_metrics = torch.load(f'results/backup/{MODEL_SHORTNAME}/nmh_backup_metrics.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.cspa_main import get_result_mean, get_cspa_results_batched, get_performance_recovered\n",
    "\n",
    "def get_cspa_for_head(model, data_toks, cspa_semantic_dict, layer, head, verbose=False):\n",
    "\n",
    "    current_batch_size = 17 # Smaller values so we can check more checkpoints in a reasonable amount of time\n",
    "    current_seq_len = 61\n",
    "\n",
    "    result_mean = get_result_mean([(layer, head)], data_toks[:100, :], model, verbose=True)\n",
    "    cspa_results_qk_ov = get_cspa_results_batched(\n",
    "        model=model,\n",
    "        toks=data_toks[:current_batch_size, :current_seq_len],\n",
    "        max_batch_size=1,  # 50,\n",
    "        negative_head=(layer, head),\n",
    "        interventions=[\"ov\", \"qk\"],\n",
    "        only_keep_negative_components=True,\n",
    "        K_unembeddings=0.05,  # most interesting in range 3-8 (out of 80)\n",
    "        K_semantic=1,  # either 1 or up to 8 to capture all sem similar\n",
    "        semantic_dict=cspa_semantic_dict,\n",
    "        result_mean=result_mean,\n",
    "        use_cuda=True,\n",
    "        verbose=True,\n",
    "        compute_s_sstar_dict=False,\n",
    "        computation_device=\"cpu\",  # device\n",
    "    )\n",
    "    head_results = get_performance_recovered(cspa_results_qk_ov)\n",
    "\n",
    "    if verbose:\n",
    "        print(f\"Layer {layer}, head {head} done. Performance: {head_results:.2f}\")\n",
    "\n",
    "    return head_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_attention_to_ioi_token(\n",
    "        model: HookedTransformer, \n",
    "        ioi_dataset: IOIDataset,  \n",
    "        head_list: List[Tuple[int, int]], \n",
    "        batch_size\n",
    "    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: \n",
    "\n",
    "    ioi_dataset.__class__ = BatchIOIDataset\n",
    "    ioi_dataloader = DataLoader(ioi_dataset, batch_size=batch_size, collate_fn=partial(collate_fn, device=model.cfg.device))\n",
    "    \n",
    "    NMH_layers, NMH_heads = zip(*head_list)\n",
    "    NMH_layers = torch.tensor(NMH_layers, device=model.cfg.device)\n",
    "    NMH_heads = torch.tensor(NMH_heads, device=model.cfg.device)\n",
    "\n",
    "    # Initialize tensors to accumulate attention values\n",
    "    n_layers = model.cfg.n_layers\n",
    "    n_heads = model.cfg.n_heads\n",
    "    s1_attention_accum = torch.zeros((n_layers, n_heads), device=model.cfg.device)\n",
    "    s2_attention_accum = torch.zeros((n_layers, n_heads), device=model.cfg.device)\n",
    "    io_attention_accum = torch.zeros((n_layers, n_heads), device=model.cfg.device)\n",
    "    batch_count = 0\n",
    "\n",
    "    for batch in ioi_dataloader:\n",
    "        batch_count += 1\n",
    "        toks = batch['toks']\n",
    "        io_pos = batch['IO_pos']\n",
    "        end_pos = batch['end_pos']\n",
    "        s2_pos = batch['S2_pos']\n",
    "        s1_pos = batch['S1_pos']\n",
    "        s_token_ids = batch['s_token_id']\n",
    "        io_token_ids = batch['io_token_id']\n",
    "\n",
    "        cache, caching_hooks, _ = model.get_caching_hooks(lambda name: 'hook_pattern' in name)\n",
    "        with model.hooks(caching_hooks):\n",
    "            logits = model(toks)[torch.arange(len(toks)), end_pos]\n",
    "\n",
    "        attention_patterns = torch.stack([cache[f'blocks.{n}.attn.hook_pattern'] for n in range(n_layers)])  #layer, batch, head, query, key\n",
    "        attention_patterns_by_head = attention_patterns[NMH_layers, :, NMH_heads]\n",
    "\n",
    "        nmh_s1_attention_values = attention_patterns_by_head[:, torch.arange(len(toks)), end_pos, s1_pos]  # batch, layer, head\n",
    "        nmh_s2_attention_values = attention_patterns_by_head[:, torch.arange(len(toks)), end_pos, s2_pos]  # batch, layer, head\n",
    "        nmh_io_attention_values = attention_patterns_by_head[:, torch.arange(len(toks)), end_pos, io_pos]  # batch, layer, head\n",
    "\n",
    "        # Accumulate attention values\n",
    "        for i, (layer, head) in enumerate(head_list):\n",
    "            s1_attention_accum[layer, head] += nmh_s1_attention_values[:, i].mean()\n",
    "            s2_attention_accum[layer, head] += nmh_s2_attention_values[:, i].mean()\n",
    "            io_attention_accum[layer, head] += nmh_io_attention_values[:, i].mean()\n",
    "\n",
    "    # Calculate mean attention values\n",
    "    s1_attention_means = s1_attention_accum / batch_count\n",
    "    s2_attention_means = s2_attention_accum / batch_count\n",
    "    io_attention_means = io_attention_accum / batch_count\n",
    "\n",
    "    return s1_attention_means, s2_attention_means, io_attention_means\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get NMH candidates\n",
    "def evaluate_direct_effect_heads(model, edge_df, dataset, verbose=False):\n",
    "    direct_effect_heads = edge_df[edge_df['target']=='logits']\n",
    "    direct_effect_heads = direct_effect_heads[direct_effect_heads['in_circuit'] == True]\n",
    "\n",
    "    head_list = direct_effect_heads['source'].unique().tolist()\n",
    "    head_list = [convert_head_names_to_tuple(c) for c in head_list if (c[0] != 'm' and c != 'input')]\n",
    "\n",
    "    head_data = dict()\n",
    "\n",
    "\n",
    "    # Test for NMH behavior\n",
    "    head_data['copy_scores'] = torch.zeros((model.cfg.n_layers, model.cfg.n_heads))\n",
    "    for layer, head in head_list:\n",
    "        head_data['copy_scores'][layer, head] = compute_copy_score(model, layer, head, dataset, verbose=False, neg=False)\n",
    "\n",
    "    # Test for attention to IOI tokens\n",
    "    s1_attn_scores, s2_attn_scores, io_attn_scores = get_attention_to_ioi_token(model, dataset, head_list, batch_size=70)\n",
    "    head_data['s1_attn_scores'], head_data['s2_attn_scores'], head_data['io_attn_scores'] = s1_attn_scores, s2_attn_scores, io_attn_scores\n",
    "    \n",
    "    # Test for copy suppression behavior\n",
    "    model.cfg.use_split_qkv_input = False\n",
    "    DATA_TOKS, DATA_STR_TOKS_PARSED, cspa_semantic_dict, indices = prepare_data(model)\n",
    "\n",
    "    head_data['copy_suppression_scores'] = torch.zeros((model.cfg.n_layers, model.cfg.n_heads))\n",
    "    for layer, head in head_list:\n",
    "        head_data['copy_suppression_scores'][layer, head] = get_cspa_for_head(model, DATA_TOKS, cspa_semantic_dict, layer, head, verbose=verbose)\n",
    "\n",
    "    model.cfg.use_split_qkv_input = True\n",
    "    \n",
    "    return head_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_name_movers(direct_effect_scores, copy_score_threshold):\n",
    "    direct_effect_scores['filtered_copy_scores'] = direct_effect_scores['copy_scores'].clone()\n",
    "\n",
    "    nmh_list = []\n",
    "\n",
    "    for layer in range(direct_effect_scores['copy_scores'].shape[0]):\n",
    "        for head in range(direct_effect_scores['copy_scores'].shape[1]):\n",
    "            if direct_effect_scores['copy_scores'][layer, head] < copy_score_threshold:\n",
    "                direct_effect_scores['filtered_copy_scores'][layer, head] = 0\n",
    "\n",
    "            if direct_effect_scores['copy_scores'][layer, head] > copy_score_threshold \\\n",
    "                and direct_effect_scores['io_attn_scores'][layer, head] > direct_effect_scores['s1_attn_scores'][layer, head] \\\n",
    "                     and direct_effect_scores['io_attn_scores'][layer, head] > direct_effect_scores['s2_attn_scores'][layer, head]:\n",
    "                nmh_list.append((layer, head))\n",
    "\n",
    "    return nmh_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_s2i_candidates(model, checkpoint_df, ioi_dataset, name_mover_heads, batch_size, verbose=False):\n",
    "\n",
    "    patch_dataset_names = ['token_same_pos_oppo', 'token_oppo_pos_same', 'token_oppo_pos_oppo']\n",
    "    targeting_nmh = np.logical_or.reduce(np.array([checkpoint_df['target'] == f'a{layer}.h{head}<q>' for layer, head in name_mover_heads]))\n",
    "    candidate_s2i = checkpoint_df[targeting_nmh]\n",
    "    candidate_s2i = candidate_s2i[candidate_s2i['in_circuit'] == True]\n",
    "\n",
    "    candidate_list = candidate_s2i['source'].unique().tolist()\n",
    "    candidate_list = [convert_head_names_to_tuple(c) for c in candidate_list if (c[0] != 'm' and c != 'input')]\n",
    "\n",
    "\n",
    "    s2i_heads = candidate_list # [(7,9), (7,2), (6,6), (6,5),]\n",
    "\n",
    "\n",
    "    s2i_ablated_logit_diff_deltas = {patch_dataset_name: torch.zeros((model.cfg.n_layers, model.cfg.n_heads)) for patch_dataset_name in patch_dataset_names}\n",
    "    s2i_io_attention_deltas = {patch_dataset_name: torch.zeros((model.cfg.n_layers, model.cfg.n_heads)) for patch_dataset_name in patch_dataset_names}\n",
    "    s2i_s1_attention_deltas = {patch_dataset_name: torch.zeros((model.cfg.n_layers, model.cfg.n_heads)) for patch_dataset_name in patch_dataset_names}\n",
    "    s2i_s2_attention_deltas = {patch_dataset_name: torch.zeros((model.cfg.n_layers, model.cfg.n_heads)) for patch_dataset_name in patch_dataset_names}\n",
    "    true_s2i_mask = torch.zeros((model.cfg.n_layers, model.cfg.n_heads))\n",
    "    true_s2i_heads = []\n",
    "\n",
    "    for head in s2i_heads:\n",
    "        s2i_token_pos_results = S2I_token_pos(model, ioi_dataset, [head], name_mover_heads, batch_size)\n",
    "        if verbose:\n",
    "            print(f'Head {head}')\n",
    "        mean_original_logit_diff = s2i_token_pos_results['ablated_logit_diffs']['token_same_pos_same'].mean()\n",
    "        mean_original_io_attention = s2i_token_pos_results['io_attention_values']['token_same_pos_same'].mean()\n",
    "        mean_original_s1_attention = s2i_token_pos_results['s1_attention_values']['token_same_pos_same'].mean()\n",
    "        mean_original_s2_attention = s2i_token_pos_results['s2_attention_values']['token_same_pos_same'].mean()\n",
    "\n",
    "        for dataset_name in patch_dataset_names:\n",
    "\n",
    "            mean_ablated_logit_diff = s2i_token_pos_results['ablated_logit_diffs'][dataset_name].mean()\n",
    "            mean_ablated_io_attention = s2i_token_pos_results['io_attention_values'][dataset_name].mean()\n",
    "            mean_ablated_s1_attention = s2i_token_pos_results['s1_attention_values'][dataset_name].mean()\n",
    "            mean_ablated_s2_attention = s2i_token_pos_results['s2_attention_values'][dataset_name].mean()\n",
    "\n",
    "            logit_diff_delta = (mean_ablated_logit_diff - mean_original_logit_diff) / mean_original_logit_diff\n",
    "            io_attention_delta = (mean_ablated_io_attention - mean_original_io_attention) / mean_original_io_attention\n",
    "            s1_attention_delta = (mean_ablated_s1_attention - mean_original_s1_attention) / mean_original_s1_attention\n",
    "            s2_attention_delta = (mean_ablated_s2_attention - mean_original_s2_attention) / mean_original_s2_attention\n",
    "\n",
    "            s2i_ablated_logit_diff_deltas[dataset_name][head] = logit_diff_delta\n",
    "            s2i_io_attention_deltas[dataset_name][head] = io_attention_delta\n",
    "            s2i_s1_attention_deltas[dataset_name][head] = s1_attention_delta\n",
    "            s2i_s2_attention_deltas[dataset_name][head] = s2_attention_delta\n",
    "            \n",
    "            if verbose:\n",
    "                print(dataset_name)\n",
    "                print(f\"Logit diff after patching: {100 * logit_diff_delta:.2f}%\")\n",
    "                # should be high with pos = same, low with pos = diff\n",
    "                print(f\"NMH IO Attention Change: {100 * io_attention_delta:.2f}%\")\n",
    "                # should be low with pos = same, high with pos = diff\n",
    "                print(f\"NMH S1 Attention Change: {100 * s1_attention_delta:.2f}%\")\n",
    "                # shouldn't change much\n",
    "                print(f\"NMH S2 Attention Change: {100 * s2_attention_delta:.2f}%\")\n",
    "                print('\\n')\n",
    "        \n",
    "        layer, head_idx = head\n",
    "        if s2i_ablated_logit_diff_deltas['token_same_pos_oppo'][layer, head_idx] < 0 \\\n",
    "            and s2i_io_attention_deltas['token_same_pos_oppo'][layer, head_idx] < 0 \\\n",
    "            and s2i_s1_attention_deltas['token_same_pos_oppo'][layer, head_idx] > 0:\n",
    "            true_s2i_mask[layer, head_idx] = 1\n",
    "            true_s2i_heads.append(head)\n",
    "\n",
    "    # mask the deltas\n",
    "    s2i_ablated_logit_diff_deltas = {k: v * true_s2i_mask for k, v in s2i_ablated_logit_diff_deltas.items()}\n",
    "    s2i_io_attention_deltas = {k: v * true_s2i_mask for k, v in s2i_io_attention_deltas.items()}\n",
    "    s2i_s1_attention_deltas = {k: v * true_s2i_mask for k, v in s2i_s1_attention_deltas.items()}\n",
    "    s2i_s2_attention_deltas = {k: v * true_s2i_mask for k, v in s2i_s2_attention_deltas.items()}\n",
    "\n",
    "    return {\n",
    "        's2i_ablated_logit_diff_deltas': s2i_ablated_logit_diff_deltas, \n",
    "        's2i_io_attention_deltas': s2i_io_attention_deltas, \n",
    "        's2i_s1_attention_deltas': s2i_s1_attention_deltas, \n",
    "        's2i_s2_attention_deltas': s2i_s2_attention_deltas\n",
    "    }, true_s2i_heads\n",
    "        \n",
    "        \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_induction_scores(model):\n",
    "    seq_len = 100\n",
    "    batch_size = 2\n",
    "\n",
    "    prev_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=\"cuda\")\n",
    "\n",
    "    def prev_token_hook(pattern, hook):\n",
    "        layer = hook.layer()\n",
    "        diagonal = pattern.diagonal(offset=1, dim1=-1, dim2=-2)\n",
    "        prev_token_scores[layer] = einops.reduce(diagonal, \"batch head_index diagonal -> head_index\", \"mean\")\n",
    "\n",
    "    duplicate_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=\"cuda\")\n",
    "\n",
    "    def duplicate_token_hook(pattern, hook):\n",
    "        layer = hook.layer()\n",
    "        diagonal = pattern.diagonal(offset=seq_len, dim1=-1, dim2=-2)\n",
    "        duplicate_token_scores[layer] = einops.reduce(diagonal, \"batch head_index diagonal -> head_index\", \"mean\")\n",
    "\n",
    "    induction_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=\"cuda\")\n",
    "\n",
    "    def induction_hook(pattern, hook):\n",
    "        layer = hook.layer()\n",
    "        diagonal = pattern.diagonal(offset=seq_len-1, dim1=-1, dim2=-2)\n",
    "        induction_scores[layer] = einops.reduce(diagonal, \"batch head_index diagonal -> head_index\", \"mean\")\n",
    "\n",
    "    original_tokens = torch.randint(100, 20000, size=(batch_size, seq_len))\n",
    "    repeated_tokens = einops.repeat(original_tokens, \"batch seq_len -> batch (2 seq_len)\").cuda()\n",
    "\n",
    "    pattern_filter = lambda act_name: act_name.endswith(\"hook_pattern\")\n",
    "    loss = model.run_with_hooks(repeated_tokens, return_type=\"loss\", fwd_hooks=[(pattern_filter, prev_token_hook), (pattern_filter, duplicate_token_hook), (pattern_filter, induction_hook)])\n",
    "\n",
    "    return induction_scores, prev_token_scores, duplicate_token_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_induction_scores(model, checkpoint_df):\n",
    "    \n",
    "    circuit_heads = checkpoint_df[checkpoint_df['in_circuit'] == True]['source'].unique().tolist()\n",
    "    circuit_heads = [convert_head_names_to_tuple(c) for c in circuit_heads if (c[0] != 'm' and c != 'input')]\n",
    "    \n",
    "    circuit_mask = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)\n",
    "    for layer, head in circuit_heads:\n",
    "        circuit_mask[layer, head] = 1\n",
    "    \n",
    "    induction_scores, prev_token_scores, duplicate_token_scores = get_induction_scores(model)\n",
    "    induction_scores = induction_scores * circuit_mask\n",
    "    prev_token_scores = prev_token_scores * circuit_mask\n",
    "    duplicate_token_scores = duplicate_token_scores * circuit_mask\n",
    "\n",
    "    return {\n",
    "        'induction_scores': induction_scores, \n",
    "        'prev_token_scores': prev_token_scores, \n",
    "        'duplicate_token_scores': duplicate_token_scores\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(overwrite=False):\n",
    "\n",
    "    model = load_model(BASE_MODEL, VARIANT, 143000, CACHE)\n",
    "    ioi_dataset, abc_dataset = generate_data_and_caches(model, 70, verbose=True, prepend_bos=True)\n",
    "\n",
    "    for checkpoint in [5000, 10000, 30000]:\n",
    "        # check if file exists; if not, create\n",
    "        if not os.path.exists(f'results/components/{MODEL_SHORTNAME}/components_over_time.pt'):\n",
    "            os.makedirs(f'results/components/{MODEL_SHORTNAME}', exist_ok=True)\n",
    "            components_over_time = dict()\n",
    "            heads_over_time = dict()\n",
    "        else:\n",
    "            components_over_time = torch.load(f'results/components/{MODEL_SHORTNAME}/components_over_time.pt')\n",
    "            heads_over_time = torch.load(f'results/components/{MODEL_SHORTNAME}/heads_over_time.pt')\n",
    "\n",
    "        if checkpoint in components_over_time and not overwrite:\n",
    "            continue\n",
    "\n",
    "\n",
    "        model = load_model(BASE_MODEL, VARIANT, checkpoint, CACHE)\n",
    "        checkpoint_df = df[df['checkpoint'] == checkpoint].copy()\n",
    "        component_scores = dict()\n",
    "        model_heads = dict()\n",
    "\n",
    "        component_scores['direct_effect_scores'] = evaluate_direct_effect_heads(model, checkpoint_df, ioi_dataset, verbose=False)\n",
    "        nmh_list = filter_name_movers(component_scores['direct_effect_scores'], copy_score_threshold=10)\n",
    "        \n",
    "        model_heads['nmh'] = nmh_list\n",
    "        print(f\"Found {len(nmh_list)} NMHs\")\n",
    "        print(nmh_list)\n",
    "        \n",
    "        if len(nmh_list) > 0:\n",
    "            component_scores['s2i_scores'], s2i_list = evaluate_s2i_candidates(model, checkpoint_df, ioi_dataset, nmh_list, batch_size=70, verbose=False)\n",
    "            print(f\"Found {len(s2i_list)} S2I heads\")\n",
    "            print(s2i_list)\n",
    "        else:\n",
    "            component_scores['s2i_scores'] = None\n",
    "            s2i_list = []\n",
    "\n",
    "        model_heads['s2i'] = s2i_list\n",
    "\n",
    "        component_scores['tertiary_head_scores'] = evaluate_induction_scores(model, checkpoint_df)\n",
    "\n",
    "        components_over_time[checkpoint] = component_scores\n",
    "        heads_over_time[checkpoint] = model_heads\n",
    "\n",
    "        torch.save(components_over_time, f'results/components/{MODEL_SHORTNAME}/components_over_time.pt')\n",
    "        torch.save(heads_over_time, f'results/components/{MODEL_SHORTNAME}/heads_over_time.pt')\n",
    "\n",
    "    return components_over_time\n",
    "   \n",
    "components_over_time = main()\n",
    "\n",
    "# # #baseline_logit_diffs, end_s2_attention_values, baseline_nmh_s1_attention_values, new_logit_diffs, new_nmh_s1_attention_values\n",
    "# s2i_results = S2I_head_metrics(model, ioi_dataset, candidate_list, name_mover_heads, batch_size)\n",
    "\n",
    "# # our three measures are thus:\n",
    "\n",
    "# # attention (higher is better)\n",
    "# s2i_s2_attention = s2i_results['end_s2_attention_values'].mean(0)\n",
    "\n",
    "# # logit diff change (lower is better)\n",
    "# logit_diff_change = (s2i_results['new_logit_diffs'] - s2i_results['baseline_logit_diffs'].unsqueeze(1)).mean(0)\n",
    "\n",
    "# # NMH s1 attention change (higher is better)\n",
    "# nmh_s1_attention_change = (s2i_results['new_nmh_s1_attention_values'] - s2i_results['baseline_nmh_s1_attention_values'].unsqueeze(1)).mean(0).mean(-1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "components_over_time.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imshow_p(\n",
    "        component_scores['s2i_scores']['s2i_ablated_logit_diff_deltas']['token_same_pos_oppo'] * 100, \n",
    "        title=f's2i_ablated_logit_diff_deltas for {MODEL_SHORTNAME} at checkpoint {checkpoint}',\n",
    "        labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"metric_value\"},\n",
    "        border=True,\n",
    "        coloraxis=dict(colorbar_ticksuffix=\" %\"),\n",
    "        width=600,\n",
    "        margin={\"r\": 100, \"l\": 100}\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imshow_p(\n",
    "        component_scores['s2i_scores']['s2i_io_attention_deltas']['token_same_pos_oppo'] * 100, \n",
    "        title=f's2i_io_attention_deltas for {MODEL_SHORTNAME} at checkpoint {checkpoint}',\n",
    "        labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"metric_value\"},\n",
    "        border=True,\n",
    "        coloraxis=dict(colorbar_ticksuffix=\" %\"),\n",
    "        width=600,\n",
    "        margin={\"r\": 100, \"l\": 100}\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imshow_p(\n",
    "        component_scores['s2i_scores']['s2i_s1_attention_deltas']['token_same_pos_oppo'] * 100, \n",
    "        title=f's2i_s1_attention_deltas for {MODEL_SHORTNAME} at checkpoint {checkpoint}',\n",
    "        labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"metric_value\"},\n",
    "        border=True,\n",
    "        coloraxis=dict(colorbar_ticksuffix=\" %\"),\n",
    "        width=600,\n",
    "        margin={\"r\": 100, \"l\": 100}\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#component_scores['tertiary_head_scores']['induction_scores']\n",
    "imshow_p(\n",
    "        component_scores['tertiary_head_scores']['induction_scores'] * 100, \n",
    "        title=f'induction_scores_deltas for {MODEL_SHORTNAME} at checkpoint {checkpoint}',\n",
    "        labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"metric_value\"},\n",
    "        border=True,\n",
    "        coloraxis=dict(colorbar_ticksuffix=\" %\"),\n",
    "        width=600,\n",
    "        margin={\"r\": 100, \"l\": 100}\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imshow_p(\n",
    "        component_scores['tertiary_head_scores']['prev_token_scores'] * 100, \n",
    "        title=f'prev_token_scores for {MODEL_SHORTNAME} at checkpoint {checkpoint}',\n",
    "        labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"metric_value\"},\n",
    "        border=True,\n",
    "        coloraxis=dict(colorbar_ticksuffix=\" %\"),\n",
    "        width=600,\n",
    "        margin={\"r\": 100, \"l\": 100}\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imshow_p(\n",
    "        component_scores['tertiary_head_scores']['duplicate_token_scores'] * 100, \n",
    "        title=f'duplicate_token_scores for {MODEL_SHORTNAME} at checkpoint {checkpoint}',\n",
    "        labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"metric_value\"},\n",
    "        border=True,\n",
    "        coloraxis=dict(colorbar_ticksuffix=\" %\"),\n",
    "        width=600,\n",
    "        margin={\"r\": 100, \"l\": 100}\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for metric in component_scores['direct_effect_scores'].keys():\n",
    "    # if metric contains \"attn\"\n",
    "    if 'attn' in metric:\n",
    "        min = -1\n",
    "        max = 1\n",
    "    else:\n",
    "        min = None\n",
    "        max = None\n",
    "    imshow_p(\n",
    "        component_scores['direct_effect_scores'][metric], \n",
    "        title=f'{metric} for {MODEL_SHORTNAME} at checkpoint {checkpoint}',\n",
    "        labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"metric_value\"},\n",
    "        border=True,\n",
    "        coloraxis=dict(cmin=min, cmax=max),\n",
    "        width=600,\n",
    "        margin={\"r\": 100, \"l\": 100}\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "component_scores['direct_effect_scores'].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "name_mover_heads = checkpoint_nmhs[checkpoint]\n",
    "targeting_nmh = np.logical_or.reduce(np.array([df['target'] == f'a{layer}.h{head}<q>' for layer, head in name_mover_heads]))\n",
    "candidate_s2i = df[targeting_nmh]\n",
    "candidate_s2i = candidate_s2i[candidate_s2i['in_circuit'] == True]\n",
    "\n",
    "candidate_list = candidate_s2i[candidate_s2i['checkpoint']==checkpoint]['source'].unique().tolist()\n",
    "candidate_list = [convert_head_names_to_tuple(c) for c in candidate_list if (c[0] != 'm' and c != 'input')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "candidate_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for i in range(len(candidate_list)):\n",
    "#     print(f'Head {candidate_list[i]}')\n",
    "#     print(f'Attention  to S2:        {s2i_s2_attention[i]:.3f}')\n",
    "#     print(f'Logit Diff Change:       {logit_diff_change[i]:.3f}')\n",
    "#     print(f'NMH S1 Attention Change: {nmh_s1_attention_change[i]:.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "s2i_heads = candidate_list # [(7,9), (7,2), (6,6), (6,5),]\n",
    "print(s2i_heads)\n",
    "\n",
    "for head in s2i_heads:\n",
    "    s2i_token_pos_results = S2I_token_pos(model, ioi_dataset, [head], name_mover_heads, batch_size)\n",
    "    print(f'Head {head}')\n",
    "    mean_original_logit_diff = s2i_token_pos_results['ablated_logit_diffs']['token_same_pos_same'].mean()\n",
    "    mean_original_io_attention = s2i_token_pos_results['io_attention_values']['token_same_pos_same'].mean()\n",
    "    mean_original_s1_attention = s2i_token_pos_results['s1_attention_values']['token_same_pos_same'].mean()\n",
    "    mean_original_s2_attention = s2i_token_pos_results['s2_attention_values']['token_same_pos_same'].mean()\n",
    "\n",
    "    for dataset_name in patch_dataset_names:\n",
    "        print(dataset_name)\n",
    "\n",
    "        mean_ablated_logit_diff = s2i_token_pos_results['ablated_logit_diffs'][dataset_name].mean()\n",
    "        mean_ablated_io_attention = s2i_token_pos_results['io_attention_values'][dataset_name].mean()\n",
    "        mean_ablated_s1_attention = s2i_token_pos_results['s1_attention_values'][dataset_name].mean()\n",
    "        mean_ablated_s2_attention = s2i_token_pos_results['s2_attention_values'][dataset_name].mean()\n",
    "\n",
    "        logit_diff_delta = (mean_ablated_logit_diff - mean_original_logit_diff) / mean_original_logit_diff\n",
    "        print(f\"Logit diff after patching: {100 * logit_diff_delta:.2f}%\")\n",
    "        # should be high with pos = same, low with pos = diff\n",
    "\n",
    "        io_attention_delta = (mean_ablated_io_attention - mean_original_io_attention) / mean_original_io_attention\n",
    "        print(f\"NMH IO Attention Value: {100 * io_attention_delta:.2f}%\")\n",
    "        # should be low with pos = same, high with pos = diff\n",
    "\n",
    "        s1_attention_delta = (mean_ablated_s1_attention - mean_original_s1_attention) / mean_original_s1_attention\n",
    "        print(f\"NMH S1 Attention Value: {100 * s1_attention_delta:.2f}%\")\n",
    "        # shouldn't change much\n",
    "\n",
    "        s2_attention_delta = (mean_ablated_s2_attention - mean_original_s2_attention) / mean_original_s2_attention\n",
    "        print(f\"NMH S2 Attention Value: {100 * s2_attention_delta:.2f}%\")\n",
    "        print('\\n')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
