{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9ec82de-3a8e-498f-ab3d-f50ccb0a6c85",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_lens import SAE\n",
    "import torch\n",
    "import os\n",
    "from transformer_lens import HookedTransformer, FactoredMatrix\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import scipy\n",
    "from scipy.stats import spearmanr, pearsonr\n",
    "\n",
    "from stitching.losses import next_token_cross_entropy_loss, kl_div_loss\n",
    "from stitching.stitching_utils import open_experiment, load_activation_store\n",
    "from stitching.sae_utils import BaseSAE, forward_modified, generate_modified\n",
    "from stitching.generic_experiments import *\n",
    "from stitching.sae_utils import feature_activations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "759c0175-1438-4900-a89c-0ec0d2ce40de",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d36f8c5-58b3-46a1-be13-6ad2edff52dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "TRUNCATION_LENGTH = 512\n",
    "PADDING_TOKEN = 0\n",
    "import yaml\n",
    "with open('global_config.yaml') as global_stream:\n",
    "    global_cfg = yaml.safe_load(global_stream)\n",
    "CACHE_DIR = global_cfg['CACHE_DIR']\n",
    "tokenized_dataset = {}\n",
    "for dataset_key in ['train', 'test']:\n",
    "    tokenized_dataset[dataset_key] = torch.load(f'data/gpt2-small_tokenized_dataset_200000_{dataset_key}_{TRUNCATION_LENGTH}.pt', weights_only=True)\n",
    "\n",
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d5b534d-4254-4e6a-8dca-4012d1282336",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pythia_70m = HookedTransformer.from_pretrained('pythia-70m-deduped', cache_dir=CACHE_DIR, device=device)\n",
    "#pythia_70m_orig = HookedTransformer.from_pretrained('pythia-70m', cache_dir=CACHE_DIR, device=device)\n",
    "# pythia_160m = HookedTransformer.from_pretrained('pythia-160m-deduped', cache_dir=CACHE_DIR, device=device)\n",
    "gpt2_sm = HookedTransformer.from_pretrained('gpt2-small', cache_dir=CACHE_DIR, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7db27d77-c2da-442e-b312-f81f31d7e421",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = gpt2_sm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "961bf128-2803-4c57-b721-d47fba63bbc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#structural_features = torch.load('structural_features.pt', weights_only=True)\n",
    "#low_density = torch.load('low_density_features.pt', weights_only=True)\n",
    "#mid_density = torch.load('mid_density_features.pt', weights_only=True)\n",
    "#high_density = torch.load('high_density_features.pt', weights_only=True)\n",
    "#not_dead = torch.load('not_dead_features.pt', weights_only=True)\n",
    "#pth = \"pythia-70m_marks_structural.pt\"\n",
    "pth = \"gpt2-small.6_structural.pt\"\n",
    "res_dict = torch.load(pth, weights_only=True)\n",
    "structural_features, densities = res_dict['structural'], res_dict['densities']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c4eec74-674c-4439-9cb8-456538efe025",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mlp_ablation_hook(\n",
    "    value,\n",
    "    hook,\n",
    "    neuron_index,\n",
    "    ablation_value,\n",
    "):\n",
    "    #print(f\"Shape of the value tensor: {value.shape}\")\n",
    "    value[..., neuron_index] = ablation_value * 1.0\n",
    "    return value\n",
    "\n",
    "def entropy(p):\n",
    "    \"\"\"Calculates the entropy of a probability distribution.\"\"\"\n",
    "    p_log_p = torch.where(p > 0, p * torch.log(p), 0)\n",
    "    return -p_log_p.sum(dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd6fd583-7f77-43db-896b-486da4392b1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "sae = SAE.load_from_pretrained(os.path.join('gpt2-small-topk-sae-checkpoints/1t637wuk/', 'final_245760000'), device=device)\n",
    "#release = \"gpt2-small-res-jb\"\n",
    "##sae_id = \"blocks.6.hook_resid_pre\"\n",
    "#sae, cfg_dict, _ = SAE.from_pretrained(release, sae_id, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff4571ad-f7c8-4fda-9dc4-e1e6d5db8348",
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_sae = BaseSAE(\n",
    "    sae.W_enc.detach().clone(),\n",
    "    sae.W_dec.detach().clone(),\n",
    "    sae.b_enc.detach().clone(),\n",
    "    sae.b_dec.detach().clone(),\n",
    "    sae.activation_fn,\n",
    "    apply_b_dec=sae.cfg.apply_b_dec_to_input\n",
    ")\n",
    "orig_sae.normalize_decoder_vectors() #(don't normalize because they're all getting penalized with same L1 anyway.\n",
    "orig_sae.get_rid_of_decoder_sub()#\n",
    "orig_sae.layer = 6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8e72671-0dbd-4370-bc14-22d3afa5c9e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "stitched = load_activation_store('metrics', 2000, orig_sae.d_sae, subdir=f\"fr_fixed_attribution_correlation/snowy-pyramid-3/\")#f\"eleuther/{run_name_id}/\") #"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38aa0df6-7c37-4ceb-81ae-57c56b60f485",
   "metadata": {},
   "outputs": [],
   "source": [
    "# cosine similarity to null space\n",
    "U,s,Vh = torch.linalg.svd(model.W_U.detach().cpu(), full_matrices=False)\n",
    "num_vecs = U.shape[-1] // 50 # columns of U are the SVs.\n",
    "csims_to_bottom = torch.linalg.norm(sae.W_dec.cpu() @ U[:,-num_vecs:] @  U[:,-num_vecs:].T, dim=-1) #"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bb02f1e-d173-4069-b2e6-1aee222aa6e5",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Entropy Neurons in Final MLP Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e90c5ae7-d49d-4f10-9d25-8e29db961c5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_OUT = model.W_out[7].squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65a9a4d9-00ec-40ba-a6a7-132075bdf55a",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_OUT.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed83dcb8-1bc7-4266-bb5c-893ea2128d4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "logit_var = []\n",
    "with torch.no_grad():\n",
    "    for feat in tqdm(range(final_OUT.shape[0])):\n",
    "        feature = final_OUT[feat]\n",
    "        projected_feature = feature @ model.W_U\n",
    "        logit_var.append(torch.var(projected_feature / feature.norm() / model.W_U.norm(dim=0)).cpu().item())\n",
    "logit_var = np.array(logit_var)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d81c1f0-0155-461f-85d3-9258f3b32f94",
   "metadata": {},
   "outputs": [],
   "source": [
    "norms = final_OUT.norm(dim=-1).detach().cpu().numpy() + model.W_in[7].squeeze().norm(dim=0).detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b872597-24ec-4252-a213-3e6136dcbff0",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_decile_norm = np.percentile(norms, 99)\n",
    "bot_decile_logitvar = np.percentile(logit_var, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da100609-4cdd-4509-a49d-9f59bc806751",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(norms, np.log10(logit_var), s=3, alpha=0.5, marker='o')\n",
    "plt.xlabel('$\\\\|W_{in}\\\\|_2 + \\\\|W_{out}\\\\|_2$')\n",
    "plt.ylabel('log10 logit var')\n",
    "for i in np.arange(len(logit_var))[(logit_var < bot_decile_logitvar) & (norms > top_decile_norm)]:\n",
    "    plt.annotate(str(i), (norms[i], np.log10(logit_var)[i]), fontsize=7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec3b7370-ed0d-40f4-a1fc-3ff741b15202",
   "metadata": {},
   "outputs": [],
   "source": [
    "ablation_values = [0, 1, 2, 3, 4, 5, 6]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6868634d-1194-4ebe-9d7c-e0db67fae375",
   "metadata": {},
   "outputs": [],
   "source": [
    "neuron_idx = 2452\n",
    "ablation_losses = {\n",
    "    'orig_loss': [],\n",
    "    'orig_entropy': [],\n",
    "    'ablated_loss': [],\n",
    "    'ablated_entropy': [],\n",
    "    'ablated_ln_scale': []\n",
    "}\n",
    "for ablation_value in ablation_values:\n",
    "    metrics = {\n",
    "        'orig_loss': [],\n",
    "        'orig_entropy': [],\n",
    "        'ablated_loss': [],\n",
    "        'ablated_entropy': [],\n",
    "        'ablated_ln_scale': []\n",
    "    }\n",
    "    for i, sample in tqdm(enumerate(val_dataloader)):\n",
    "        mask = (sample == 0)\n",
    "        original_logits = model(sample, return_type=\"logits\")\n",
    "        def ln_scale_hook(\n",
    "            value,\n",
    "            hook\n",
    "        ):\n",
    "            metrics['ablated_ln_scale'].append(value[~mask].flatten().cpu())\n",
    "            return value\n",
    "\n",
    "        ablated_logits = model.run_with_hooks(\n",
    "            sample, \n",
    "            return_type=\"logits\",\n",
    "\n",
    "            fwd_hooks=[(\n",
    "                f'blocks.{model.cfg.n_layers - 1}.mlp.hook_post', \n",
    "                lambda x, hook: mlp_ablation_hook(x, hook, neuron_idx, ablation_value)\n",
    "                ),\n",
    "              (\n",
    "                'ln_final.hook_scale',\n",
    "                ln_scale_hook \n",
    "              )\n",
    "            ]\n",
    "        )\n",
    "        \n",
    "        metrics['orig_loss'].append(\n",
    "            next_token_cross_entropy_loss(original_logits, sample.to(device), ignore_index=0, reduction='none').flatten().cpu()\n",
    "        )\n",
    "        metrics['orig_entropy'].append(entropy(torch.nn.functional.softmax(original_logits[~mask], dim=-1)).flatten().cpu())\n",
    "        metrics['ablated_loss'].append(\n",
    "            next_token_cross_entropy_loss(ablated_logits, sample.to(device), ignore_index=0, reduction='none').flatten().cpu()\n",
    "        )\n",
    "        metrics['ablated_entropy'].append(entropy(torch.nn.functional.softmax(ablated_logits[~mask], dim=-1)).flatten().cpu())\n",
    "    for (k,v) in metrics.items():\n",
    "        ablation_losses[k].append(torch.concatenate(v).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4af7f32e-b022-4a4d-baae-6d6c81a9f5e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ablation_values, ablation_losses['orig_loss'], linestyle='dashed', color='black', alpha=0.75)\n",
    "plt.plot(ablation_values, ablation_losses['ablated_loss'], color='tab:blue', alpha=0.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdce8021-6616-4e58-8b6f-be717f88c0ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ablation_values, ablation_losses['ablated_ln_scale'], color='tab:blue', alpha=0.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90bb2e6b-539c-4837-ab3c-2eb5520b46df",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ablation_values, ablation_losses['orig_entropy'], linestyle='dashed', color='black', alpha=0.75)\n",
    "plt.plot(ablation_values, ablation_losses['ablated_entropy'], color='tab:blue', alpha=0.75)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1317761d-2798-4858-a846-ba1eb1c6ede2",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Entropy SAE features\n",
    "\n",
    "This may or may not exist. It's possible that confidence regulation only happens in the last layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb2d0374-d307-4c48-a5f5-ed56f01931d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_OUT = sae.W_dec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c31814e3-03fb-4f84-b07f-12c171d292d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "logit_var = []\n",
    "with torch.no_grad():\n",
    "    for feat in tqdm(range(final_OUT.shape[0])):\n",
    "        feature = final_OUT[feat]\n",
    "        projected_feature = feature @ model.W_U\n",
    "        logit_var.append(torch.var(projected_feature / feature.norm() / model.W_U.norm(dim=0)).cpu().item())\n",
    "logit_var = np.array(logit_var)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52bc89ed-2307-475f-9eeb-d4d3c1a32a8e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afab48d7-8880-400e-a524-727c50bc619d",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_decile_norm = np.percentile(norms, 99)\n",
    "bot_decile_logitvar = np.percentile(logit_var, 0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff0016f2-1a7a-4b42-a44c-820153ce35dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(norms, np.log10(logit_var), s=3, alpha=0.5, marker='o')\n",
    "plt.xlabel('max activation')\n",
    "plt.ylabel('log10 logit var')\n",
    "for i in np.arange(len(logit_var))[(logit_var < bot_decile_logitvar) & (norms > top_decile_norm)]:\n",
    "    plt.annotate(str(i), (norms[i], np.log10(logit_var)[i]), fontsize=7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27b2bf3d-8f5c-4a18-923d-eb3b81583a64",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(stitched['attribution_correlation'][structural_features & (densities > 1e-7)], csims_to_bottom[structural_features & (densities > 1e-7)], s=10, alpha=0.5, marker='o', color='tab:blue', label='structural')\n",
    "plt.scatter(stitched['attribution_correlation'][~structural_features & (densities > 1e-7)], csims_to_bottom[~structural_features & (densities > 1e-7)], s=10, alpha=0.5, marker='x', color='tab:orange', label='semantic')\n",
    "plt.xlabel('spearman')\n",
    "plt.ylabel('cosine to bottom 1%')\n",
    "plt.legend(bbox_to_anchor=(1.05,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22098ec0-9ca2-4978-b59b-559c27504bdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(stitched['activation_correlation'][structural_features], np.log10(logit_var)[structural_features], s=10, alpha=0.5, marker='o', color='tab:blue', label='structural')\n",
    "plt.scatter(stitched['activation_correlation'][~structural_features], np.log10(logit_var)[~structural_features], s=10, alpha=0.5, marker='x', color='tab:orange', label='semantic')\n",
    "plt.xlabel('spearman')\n",
    "plt.ylabel('log10 logitvar')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f29a9ba-fbbf-490f-9b1b-884710ff6bec",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### ablations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ccf55d5-f4c3-4be1-b575-48af6822f6bf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "276b1946-b31b-4302-9475-c9324058bdcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_high_cosine_similarity(matrix, threshold=0.99):\n",
    "    # Normalize the rows of the matrix\n",
    "    norms = np.linalg.norm(matrix, axis=1, keepdims=True)\n",
    "    normalized_matrix = matrix / norms\n",
    "\n",
    "    # Initialize the count\n",
    "    count = 0\n",
    "\n",
    "    # Compare each vector with subsequent vectors\n",
    "    n = normalized_matrix.shape[0]\n",
    "    indices = []\n",
    "    for i in tqdm(range(n)):\n",
    "        similarities = np.dot(normalized_matrix[i], normalized_matrix[i + 1:].T)\n",
    "        count += np.sum(similarities > threshold)\n",
    "        if np.sum(similarities > threshold) > 0:\n",
    "            indices.append(i)\n",
    "    return count, indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68842d8d-6dc6-45df-8671-4c7c2ab1c1d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "count_high_cosine_similarity(orig_sae.W_dec.detach().cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba45928e-bdf9-45b0-b54e-2d9224d20d8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "ablation_values = [0, 50, 100, 200]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd6d93f5-dd74-4eff-a3f0-e8e455207b98",
   "metadata": {},
   "outputs": [],
   "source": [
    "neuron_idx = 16408\n",
    "\n",
    "densities[neuron_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89c4de6c-29c7-4785-835a-e59f7567952b",
   "metadata": {},
   "outputs": [],
   "source": [
    "stitched['activation_correlation'][neuron_idx], stitched['attribution_correlation'][neuron_idx], structural_features[neuron_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "455c8b21-3b97-444d-832f-29999b3d743c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#neuron_idx = 30099\n",
    "ablation_losses = {\n",
    "    'orig_loss': [],\n",
    "    'orig_entropy': [],\n",
    "    'ablated_loss': [],\n",
    "    'ablated_entropy': [],\n",
    "    #'ablated_ln_scale': []\n",
    "}\n",
    "for ablation_value in ablation_values:\n",
    "    metrics = {\n",
    "        'orig_loss': [],\n",
    "        'orig_entropy': [],\n",
    "        'ablated_loss': [],\n",
    "        'ablated_entropy': [],\n",
    "        #'ablated_ln_scale': []\n",
    "    }\n",
    "    for i, sample in tqdm(enumerate(val_dataloader)):\n",
    "        mask = (sample == 50256)\n",
    "        original_latents = model(sample, stop_at_layer=orig_sae.layer)\n",
    "        #def ln_scale_hook(\n",
    "        #    value,\n",
    "        #    hook\n",
    "        #):\n",
    "        #    metrics['ablated_ln_scale'].append(value[~mask].flatten().cpu())\n",
    "        #    return value\n",
    "        \n",
    "        original_logits = model(original_latents, start_at_layer=orig_sae.layer)\n",
    "        acts = orig_sae.encode(original_latents)\n",
    "        error = original_latents - orig_sae.decode(acts)\n",
    "        print(acts[...,neuron_idx])\n",
    "        acts[..., neuron_idx] = ablation_value\n",
    "        ablated_latents = orig_sae.decode(acts) + error\n",
    "        ablated_logits = model(ablated_latents, start_at_layer=orig_sae.layer)\n",
    "        \n",
    "        metrics['orig_loss'].append(\n",
    "            next_token_cross_entropy_loss(original_logits, sample.to(device), ignore_index=50256, reduction='none').flatten().cpu()\n",
    "        )\n",
    "        metrics['orig_entropy'].append(entropy(torch.nn.functional.softmax(original_logits[~mask], dim=-1)).flatten().cpu())\n",
    "        metrics['ablated_loss'].append(\n",
    "            next_token_cross_entropy_loss(ablated_logits, sample.to(device), ignore_index=50256, reduction='none').flatten().cpu()\n",
    "        )\n",
    "        metrics['ablated_entropy'].append(entropy(torch.nn.functional.softmax(ablated_logits[~mask], dim=-1)).flatten().cpu())\n",
    "    for (k,v) in metrics.items():\n",
    "        ablation_losses[k].append(torch.concatenate(v).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "167ee49a-697f-40b8-85dc-27e605d80bc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ablation_values, ablation_losses['orig_loss'], linestyle='dashed', color='black', alpha=0.75)\n",
    "plt.plot(ablation_values, ablation_losses['ablated_loss'], color='tab:blue', alpha=0.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b5d4060-1ee0-4aa3-aa99-cc7b2a33cb15",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ablation_values, ablation_losses['orig_entropy'], linestyle='dashed', color='black', alpha=0.75)\n",
    "plt.plot(ablation_values, ablation_losses['ablated_entropy'], color='tab:blue', alpha=0.75)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4a27781-b661-49e4-873e-ed8c7c28d27a",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Entropy Transfer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7f4d86e-06e9-4df9-bd52-b2573742cfc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stitching.stitching_utils import open_experiment, load_activation_store\n",
    "from stitching.sae_utils import BaseSAE, forward_modified, generate_modified, max_csim_transfer_to_orig, convert_eleuther_sae_to_BaseSAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b21e8e6c-5530-408b-bd6d-598bb0aa84cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "norms = stitched['max_activation_orig'] #sae.W_dec.norm(dim=-1).detach().cpu().numpy() + sae.W_enc.norm(dim=0).detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94e82b95-637e-4683-b15f-057c021313fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "transfer_model = HookedTransformer.from_pretrained('gpt2-medium', cache_dir=CACHE_DIR, device=model.cfg.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "998d60ed-f783-44c1-9f91-9b880a7f1148",
   "metadata": {},
   "outputs": [],
   "source": [
    "P, Pinv, beta, bias, biasinv = open_experiment(model.cfg.d_model, transfer_model.cfg.d_model, 'checkpoints/stitch_training_{model_cfg['model_a_name']}_to_{model_cfg['model_b_name']}_bidirectional_mse', 'snowy-pyramid-3', biases=True,device=device)\n",
    "\n",
    "transferred_sae = BaseSAE(\n",
    "    Pinv @ orig_sae.W_enc.detach().clone(),\n",
    "    orig_sae.W_dec.detach().clone() @ P,\n",
    "    orig_sae.b_enc.detach().clone() + biasinv @ orig_sae.W_enc.detach().clone(),\n",
    "    orig_sae.b_dec.detach().clone() @ P + bias,\n",
    "    orig_sae.activation_fn,\n",
    "    apply_b_dec=False\n",
    ")\n",
    "transferred_sae.normalize_decoder_vectors()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76072c47-a583-4db6-b4cb-d8520063cffd",
   "metadata": {},
   "outputs": [],
   "source": [
    "transfer_layer = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b47b0878-11f3-41f0-ab16-b01d3eb3c888",
   "metadata": {},
   "outputs": [],
   "source": [
    "# cosine similarity to null space\n",
    "tU,ts,tVh = torch.linalg.svd(transfer_model.W_U.detach().cpu(), full_matrices=False)\n",
    "tnum_vecs = tU.shape[-1] // 50 # columns of U are the SVs.\n",
    "tcsims_to_bottom = torch.linalg.norm(transferred_sae.W_dec.cpu() @ tU[:,-tnum_vecs:] @  tU[:,-tnum_vecs:].T, dim=-1) #"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7b77659-e961-488b-a888-5d9c3b3681b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "neuron_idx = 4034\n",
    "perc = (stitched['max_activation_orig'] <= stitched['max_activation_orig'][neuron_idx]).sum() / len(stitched['max_activation_orig'])\n",
    "csim = csims_to_bottom[neuron_idx]\n",
    "perc_stitch = (stitched['max_activation_stitched'] <= stitched['max_activation_stitched'][neuron_idx]).sum() / len(stitched['max_activation_stitched'])\n",
    "csim_stitch = tcsims_to_bottom[neuron_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84a9dde5-b28c-4429-99a7-e4b819d864f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "perc, csim, perc_stitch, csim_stitch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3b15b33-2ed6-4528-8847-19712093c96d",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_norm = np.percentile(norms, 98)\n",
    "top_csims = np.percentile(csims_to_bottom, 98)\n",
    "indices = (norms > top_norm) & (csims_to_bottom.numpy() > top_csims)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de310ab7-a0af-4bef-b7a5-1aba09de88de",
   "metadata": {},
   "outputs": [],
   "source": [
    "not_dead = stitched['max_activation_orig'] > 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb5ff691-c010-49ca-967d-0ea8f6219b70",
   "metadata": {},
   "outputs": [],
   "source": [
    "indices.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6706c40-c52b-4266-b35a-26e3b178f54d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0c9b8f9-3783-4326-8196-980d7fbe1955",
   "metadata": {},
   "outputs": [],
   "source": [
    "random_sample = np.random.choice(np.arange(len(indices)), 2000)\n",
    "mask = np.zeros(len(indices), dtype=bool)\n",
    "mask[random_sample] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18c39d96-6617-4654-b4d4-f925189545a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9ec3b63-3bd7-45e4-bbde-39d333295bd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams.update({'font.size': 14})\n",
    "fig, ax = plt.subplots(1,2, figsize=(11,4))\n",
    "\n",
    "ax[0].scatter(stitched['max_activation_orig'][(mask & not_dead) & ~indices], csims_to_bottom[(mask & not_dead) & ~indices], s=3, alpha=0.5)\n",
    "ax[0].scatter(stitched['max_activation_orig'][indices], csims_to_bottom[indices], color='red', marker='x', s=50, label='entropy')\n",
    "ax[0].legend()\n",
    "ax[0].set_xlabel('max activation')\n",
    "ax[0].set_ylabel('fraction of norm in eff null space')\n",
    "ax[0].set_title('gpt2-small')\n",
    "ax[1].scatter(stitched['max_activation_stitched'][(mask & not_dead) &~indices], tcsims_to_bottom[(mask & not_dead) &~indices], s=3, alpha=0.5)\n",
    "ax[1].scatter(stitched['max_activation_stitched'][indices], tcsims_to_bottom[indices], color='red', marker='x', s=50, label='entropy')\n",
    "ax[1].legend()\n",
    "ax[1].set_xlabel('max activation')\n",
    "ax[1].set_ylabel('fraction of norm in eff null space')\n",
    "ax[1].set_title('gpt2-medium')\n",
    "ax[0].set_ylim([0,1])\n",
    "ax[1].set_ylim([0,1])\n",
    "\n",
    "plt.savefig('results/figures/entropy_candidates.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "647d11f2-311b-4de2-aca1-1e746d5d64c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "stitched['max_activation_stitched'].argsort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ae8be5d-cd3b-4d26-a171-3cd1aa5abe0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sort(stitched['max_activation_stitched'][stitched['max_activation_stitched'] > 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8c7ea51-b71c-4742-86d4-4cf72e501f76",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(csims_to_bottom, tcsims_to_bottom, s=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5123aaeb-e60f-4b5f-b749-8a317abd4a5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(stitched['max_activation_orig'], stitched['max_activation_stitched'], s=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00801ad1-0f13-4105-8005-03ae77c3b647",
   "metadata": {},
   "source": [
    "# Attention Deactivation\n",
    "\n",
    "Heuristic score is pretending the output MLP vector is the query vector for the key on the BOS token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc5e6438-6bc1-4d4a-b3eb-c848a2e9bb4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, BOS_CACHE = model.run_with_cache(\"test\")\n",
    "# _, BOS_CACHE = gpt2_sm.run_with_cache(\"test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b6885b2-a258-4107-9108-2f4c2458cc5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def heuristic_score_no_rope(model, mlp_out, attention_head, BOS_CACHE):\n",
    "    # mlp layer is a layer, attention head is a \n",
    "    head_layer, head_index = attention_head\n",
    "    W_out = mlp_out / mlp_out.norm(dim=-1, keepdim=True) # (shape 4xd_m, d_m)\n",
    "    W_Q = model.W_Q[head_layer, head_index]\n",
    "    k_BOS = BOS_CACHE[f'blocks.{head_layer}.attn.hook_k'][0, 0, head_index].squeeze()\n",
    "    heuristics = W_out @ W_Q @ k_BOS\n",
    "    return heuristics.cpu()\n",
    "\n",
    "def heuristic_score_with_rope(model, mlp_layer, attention_head):\n",
    "    # mlp layer is a layer, attention head is a \n",
    "    head_layer, head_index = attention_head\n",
    "    W_out = model.W_out[mlp_layer] / model.W_out[mlp_layer].norm(dim=-1, keepdim=True) # (shape 4xd_m, d_m)\n",
    "    W_Q = model.W_Q[head_layer]\n",
    "    k_BOS = BOS_CACHE[f'blocks.{head_layer}.attn.hook_rot_k'][0, 0, head_index].squeeze()\n",
    "    sequenced = torch.einsum('ij,kjm->ikm', W_out, W_Q).unsqueeze(1).repeat(1,100,1,1)\n",
    "    heuristics = model.blocks[head_layer].attn.apply_rotary(sequenced)[:,:,head_index] @ k_BOS\n",
    "    print(heuristics.shape)\n",
    "    return heuristics.cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21f1959b-57f2-4ee8-8608-53c7dc42d236",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## MLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1972f2f2-2b47-4f92-b832-3f1dab97a313",
   "metadata": {},
   "outputs": [],
   "source": [
    "mlp_layer = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fc6be57-209a-4e3d-abce-8a5d5c67c650",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "heehee = []\n",
    "heehee_random = []\n",
    "for layer in range(mlp_layer+1, model.cfg.n_layers):\n",
    "    for head_index in range(model.W_Q.shape[1]):\n",
    "        cur = heuristic_score_no_rope(model, model.W_out[mlp_layer], (layer,head_index)).numpy()\n",
    "        #cur = heuristic_score_with_rope(model, model.W_out[mlp_layer], (layer,head_index)).numpy()\n",
    "\n",
    "        random = heuristic_score_no_rope(model, torch.nn.init.normal_(torch.zeros_like(model.W_out[mlp_layer])), (layer,head_index)).numpy()\n",
    "        #cur = heuristic_score_no_rope(gpt2_sm, 2, (layer,head_index)).numpy()\n",
    "        #cur = heuristic_score_with_rope(pythia_70m, 2, (layer,head_index)).numpy()\n",
    "        print(layer, head_index, np.arange(len(cur))[cur > 3])\n",
    "        heehee.append(cur)\n",
    "        heehee_random.append(random)\n",
    "heehee = np.concatenate(heehee, axis=0)\n",
    "heehee_random = np.concatenate(heehee_random, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b679095-c6eb-4a44-bb28-7bb2c9d82eeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(heehee.flatten(), log=True, alpha=0.75)\n",
    "sns.histplot(heehee_random.flatten(), log=True, alpha=0.75)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ea626ea-cef4-42c8-9310-b3e5ffd96bf3",
   "metadata": {},
   "source": [
    "Ablation experiment\n",
    "Pick a sample dataset.\n",
    "Only track when you're in the second half of the context.\n",
    "Two forward passes.\n",
    "- Original: track the attention placed from the token to the BOS token.\n",
    "- Ablate the neuron to 0 in the forward pass and get the same thing.\n",
    "- "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b03a60d5-3157-4ad6-ad91-7610176fa38b",
   "metadata": {},
   "outputs": [],
   "source": [
    "neuron_idx = 2848  \n",
    "head_layer = 6\n",
    "head_idx = 9\n",
    "metrics = {\n",
    "    'bos_diffs': [],\n",
    "    'orig_activations': []\n",
    "}\n",
    "for i, sample in enumerate(val_dataloader):\n",
    "    mask = (torch.arange(512) > 256) & (sample != model.tokenizer.pad_token_id)\n",
    "    logits, unablated_cache = model.run_with_cache(sample, names_filter=[f\"blocks.{mlp_layer}.mlp.hook_post\", f\"blocks.{head_layer}.attn.hook_pattern\"], stop_at_layer=head_layer+1)\n",
    "    orig_activation = unablated_cache[f\"blocks.{mlp_layer}.mlp.hook_post\"][mask][..., neuron_idx]\n",
    "    unablated_pattern = unablated_cache[f\"blocks.{head_layer}.attn.hook_pattern\"][:, head_idx].squeeze()\n",
    "    unablated_pattern = unablated_pattern[mask][..., 0] # (batch, head, source, dest)\n",
    "    ablated_intermediate_logits = model.run_with_hooks(\n",
    "        sample,\n",
    "        stop_at_layer=head_layer,\n",
    "        fwd_hooks=[(\n",
    "            f'blocks.{mlp_layer}.mlp.hook_post', \n",
    "            lambda x, hook: mlp_ablation_hook(x, hook, neuron_idx, 0)\n",
    "        )]\n",
    "    )\n",
    "    logits, ablated_cache = model.run_with_cache(ablated_intermediate_logits, start_at_layer=head_layer, names_filter=f\"blocks.{head_layer}.attn.hook_pattern\")\n",
    "    ablated_pattern = ablated_cache[f\"blocks.{head_layer}.attn.hook_pattern\"][:, head_idx].squeeze()\n",
    "    ablated_pattern = ablated_pattern[mask][..., 0] # (batch, head, source, dest)\n",
    "    metrics['bos_diffs'].append((unablated_pattern - ablated_pattern).cpu())\n",
    "    metrics['orig_activations'].append(orig_activation.cpu())\n",
    "for (k,v) in metrics.items():\n",
    "    metrics[k] = torch.cat(v).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16e65151-92cd-4de6-aac5-900fa26bd450",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(x=metrics['orig_activations'], y=metrics['bos_diffs'], s=3)\n",
    "plt.grid(True)\n",
    "plt.ylabel(\"p(BOS) orig - p(BOS) ablated\")\n",
    "plt.xlabel(\"original activation\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8ae3804-3314-44d5-a4b3-bfd4f5d4674c",
   "metadata": {},
   "outputs": [],
   "source": [
    "U,s,Vh = torch.linalg.svd(model.W_U.detach().cpu(), full_matrices=False)\n",
    "num_vecs = U.shape[-1] // 20 # columns of U are the SVs.\n",
    "csims_to_bottom = torch.linalg.norm((model.W_out[5].cpu() / model.W_out[5].cpu().norm(dim=-1, keepdim=True)) @ U[:,-num_vecs:] @  U[:,-num_vecs:].T, dim=-1) #"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f22e0411-be8d-486f-b75d-d353bbc0a548",
   "metadata": {},
   "outputs": [],
   "source": [
    "csims_to_bottom.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "492338de-ac48-445d-acc2-950705985169",
   "metadata": {},
   "outputs": [],
   "source": [
    "csims_to_bottom[2848]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4743c6b9-76dc-4a52-a91a-1c4b54468b0d",
   "metadata": {},
   "source": [
    "Okay, this might be the pivot of the century.\n",
    "Using gpt might just save our future.\n",
    "\n",
    "I don't have the stitching layer yet but we know at least that these phenomenon exist."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5aac0c52-61e6-4642-b8fb-586bf1653677",
   "metadata": {},
   "source": [
    "## SAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7067c8e3-dfec-4798-a571-daa6960e6613",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "heehee = {}\n",
    "heehee_random = {}\n",
    "sae_layer = 6\n",
    "for layer in range(sae_layer, model.cfg.n_layers):\n",
    "    for head_index in range(model.W_Q.shape[1]):\n",
    "        cur = heuristic_score_no_rope(model, sae.W_dec, (layer,head_index), BOS_CACHE).numpy()\n",
    "        random = heuristic_score_no_rope(model, torch.nn.init.normal_(torch.zeros_like(sae.W_dec)), (layer,head_index), BOS_CACHE).numpy()\n",
    "        #if (layer, head_index) == (8,7):\n",
    "        #    print(np.argmax(cur), cur[np.argmax(cur)])\n",
    "        heehee[(layer,head_index)] = cur\n",
    "        heehee_random[(layer,head_index)] = random\n",
    "#heehee = np.concatenate(heehee, axis=0)\n",
    "#heehee_random = np.concatenate(heehee_random, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "382659f2-c540-4c1a-8dd0-10415db3a7f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "heehee = {}\n",
    "neuron_idx = 27380  \n",
    "sae_layer = 6\n",
    "\n",
    "for layer in range(sae_layer, model.cfg.n_layers):\n",
    "    for head_index in range(model.W_Q.shape[1]):\n",
    "        cur = heuristic_score_no_rope(model, sae.W_dec[neuron_idx], (layer,head_index), BOS_CACHE).numpy()\n",
    "        random = heuristic_score_no_rope(model, torch.nn.init.normal_(torch.zeros_like(sae.W_dec)), (layer,head_index), BOS_CACHE).numpy()\n",
    "\n",
    "        #print(layer, head_index, np.arange(len(cur))[cur > 5])\n",
    "        heehee[(layer,head_index)] = cur\n",
    "        heehee_random[(layer,head_index)] = random\n",
    "#heehee = np.concatenate(heehee, axis=0)\n",
    "#heehee_random = np.concatenate(heehee_random, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75528692-26bc-46a6-8330-756bdcd9900a",
   "metadata": {},
   "outputs": [],
   "source": [
    "heehee"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a548de2c-89b7-41c8-b940-cd9595cc9c15",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_randoms = []\n",
    "for key, scores in heehee_random.items():\n",
    "    all_randoms.append(scores)\n",
    "all_randoms = np.concatenate(all_randoms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a17a353-ae0b-4f70-8c40-0206d6a0187b",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_randoms.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80eecb27-919c-4bd4-b7a4-1d766aba8aa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "thresh = np.percentile(all_randoms, 99.9)#(1 - (50 / len(all_randoms))) * 100)\n",
    "thresh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1704e219-dcb1-46d2-88c7-910fea54f4e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.percentile(all_randoms, 99.9), np.percentile(all_randoms, 0.1), \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5bb7fa8-71fa-4318-989d-f022485c9455",
   "metadata": {},
   "outputs": [],
   "source": [
    "heehee"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79d8d914-7d49-4d6c-95d8-27fd4e6efe9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_dataloader = torch.utils.data.DataLoader(\n",
    "    tokenized_dataset['train'][:500],\n",
    "    batch_size=1,\n",
    "    shuffle=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c73a04a7-66b8-4f9b-ae28-f21e8c4906bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "neuron_idx = 27380  \n",
    "head_layer = 6\n",
    "head_idx = 10\n",
    "metrics = {\n",
    "    'bos_diffs': [],\n",
    "    'orig_activations': []\n",
    "}\n",
    "tot_count = 0\n",
    "count = 0\n",
    "all_token_indices = {}\n",
    "for i, sample in tqdm(enumerate(val_dataloader)):\n",
    "    #mask = (torch.arange(512) > 128) & (sample != model.tokenizer.pad_token_id) # get second half mask\n",
    "    orig_sae_latents = model(sample, stop_at_layer=sae_layer)\n",
    "    orig_acts = sae.encode(orig_sae_latents)[..., [neuron_idx]]\n",
    "    orig_attn_latents = model(orig_sae_latents, start_at_layer=sae_layer, stop_at_layer=head_layer)\n",
    "    #print(orig_acts[..., [neuron_idx]].shape, sae.W_dec[[neuron_idx]].shape)\n",
    "    # noticeably all other paths are kept as the clean run. we just removing the contribution of the neuron.\n",
    "    \n",
    "    _, unablated_cache = model.run_with_cache(orig_attn_latents, names_filter=[f\"blocks.{head_layer}.attn.hook_pattern\"], start_at_layer=head_layer, stop_at_layer=head_layer+1)\n",
    "    unablated_pattern = unablated_cache[f\"blocks.{head_layer}.attn.hook_pattern\"][:, head_idx]\n",
    "    unablated_pattern = unablated_pattern[..., 0] # (batch, head, source, dest)\n",
    "    \n",
    "    token_indexes = np.random.choice(np.arange(64, 256), 50)\n",
    "    all_token_indices[i] = token_indexes\n",
    "    \n",
    "    for tok in token_indexes:  # tuples of seq_pos\n",
    "        if sample[0, tok] == model.tokenizer.pad_token_id:\n",
    "            continue\n",
    "        ablated_attn_latents = orig_attn_latents.detach().clone()\n",
    "        ablated_attn_latents[:, tok] = orig_attn_latents[:, tok] - orig_acts[:, tok].item() * sae.W_dec[[neuron_idx]]  # take out specifically just the ablated, and just at current token\n",
    "        _, ablated_cache = model.run_with_cache(ablated_attn_latents, start_at_layer=head_layer, names_filter=f\"blocks.{head_layer}.attn.hook_pattern\", stop_at_layer=head_layer+1)\n",
    "        \n",
    "        ablated_pattern = ablated_cache[f\"blocks.{head_layer}.attn.hook_pattern\"][:, head_idx]\n",
    "\n",
    "        ablated_pattern = ablated_pattern[0, tok, 0] # (batch, head, source, dest)\n",
    "        metrics['bos_diffs'].append((unablated_pattern[0, tok] - ablated_pattern).item())\n",
    "        metrics['orig_activations'].append(orig_acts[0,tok].item())\n",
    "    #tot_count += orig_activation.shape[0]\n",
    "    #count += (orig_activation > 0).sum()\n",
    "for (k,v) in metrics.items():\n",
    "    metrics[k] = np.array(v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8382bb1-4dfb-4409-b91a-ec49758036df",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics['orig_activations'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "974170f9-4e63-4985-b2c1-0f5995bb8154",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(x=metrics['orig_activations'], y=metrics['bos_diffs'], s=3)\n",
    "plt.grid(True)\n",
    "plt.ylabel(\"p(BOS) orig - p(BOS) ablated\")\n",
    "plt.xlabel(\"original activation\")\n",
    "#plt.savefig('attention_deactivation_orig.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20299158-662a-4156-be92-8c1bc5b06453",
   "metadata": {},
   "outputs": [],
   "source": [
    "densities[neuron_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f527de5-dee3-492b-909c-c9af6af465e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "stitched['activation_correlation'][neuron_idx], stitched['attribution_correlation'][neuron_idx], structural_features[neuron_idx], csims_to_bottom[neuron_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91718135-42df-47eb-b908-14c1e23e7265",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.histplot(csims_to_bottom)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92b97fb4-e149-457b-a5f8-170ac1554e11",
   "metadata": {},
   "source": [
    "## Transfer. Does there exist a head that this still happens?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44e62ccb-8b22-44d5-ba73-d88cc5a32a5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stitching.stitching_utils import open_experiment, load_activation_store\n",
    "from stitching.sae_utils import BaseSAE, forward_modified, generate_modified, max_csim_transfer_to_orig, convert_eleuther_sae_to_BaseSAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a83d6d8-a8a7-4f8c-8c34-490ea147b2c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "transfer_model = HookedTransformer.from_pretrained('gpt2-medium', cache_dir=CACHE_DIR, device=model.cfg.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b265b541-a4f8-4230-a60a-4f0616595eb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "P, Pinv, beta, bias, biasinv = open_experiment(model.cfg.d_model, transfer_model.cfg.d_model, 'checkpoints/stitch_training_gpt2-small_to_gpt2-medium_bidirectional_mse', 'snowy-pyramid-3', biases=True,device=device)\n",
    "\n",
    "transferred_sae = BaseSAE(\n",
    "    Pinv @ orig_sae.W_enc.detach().clone(),\n",
    "    orig_sae.W_dec.detach().clone() @ P,\n",
    "    orig_sae.b_enc.detach().clone() + biasinv @ orig_sae.W_enc.detach().clone(),\n",
    "    orig_sae.b_dec.detach().clone() @ P + bias,\n",
    "    orig_sae.activation_fn,\n",
    "    apply_b_dec=False\n",
    ")\n",
    "transferred_sae.normalize_decoder_vectors()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "446d309b-f0c2-42fa-a96c-63a55d2678cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "transfer_layer = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56b61b0c-a348-4efc-92ba-ae5162511923",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, BOS_CACHE = transfer_model.run_with_cache(\"test\")\n",
    "# _, BOS_CACHE = gpt2_sm.run_with_cache(\"test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74104f8b-23c3-40ba-ace1-f6f8493cdc08",
   "metadata": {},
   "outputs": [],
   "source": [
    "heehee = {}\n",
    "neuron_idx = 27380  \n",
    "#neuron_idx = 30099 \n",
    "#heehee_random = []\n",
    "sae_layer = transfer_layer - 1\n",
    "for layer in range(sae_layer+1, transfer_model.cfg.n_layers):\n",
    "    for head_index in range(transfer_model.W_Q.shape[1]):\n",
    "        cur = heuristic_score_no_rope(transfer_model, transferred_sae.W_dec[neuron_idx], (layer,head_index), BOS_CACHE).numpy()\n",
    "        random = heuristic_score_no_rope(transfer_model, torch.nn.init.normal_(torch.zeros_like(transferred_sae.W_dec)), (layer,head_index), BOS_CACHE).numpy()\n",
    "\n",
    "        #print(layer, head_index, np.arange(len(cur))[cur > 5])\n",
    "        heehee[(layer,head_index)] = cur\n",
    "        heehee_random[(layer,head_index)] = random\n",
    "#heehee = np.concatenate(heehee, axis=0)\n",
    "#heehee_random = np.concatenate(heehee_random, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5933f07-b122-4911-b115-cadc6eb92830",
   "metadata": {},
   "outputs": [],
   "source": [
    "lol = -float('inf')\n",
    "best = None\n",
    "for keys in heehee.keys():\n",
    "    if heehee[keys] > lol:\n",
    "        best = keys\n",
    "        lol = heehee[keys]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d579932-0edd-4829-a5cb-8b5eab66b08e",
   "metadata": {},
   "outputs": [],
   "source": [
    "best"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0ab8932-bb70-4f4a-8c61-59dcf40ac94c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "heehee"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba55db2d-bf8c-43bc-8876-e3e5bfc2ca87",
   "metadata": {},
   "outputs": [],
   "source": [
    "lol = []\n",
    "for arr in heehee_random.values():\n",
    "    lol.append(arr)\n",
    "lol = np.concatenate(lol).flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6434da2c-bd7d-465a-846a-01688c7eb7c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.percentile(lol, 0.1), np.percentile(lol, 99.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bf00400-9fc2-45bd-9ce2-019c3f87e6f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "head_layer = 11\n",
    "head_idx = 2 # 10 11, 11 1, 11 13,  12 1, 13 12\n",
    "transfer_metrics = {\n",
    "    'bos_diffs': [],\n",
    "    'orig_activations': []\n",
    "}\n",
    "tot_count = 0\n",
    "count = 0\n",
    "for i, sample in tqdm(enumerate(val_dataloader)):\n",
    "    orig_sae_latents = transfer_model(sample, stop_at_layer=transfer_layer)\n",
    "    orig_acts = transferred_sae.encode(orig_sae_latents)[..., [neuron_idx]]\n",
    "    orig_attn_latents = transfer_model(orig_sae_latents, start_at_layer=transfer_layer, stop_at_layer=head_layer)\n",
    "    #print(orig_acts[..., [neuron_idx]].shape, sae.W_dec[[neuron_idx]].shape)\n",
    "    # noticeably all other paths are kept as the clean run. we just removing the contribution of the neuron.\n",
    "    _, unablated_cache = transfer_model.run_with_cache(orig_attn_latents, names_filter=[f\"blocks.{head_layer}.attn.hook_pattern\"], start_at_layer=head_layer, stop_at_layer=head_layer+1)\n",
    "    unablated_pattern = unablated_cache[f\"blocks.{head_layer}.attn.hook_pattern\"][:, head_idx]\n",
    "    unablated_pattern = unablated_pattern[..., 0] # (batch, head, source, dest)\n",
    "    token_indexes = all_token_indices[i] #np.random.choice(np.arange(128, 512), 50)\n",
    "    for tok in token_indexes:  # tuples of seq_pos\n",
    "        if sample[0, tok] == transfer_model.tokenizer.pad_token_id:\n",
    "            continue\n",
    "        ablated_attn_latents = orig_attn_latents.detach().clone()\n",
    "        ablated_attn_latents[:, tok] = orig_attn_latents[:, tok] - orig_acts[:, tok].item() * transferred_sae.W_dec[[neuron_idx]]  # take out specifically just the ablated, and just at current token\n",
    "        _, ablated_cache = transfer_model.run_with_cache(ablated_attn_latents, start_at_layer=head_layer, names_filter=f\"blocks.{head_layer}.attn.hook_pattern\", stop_at_layer=head_layer+1)\n",
    "        \n",
    "        ablated_pattern = ablated_cache[f\"blocks.{head_layer}.attn.hook_pattern\"][:, head_idx]\n",
    "\n",
    "        ablated_pattern = ablated_pattern[0, tok, 0] # (batch, head, source, dest)\n",
    "        transfer_metrics['bos_diffs'].append((unablated_pattern[0, tok] - ablated_pattern).item())\n",
    "        transfer_metrics['orig_activations'].append(orig_acts[0,tok].item())\n",
    "    #tot_count += orig_activation.shape[0]\n",
    "    #count += (orig_activation > 0).sum()\n",
    "for (k,v) in transfer_metrics.items():\n",
    "    transfer_metrics[k] = np.array(v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3876c75-b1fc-4e7b-a79c-1720865d7bc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "subset = (metrics['orig_activations'] > 0) #& (transfer_metrics['orig_activations'] > 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91e8e553-c04d-44c6-97bc-11e586e06136",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({'font.size': 14})\n",
    "lims = [-0.6, 0.3]\n",
    "fig,ax = plt.subplots(1,2, figsize=(15,5))\n",
    "ax[0].scatter(x=metrics['orig_activations'][subset], y=metrics['bos_diffs'][subset], s=10, alpha=0.5)\n",
    "ax[0].set(title='gpt2-small (6,10)', xlabel='original activation', ylabel='p(BOS) orig - p(BOS) ablated')\n",
    "ax[0].set_ylim(lims)\n",
    "ax[1].scatter(x=transfer_metrics['orig_activations'][subset], y=transfer_metrics['bos_diffs'][subset], s=10, alpha=0.5)\n",
    "ax[1].set(title=f'gpt2-medium ({head_layer},{head_idx})')\n",
    "ax[1].set_ylabel(\"p(BOS) orig - p(BOS) ablated\")\n",
    "ax[1].set_xlabel(\"original activation\")\n",
    "ax[1].set_ylim(lims)\n",
    "#ax[0].grid(True)\n",
    "#ax[1].grid(True)\n",
    "ax[0].axhline(0, color='black', alpha=0.5)\n",
    "ax[1].axhline(0, color='black', alpha=0.5)\n",
    "ax[0].axvline(0, color='black', alpha=0.5)\n",
    "ax[1].axvline(0, color='black', alpha=0.5)\n",
    "plt.savefig('results/figures/attention_deactivation_experiment_flipped.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e29a4c4-3891-4d9b-a99c-80286fc29725",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
