{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21ffcd6f-1327-4a72-9422-a4b38115b688",
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "from transformer_lens import HookedTransformer\n",
    "import torch\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "from tqdm import tqdm\n",
    "from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner, SAE\n",
    "from dataclasses import dataclass\n",
    "from stitching.stitching_utils import open_experiment\n",
    "from stitching.losses import next_token_cross_entropy_loss\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "device = \"cuda\"\n",
    "\n",
    "print(\"Using device:\", device)\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "import yaml\n",
    "with open('global_config.yaml') as global_stream:\n",
    "    global_cfg = yaml.safe_load(global_stream)\n",
    "CACHE_DIR = global_cfg['CACHE_DIR']\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fe61c8f-fcaf-4b59-b5e2-0d0c7fcfc82e",
   "metadata": {},
   "outputs": [],
   "source": [
    "TRUNCATION_LENGTH = 512\n",
    "tokenized_dataset = {}\n",
    "modelA_name = 'gpt2-small'#'pythia-70m-deduped' \n",
    "modelB_name = 'gpt2-medium' #'pythia-160m-deduped'#'gpt2-medium' \n",
    "layerA = 6#3\n",
    "layerB = 10#4\n",
    "for dataset_key in ['train', 'test']:\n",
    "    tokenized_dataset[dataset_key] = torch.load(f'data/{modelA_name}_tokenized_dataset_200000_{dataset_key}_{TRUNCATION_LENGTH}.pt', weights_only=True)\n",
    "modelA = HookedTransformer.from_pretrained(modelA_name, cache_dir=CACHE_DIR)\n",
    "modelB = HookedTransformer.from_pretrained(modelB_name, cache_dir=CACHE_DIR)\n",
    "PADDING_TOKEN = modelA.tokenizer.pad_token_id\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75b7d9d8-52f3-4719-a7d0-7d6ef725e277",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.inference_mode()\n",
    "def sae_eval(dataloader, model, layer, sae, ignore_index=0, ctx_size=128):\n",
    "    results = {\n",
    "        'l0': [],\n",
    "        'uev': [],\n",
    "        'delta': [],\n",
    "    }\n",
    "    device = next(model.parameters()).device\n",
    "\n",
    "    if isinstance(ignore_index, int):\n",
    "        ignore_index_tens = torch.tensor([ignore_index], device=device).int()\n",
    "    else:\n",
    "        ignore_index_tens = torch.tensor(ignore_index,device=device).int()\n",
    "    for i, sample in tqdm(enumerate(dataloader)):\n",
    "        sample = sample.to(device)[..., :ctx_size]\n",
    "        logits = model(sample, stop_at_layer=layer)\n",
    "        mask = torch.isin(sample, ignore_index_tens)\n",
    "        acts = sae.encode(logits)\n",
    "        recon = sae.decode(acts)\n",
    "        acts_masked = acts[~mask]\n",
    "        recon_masked = recon[~mask]\n",
    "        results['l0'].append(\n",
    "            (acts_masked > 0).sum(dim=-1).float().cpu()\n",
    "        )\n",
    "        logits_masked = logits[~mask]\n",
    "        #mse = (logits - recon).pow(2).sum(dim=-1)\n",
    "        # Recon and logits are ((b ctx), d)\n",
    "        per_token_l2_loss = (recon_masked - logits_masked).pow(2).sum(dim=-1).squeeze()\n",
    "        total_variance = (logits_masked - logits_masked.mean(0)).pow(2).sum(-1)\n",
    "        uev = per_token_l2_loss / total_variance # ((b ctx))\n",
    "        results['uev'].append(uev.cpu())\n",
    "\n",
    "        true_logits = model(logits, start_at_layer=layer)\n",
    "        recon_logits = model(recon, start_at_layer=layer)\n",
    "        true_loss = next_token_cross_entropy_loss(true_logits, sample, reduction='none', ignore_index=ignore_index)\n",
    "        recon_loss = next_token_cross_entropy_loss(recon_logits, sample, reduction='none', ignore_index=ignore_index)\n",
    "        results['delta'].append((recon_loss - true_loss).cpu())\n",
    "    \n",
    "    for (k,v) in results.items():\n",
    "        if len(v) == 0:\n",
    "            results[k] = None\n",
    "        else:\n",
    "            results[k] = torch.cat(v).mean().item()\n",
    "    return results\n",
    "\n",
    "@torch.inference_mode()\n",
    "def get_densities(dataloader, model, layer, sae, ctx_size=128):\n",
    "    total_toks = 0\n",
    "    device = next(model.parameters()).device\n",
    "    if 'BaseSAE' in str(type(sae_A)):\n",
    "        d_sae = sae.d_sae\n",
    "    else:\n",
    "        d_sae = sae.cfg.d_sae\n",
    "    act_sums = torch.zeros(d_sae, device=device)\n",
    "\n",
    "    for i, sample in tqdm(enumerate(dataloader)):\n",
    "        sample = sample.to(device)[..., :ctx_size]\n",
    "        logits = model(sample, stop_at_layer=layer)\n",
    "        acts = sae.encode(logits)\n",
    "        ignore_mask = (sample == 0)\n",
    "        acts = acts[~ignore_mask]\n",
    "        act_sums += (acts > 0).sum(dim=0)\n",
    "        total_toks += acts.shape[0]\n",
    "    return (act_sums / total_toks).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc4993d8-2f0b-4d36-b068-a947a51ead05",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_sae(checkpoint_dir, checkpoint_name, subdir):\n",
    "    return SAE.load_from_pretrained(os.path.join(checkpoint_dir, checkpoint_name, subdir), device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6387397a-6792-41bb-a9b7-bfe6d3471020",
   "metadata": {},
   "source": [
    "# Zero shot Evals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d32563b7-e0a4-4ea8-a72e-85b94f1b77a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae import Sae\n",
    "from stitching.sae_utils import convert_eleuther_sae_to_BaseSAE, BaseSAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad1ba425-e62b-4101-908c-13d9778c0299",
   "metadata": {},
   "outputs": [],
   "source": [
    "transfer_id = 'snowy-pyramid-3' #'tough-forest-11'\n",
    "project_name = f\"stitch_training_{modelA_name}_to_{modelB_name}_bidirectional_mse\"\n",
    "checkpoints_dir = os.path.join('checkpoints/', f\"{project_name}/\")\n",
    "P, Pinv, beta, bias, biasinv = open_experiment(modelA.cfg.d_model, modelB.cfg.d_model, checkpoints_dir, transfer_id, biases=True, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adfc2e68-3872-466a-9f85-cb8a1c26b45e",
   "metadata": {},
   "outputs": [],
   "source": [
    "sae_A = Sae.load_from_hub(\"EleutherAI/sae-pythia-70m-deduped-32k\", hookpoint=f\"layers.{layerA-1}\").to(device)\n",
    "sae_A = convert_eleuther_sae_to_BaseSAE(sae_A)\n",
    "sae_A.normalize_decoder_vectors()\n",
    "sae_A.get_rid_of_decoder_sub()\n",
    "sae_B = Sae.load_from_hub(\"EleutherAI/sae-pythia-160m-deduped-32k\", hookpoint=f\"layers.{layerB-1}\").to(device)\n",
    "sae_B = convert_eleuther_sae_to_BaseSAE(sae_B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa16918a-e48c-441a-9c26-629e77748764",
   "metadata": {},
   "outputs": [],
   "source": [
    "#sae_A, cfg_dict, _ = SAE.from_pretrained(\n",
    "#    release = \"gpt2-small-res-jb\", # see other options in sae_lens/pretrained_saes.yaml\n",
    "#    sae_id = f\"blocks.{layerA}.hook_resid_pre\", # won't always be a hook point\n",
    "#    device = device\n",
    "#)\n",
    "sae_A = SAE.load_from_pretrained('gpt2-small-topk-sae-checkpoints/1t637wuk/final_245760000/', device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ac34e97-bb7c-4523-981f-0504f8b094f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if 'BaseSAE' in str(type(sae_A)):\n",
    "    apply_b_dec = sae_A.apply_b_dec\n",
    "else:\n",
    "    apply_b_dec = sae_A.cfg.apply_b_dec_to_input\n",
    "sae_A = BaseSAE(\n",
    "    sae_A.W_enc.detach().clone(),\n",
    "    sae_A.W_dec.detach().clone(),\n",
    "    sae_A.b_enc.detach().clone(),\n",
    "    sae_A.b_dec.detach().clone(),\n",
    "    sae_A.activation_fn,\n",
    "    apply_b_dec=apply_b_dec\n",
    ")\n",
    "#sae_A.normalize_decoder_vectors()\n",
    "sae_A.get_rid_of_decoder_sub()\n",
    "\n",
    "transferred_sae_B = BaseSAE(\n",
    "    Pinv @ sae_A.W_enc.detach().clone(),\n",
    "    sae_A.W_dec.detach().clone() @ P,\n",
    "    sae_A.b_enc.detach().clone() + biasinv @ sae_A.W_enc.detach().clone(),\n",
    "    sae_A.b_dec.detach().clone() @ P + bias,\n",
    "    sae_A.activation_fn,\n",
    "    apply_b_dec=False\n",
    ")\n",
    "transferred_sae_B.normalize_decoder_vectors()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e97b998-3f8b-4d39-a735-5b2d46f7b2e9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8b0759d-45d1-44ed-89fe-224548979218",
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_A_sae_res = sae_eval(torch.utils.data.DataLoader(\n",
    "    tokenized_dataset['test'][:5000],\n",
    "    batch_size=10,\n",
    "    shuffle=False\n",
    "), modelA, layerA, sae_A, ignore_index=modelA.tokenizer.pad_token_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0daab76-bf01-44d8-baf4-a00a460dc0d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_A_sae_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a599938e-a720-4f16-837e-c87fdc0bac4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_A_sae_densities = get_densities(torch.utils.data.DataLoader(\n",
    "    tokenized_dataset['test'][:5000],\n",
    "    batch_size=10,\n",
    "    shuffle=False\n",
    "), modelA, layerA, sae_A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d919913d-69a5-4072-977c-8a1bc0f69920",
   "metadata": {},
   "outputs": [],
   "source": [
    "dead = (orig_A_sae_densities < 1e-8).sum().item()\n",
    "print(\"dead features\", dead, dead / sae_A.d_sae)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b14bd8ed-3583-4954-bbdd-61500dff040e",
   "metadata": {},
   "outputs": [],
   "source": [
    "transferred_sae_B = BaseSAE(\n",
    "    Pinv @ sae_A.W_enc.clone() ,\n",
    "    sae_A.W_dec.clone() @ P,\n",
    "    sae_A.b_enc.clone() + biasinv @ sae_A.W_enc.clone(),\n",
    "    sae_A.b_dec.clone() @ P + bias,\n",
    "    sae_A.activation_fn,\n",
    "    apply_b_dec=False\n",
    ")\n",
    "transferred_sae_B.normalize_decoder_vectors()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb91f91c-91ec-4732-8265-2442aa0218c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "zero_shot_pythia_160m_res = sae_eval(torch.utils.data.DataLoader(\n",
    "    tokenized_dataset['test'][:5000],\n",
    "    batch_size=10,\n",
    "    shuffle=False\n",
    "), modelB, layerB, transferred_sae_B, ignore_index=modelB.tokenizer.pad_token_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a6a6b12-2d0d-4855-a71a-e94f0c778a9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "zero_shot_pythia_160m_densities = get_densities(torch.utils.data.DataLoader(\n",
    "    tokenized_dataset['test'][:5000],\n",
    "    batch_size=10,\n",
    "    shuffle=False\n",
    "), modelB, layerB, transferred_sae_B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a58128d4-f666-400b-96fb-0b23b1e5421e",
   "metadata": {},
   "outputs": [],
   "source": [
    "zero_shot_pythia_160m_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e97fcc15-eabc-409b-868d-ae5dba204055",
   "metadata": {},
   "outputs": [],
   "source": [
    "dead = (zero_shot_pythia_160m_densities < 1e-8).sum().item()\n",
    "print(\"dead features\", dead, dead / transferred_sae_B.d_sae)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13c7348f-2589-4b57-990b-55a5529be562",
   "metadata": {},
   "source": [
    "# Cosine Similarity Over Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f2dc139-c3d0-41aa-9d89-07b06ccb0d9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "from transformer_lens import HookedTransformer\n",
    "import torch\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "from tqdm import tqdm\n",
    "from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner, SAE\n",
    "from dataclasses import dataclass\n",
    "from stitching.losses import next_token_cross_entropy_loss\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "device = \"cuda\"\n",
    "\n",
    "print(\"Using device:\", device)\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "import yaml\n",
    "with open('global_config.yaml') as global_stream:\n",
    "    global_cfg = yaml.safe_load(global_stream)\n",
    "CACHE_DIR = global_cfg['CACHE_DIR']\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c87175c9-fdcf-4dc5-9eb2-c9c6bb3a337f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stitching.sae_utils import get_densities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb92b119-271b-4c6a-8c1a-0e436d9eaaff",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_dataset = {}\n",
    "for dataset_key in ['train', 'test']:\n",
    "    tokenized_dataset[dataset_key] = torch.load(f'data/pythia-70m-deduped_tokenized_dataset_200000_{dataset_key}_{512}.pt', weights_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93d8be87-209c-4675-9b56-26b9083d8aab",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HookedTransformer.from_pretrained('pythia-160m-deduped', cache_dir=CACHE_DIR)\n",
    "layer = 4\n",
    "sae = SAE.load_from_pretrained(os.path.join('pythia-160m-sae-topk-checkpoints', 'kr2slklm', 'initial_config')).to('cuda')\n",
    "sae_random =  SAE.load_from_pretrained(os.path.join('pythia-160m-sae-topk-checkpoints', 'cdvcqe34', 'initial_config')).to('cuda')\n",
    "dataloader = torch.utils.data.DataLoader(tokenized_dataset['train'][:10000], batch_size=10, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12162471-92f4-473d-a312-454a2bc93fb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "sae.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd8d9ad1-f92e-4d71-a3af-00918e3fb0e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "densities = get_densities(\n",
    "    dataloader,\n",
    "    model,\n",
    "    layer,\n",
    "    sae,\n",
    "    512\n",
    ")\n",
    "densities_random = get_densities(\n",
    "    dataloader,\n",
    "    model,\n",
    "    layer,\n",
    "    sae_random,\n",
    "    512\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45ba61b1-64ee-46aa-90a0-cefa0fc2a962",
   "metadata": {},
   "outputs": [],
   "source": [
    "sae_dict = {\n",
    "    'stitched' : 'kr2slklm', \n",
    "    'random' : 'cdvcqe34',   \n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a277f722-682e-4b2f-af2f-0752779673df",
   "metadata": {},
   "outputs": [],
   "source": [
    "init_mats = {\n",
    "    k: None for k in sae_dict.keys()\n",
    "}\n",
    "decoder_mats = {\n",
    "    k: {} for k in sae_dict.keys()\n",
    "}\n",
    "full_results = {}\n",
    "for (desc, checkpoint_dir) in sae_dict.items():\n",
    "    print(desc, checkpoint_dir)\n",
    "    path = os.path.join('pythia-160m-sae-topk-checkpoints', checkpoint_dir)\n",
    "\n",
    "    for entry in os.listdir(path):\n",
    "        full_path = os.path.join(path, entry)\n",
    "        if os.path.isdir(full_path):\n",
    "            print(full_path)\n",
    "            trained_sae = SAE.load_from_pretrained(full_path, device='cpu')\n",
    "            iteration = None\n",
    "            if entry == 'initial_config':\n",
    "                iteration = 0\n",
    "            elif 'final' in entry:\n",
    "                iteration = int(entry.split('_')[1])\n",
    "            else:\n",
    "                iteration = int(entry)\n",
    "            decoder_mats[desc][iteration] = trained_sae.W_dec\n",
    "            if 'initial' in entry:\n",
    "                init_mats[desc] = trained_sae.W_dec              "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93ddcb74-e6e2-4bcc-8ec5-8463fad2b62d",
   "metadata": {},
   "outputs": [],
   "source": [
    "csim_series = {\n",
    "    k: None for k in sae_dict.keys()\n",
    "}\n",
    "for k in sae_dict.keys():\n",
    "    decoder_mats_dict = decoder_mats[k]\n",
    "    initial_mat = init_mats[k]\n",
    "    iterations = []\n",
    "    csims = []\n",
    "    for iteration, decoder_mat in decoder_mats_dict.items():\n",
    "        iterations.append(iteration)\n",
    "        csims.append((initial_mat * decoder_mat).sum(dim=-1).flatten().numpy())\n",
    "    iterations = np.array(iterations)\n",
    "    idxs = np.argsort(iterations)\n",
    "    csim_series[k] = (iterations[idxs], np.array(csims)[idxs])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e33b8813-11a6-403c-916d-06d3249fc389",
   "metadata": {},
   "outputs": [],
   "source": [
    "csim_series['random'][1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ff74ef0-dc3a-46b1-a65b-b5b95a393244",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(100):\n",
    "    plt.plot(csim_series['random'][0], csim_series['random'][1][:,i], color='tab:blue', alpha=0.5)\n",
    "    plt.plot(csim_series['stitched'][0], csim_series['stitched'][1][:,i], color='tab:orange', alpha=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6af1d098-9eb2-4509-962a-2e68478cf306",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4141962c-b1e7-4da4-b961-5e348f431e70",
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams.update({'font.size': 14})\n",
    "plt.hist(csim_series['random'][1][-1][densities_random > 1e-5], bins=50, range=(-0.1, 1.0), alpha=0.75, label='random')\n",
    "plt.hist(csim_series['stitched'][1][-1][densities > 1e-5], bins=50, range=(-0.1, 1.0), alpha=0.75, label='stitched')\n",
    "plt.legend()\n",
    "plt.xlabel('cosine similarity to initialization')\n",
    "plt.ylabel('count')\n",
    "plt.savefig('results/figures/32k_cosine_similarity_to_initialization.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a48c5e58-124b-42eb-af1a-29e9561857b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.tensor(csim_series['stitched'][1][-1][densities>1e-5]).topk(100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7958e306-2140-49cc-a1a8-48b7d7acb34f",
   "metadata": {},
   "source": [
    "### Visualize the clump"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f25ad99f-75fb-47a4-85d3-9466d5c8fc30",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stitching.sae_viz import get_activations_for_feature, display_acts, highlight_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cb4771c-f4ec-4ad5-8fd8-57db58158643",
   "metadata": {},
   "outputs": [],
   "source": [
    "modelB = HookedTransformer.from_pretrained('pythia-160m-deduped', cache_dir=CACHE_DIR).to(device)\n",
    "saeB = SAE.load_from_pretrained('pythia-160m-sae-topk-checkpoints/kr2slklm/final_491520000/').to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a66c3a4e-eac9-4c3c-885c-65f99f5d88c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_dataset = {}\n",
    "for dataset_key in ['train', 'test']:\n",
    "    tokenized_dataset[dataset_key] = torch.load(f'data/pythia-70m-deduped_tokenized_dataset_200000_{dataset_key}_512.pt', weights_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91b200a3-ce85-49b9-8bae-ab935c57f3bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat_idx = 19444"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9496791-a45c-4f27-9b56-c9d22e30672e",
   "metadata": {},
   "outputs": [],
   "source": [
    "acts, toks = get_activations_for_feature(torch.utils.data.DataLoader(tokenized_dataset['test'][:1000], batch_size=10, shuffle=False), feat_idx, modelB, 4, saeB)\n",
    "#stitched_acts, stitched_toks = get_activations_for_feature(val_dataloader, feat_idx, modelB, layer_B, transferred_sae)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f217fac-db89-435b-b1c3-52b9b2494487",
   "metadata": {},
   "outputs": [],
   "source": [
    "display_acts(toks,acts,modelB,k=10,ctx=25,upper_cap=50,display_density=False,verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db8a54a6-e471-4fbc-aecb-4968927102d9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
