{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory\n",
    "\n",
    "# TODO: Make this nicer.\n",
    "df = pd.DataFrame.from_records(\n",
    "    {k: v.__dict__ for k, v in get_pretrained_saes_directory().items()}\n",
    ").T\n",
    "df.drop(\n",
    "    columns=[\n",
    "        \"expected_var_explained\",\n",
    "        \"expected_l0\",\n",
    "        \"config_overrides\",\n",
    "        \"conversion_func\",\n",
    "    ],\n",
    "    inplace=True,\n",
    ")\n",
    "df[\n",
    "    df.release.str.contains(\"bench\")\n",
    "]  # Each row is a \"release\" which has multiple SAEs which may have different configs / match different hook points in a model.\n",
    "# print(df.head())\n",
    "\n",
    "# print all columns\n",
    "print(df.columns)\n",
    "# # print(df.model)\n",
    "# # print(df.release)\n",
    "# print(df['model'][\"sae_bench_pythia70m_sweep_topk_ctx128_0730\"])\n",
    "# print(df.saes_map[\"sae_bench_pythia70m_sweep_topk_ctx128_0730\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sae_basic_info = {\n",
    "    \"pythia70m_sweep_topk_ctx128_0730/resid_post_layer_4/trainer_9\": {\n",
    "        \"sae_config_dictionary_learning\": {},\n",
    "        \"basic_eval_results\": {\"l0\": 80, \"frac_recovered\": 0.99},\n",
    "    }\n",
    "}\n",
    "\n",
    "\n",
    "custom_results_dict = {\n",
    "    \"custom_eval_config\": {\"dataset\": \"Bib\", \"n_inputs\": 100},\n",
    "    \"custom_eval_results\": {\n",
    "        \"pythia70m_sweep_topk_ctx128_0730/resid_post_layer_4/trainer_9\": {\n",
    "            \"sparse_probing_k_1\": 0.55\n",
    "        }\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard imports\n",
    "# Imports for displaying vis in Colab / notebook\n",
    "\n",
    "import plotly.express as px\n",
    "import torch\n",
    "\n",
    "PORT = 8000\n",
    "\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.backends.mps.is_available():\n",
    "    device = \"mps\"\n",
    "else:\n",
    "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "\n",
    "print(f\"Device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from sae_lens import SAE\n",
    "from transformer_lens import HookedTransformer\n",
    "\n",
    "model = HookedTransformer.from_pretrained(\"pythia-70m-deduped\", device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)\n",
    "# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict\n",
    "# We also return the feature sparsities which are stored in HF for convenience.\n",
    "sae, cfg_dict, sparsity = SAE.from_pretrained(\n",
    "    release=\"sae_bench_pythia70m_sweep_topk_ctx128_0730\",\n",
    "    sae_id=\"blocks.4.hook_resid_post__trainer_10\",\n",
    "    device=device,\n",
    ")\n",
    "sae = sae.to(device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(cfg_dict)\n",
    "print(sparsity)\n",
    "print(sae)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_lens.utils import tokenize_and_concatenate\n",
    "\n",
    "dataset = load_dataset(\n",
    "    path=\"NeelNanda/pile-10k\",\n",
    "    split=\"train\",\n",
    "    streaming=False,\n",
    ")\n",
    "\n",
    "token_dataset = tokenize_and_concatenate(\n",
    "    dataset=dataset,  # type: ignore\n",
    "    tokenizer=model.tokenizer,  # type: ignore\n",
    "    streaming=True,\n",
    "    max_length=sae.cfg.context_size,\n",
    "    add_bos_token=sae.cfg.prepend_bos,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(device, sae.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(token_dataset[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from functools import partial\n",
    "\n",
    "sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads\n",
    "\n",
    "batch_size = 32\n",
    "seq_len = sae.cfg.context_size\n",
    "d_model = sae.cfg.d_in\n",
    "layer = sae.cfg.hook_layer\n",
    "\n",
    "all_acts_list_BLD = []\n",
    "\n",
    "\n",
    "def activation_hook(resid_BLD: torch.Tensor, hook):\n",
    "    all_acts_list_BLD.append(resid_BLD)\n",
    "    return resid_BLD\n",
    "\n",
    "\n",
    "hook_name = f\"blocks.{layer}.hook_resid_post\"\n",
    "# model.add_hook(hook_name, temp_hook_fn)\n",
    "\n",
    "\n",
    "batches = 10\n",
    "\n",
    "with torch.no_grad():\n",
    "    # activation store can give us tokens.\n",
    "\n",
    "    for i in range(batches):\n",
    "        batch_tokens = token_dataset[(i * batch_size) : ((i + 1) * batch_size)][\n",
    "            \"tokens\"\n",
    "        ]\n",
    "\n",
    "        model.run_with_hooks(\n",
    "            batch_tokens, return_type=None, fwd_hooks=[(hook_name, activation_hook)]\n",
    "        )\n",
    "\n",
    "acts_BLD = torch.cat(all_acts_list_BLD, dim=0)\n",
    "print(acts_BLD.shape)\n",
    "\n",
    "# Use the SAE\n",
    "# feature_acts = sae.encode(cache[sae.cfg.hook_name])\n",
    "# sae_out = sae.decode(feature_acts)\n",
    "\n",
    "# # save some room\n",
    "# del cache\n",
    "\n",
    "# # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position\n",
    "# l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()\n",
    "# print(\"average l0\", l0.mean().item())\n",
    "# px.histogram(l0.flatten().cpu().numpy()).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads\n",
    "\n",
    "with torch.no_grad():\n",
    "    # activation store can give us tokens.\n",
    "    batch_tokens = token_dataset[:32][\"tokens\"]\n",
    "    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)\n",
    "\n",
    "    print(cache[sae.cfg.hook_name].device, sae.device)\n",
    "\n",
    "    # Use the SAE\n",
    "    feature_acts = sae.encode(cache[sae.cfg.hook_name])\n",
    "    sae_out = sae.decode(feature_acts)\n",
    "\n",
    "    # save some room\n",
    "    del cache\n",
    "\n",
    "    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position\n",
    "    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()\n",
    "    print(\"average l0\", l0.mean().item())\n",
    "    px.histogram(l0.flatten().cpu().numpy()).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# next we want to do a reconstruction test.\n",
    "def reconstr_hook(activation, hook, sae_out):\n",
    "    return sae_out\n",
    "\n",
    "\n",
    "def zero_abl_hook(activation, hook):\n",
    "    return torch.zeros_like(activation)\n",
    "\n",
    "\n",
    "print(\"Orig\", model(batch_tokens, return_type=\"loss\").item())\n",
    "print(\n",
    "    \"reconstr\",\n",
    "    model.run_with_hooks(\n",
    "        batch_tokens,\n",
    "        fwd_hooks=[\n",
    "            (\n",
    "                sae.cfg.hook_name,\n",
    "                partial(reconstr_hook, sae_out=sae_out),\n",
    "            )\n",
    "        ],\n",
    "        return_type=\"loss\",\n",
    "    ).item(),\n",
    ")\n",
    "print(\n",
    "    \"Zero\",\n",
    "    model.run_with_hooks(\n",
    "        batch_tokens,\n",
    "        return_type=\"loss\",\n",
    "        fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],\n",
    "    ).item(),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
